use crate::{space::VecSpace, vector::Vector};
use std::io::Write;
pub const DEFAULT_WRITE_HEADER: bool = true;
pub const DEFAULT_TERM_SEP: char = ' ';
pub const DEFAULT_VEC_SEP: char = ' ';
#[derive(Debug, Clone, Copy)]
pub struct Exporter<W> {
write_header: bool,
term_separator: char,
vec_separator: char,
binary: bool,
writer: W,
header_written: bool,
}
impl<W> Exporter<W> {
#[inline]
pub fn new(w: W) -> Self {
Self {
write_header: DEFAULT_WRITE_HEADER,
term_separator: DEFAULT_TERM_SEP,
vec_separator: DEFAULT_VEC_SEP,
binary: false,
writer: w,
header_written: false,
}
}
pub fn use_binary(mut self) -> Self {
self.binary = true;
self
}
pub fn no_header(mut self) -> Self {
self.write_header = false;
self
}
}
impl<W: Write> Exporter<W> {
pub fn export_space(self, space: &VecSpace) -> Result<usize, std::io::Error> {
self.export_space_filtered(space, |_| true)
}
pub fn export_space_filtered<F>(
mut self,
space: &VecSpace,
filter: F,
) -> Result<usize, std::io::Error>
where
F: Fn(&Vector) -> bool,
{
let mut n = 0;
let len = space.len();
let dim = space.dim();
n += self.write_header(len, dim)?;
n += self.export_vectors(space.iter().filter(|i| (filter)(i)))?;
Ok(n)
}
pub fn export_vectors<'a, 'b, I>(&mut self, iter: I) -> Result<usize, std::io::Error>
where
I: IntoIterator<Item = Vector<'a, 'b>>,
{
if self.write_header && !self.header_written {
panic!("Expecetd header to be written");
}
let mut n = 0;
for (pos, i) in iter.into_iter().enumerate() {
if pos > 0 {
n += self.writer.write(b"\n")?;
}
n += self.write_vector(i)?;
}
Ok(n)
}
fn write_vector(&mut self, vec: Vector) -> Result<usize, std::io::Error> {
if self.binary {
self.write_vector_bin(vec)
} else {
self.write_vector_txt(vec)
}
}
fn write_vector_bin(&mut self, _vec: Vector) -> Result<usize, std::io::Error> {
todo!()
}
fn write_vector_txt(&mut self, vec: Vector) -> Result<usize, std::io::Error> {
let mut n = 0;
n += self.writer.write(vec.term().as_bytes())?;
n += self
.writer
.write(self.term_separator.to_string().as_bytes())?;
for (pos, v) in vec.data().iter().enumerate() {
if pos > 0 {
n += self
.writer
.write(self.vec_separator.to_string().as_bytes())?;
}
n += self.writer.write(v.to_string().as_bytes())?;
}
Ok(n)
}
fn write_header(&mut self, dim: usize, len: usize) -> Result<usize, std::io::Error> {
self.header_written = true;
if self.binary {
self.write_header_bin(dim, len)
} else {
self.write_header_txt(dim, len)
}
}
fn write_header_bin(&mut self, _dim: usize, _len: usize) -> Result<usize, std::io::Error> {
todo!()
}
fn write_header_txt(&mut self, dim: usize, len: usize) -> Result<usize, std::io::Error> {
let mut n = 0;
n += self.writer.write(dim.to_string().as_bytes())?;
n += self.writer.write(b" ")?;
n += self.writer.write(len.to_string().as_bytes())?;
n += self.writer.write(b"\n")?;
Ok(n)
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::parse::Word2VecParser;
use std::io::Cursor;
#[test]
fn test_txt_export() {
let vecs = [
Vector::new(&[1.2, 2.0, 4.4], "term1"),
Vector::new(&[2.3, 1.0, 3.4], "term3"),
Vector::new(&[3.1, 9.4, 3.0], "term3"),
];
let mut space = VecSpace::new(3);
space.extend(vecs);
let mut buf: Vec<u8> = vec![];
Exporter::new(&mut buf).export_space(&space).unwrap();
let parsed = Word2VecParser::new().parse(Cursor::new(&buf)).unwrap();
assert_eq!(space, parsed);
}
}