use std::io::{self, BufRead, BufWriter, Write};
use std::path::Path;
pub fn write_w2v_tsv(
writer: &mut impl Write,
names: &[String],
vecs: &[Vec<f32>],
) -> io::Result<()> {
assert_eq!(names.len(), vecs.len(), "names and vecs must match");
if vecs.is_empty() {
return Ok(());
}
let dim = vecs[0].len();
let mut w = BufWriter::new(writer);
writeln!(w, "{} {dim}", names.len())?;
for (name, vec) in names.iter().zip(vecs.iter()) {
write!(w, "{name}")?;
for v in vec {
write!(w, "\t{v}")?;
}
writeln!(w)?;
}
w.flush()
}
pub fn read_w2v_tsv(reader: impl io::Read) -> io::Result<(Vec<String>, Vec<Vec<f32>>)> {
let buf = io::BufReader::new(reader);
let mut lines = buf.lines();
let header = lines
.next()
.ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "empty file"))??;
let parts: Vec<&str> = header.split_whitespace().collect();
if parts.len() != 2 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("expected 'count dim' header, got: {header}"),
));
}
let count: usize = parts[0]
.parse()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("bad count: {e}")))?;
let dim: usize = parts[1]
.parse()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("bad dim: {e}")))?;
let mut names = Vec::with_capacity(count);
let mut vecs = Vec::with_capacity(count);
for line in lines {
let line = line?;
if line.is_empty() {
continue;
}
let mut parts = line.split('\t');
let name = parts
.next()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "empty line"))?
.to_string();
let vec: Vec<f32> = parts
.map(|s| {
s.parse::<f32>().map_err(|e| {
io::Error::new(io::ErrorKind::InvalidData, format!("bad float: {e}"))
})
})
.collect::<io::Result<_>>()?;
if vec.len() != dim {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("expected {dim} values for '{name}', got {}", vec.len()),
));
}
names.push(name);
vecs.push(vec);
}
Ok((names, vecs))
}
pub fn write_binary(writer: &mut impl Write, vecs: &[Vec<f32>]) -> io::Result<()> {
let mut w = BufWriter::new(writer);
for vec in vecs {
for &v in vec {
w.write_all(&v.to_le_bytes())?;
}
}
w.flush()
}
pub fn write_vocab_tsv(writer: &mut impl Write, names: &[String]) -> io::Result<()> {
let mut w = BufWriter::new(writer);
for (id, name) in names.iter().enumerate() {
writeln!(w, "{id}\t{name}")?;
}
w.flush()
}
pub fn export_embeddings(
dir: &Path,
entity_names: &[String],
entity_vecs: &[Vec<f32>],
relation_names: &[String],
relation_vecs: &[Vec<f32>],
) -> io::Result<()> {
std::fs::create_dir_all(dir)?;
let mut ent_file = std::fs::File::create(dir.join("entities.tsv"))?;
write_w2v_tsv(&mut ent_file, entity_names, entity_vecs)?;
let mut rel_file = std::fs::File::create(dir.join("relations.tsv"))?;
write_w2v_tsv(&mut rel_file, relation_names, relation_vecs)?;
Ok(())
}
pub fn import_embeddings(path: &Path) -> io::Result<(Vec<String>, Vec<Vec<f32>>)> {
let file = std::fs::File::open(path)?;
read_w2v_tsv(file)
}
pub struct LoadedEmbeddings {
pub entity_names: Vec<String>,
pub entity_vecs: Vec<Vec<f32>>,
pub relation_names: Vec<String>,
pub relation_vecs: Vec<Vec<f32>>,
}
pub fn load_embeddings(dir: &Path) -> io::Result<LoadedEmbeddings> {
let (entity_names, entity_vecs) = import_embeddings(&dir.join("entities.tsv"))?;
let (relation_names, relation_vecs) = import_embeddings(&dir.join("relations.tsv"))?;
Ok(LoadedEmbeddings {
entity_names,
entity_vecs,
relation_names,
relation_vecs,
})
}
pub fn flatten_matrix(vecs: &[Vec<f32>]) -> Vec<f32> {
let total: usize = vecs.iter().map(|v| v.len()).sum();
let mut flat = Vec::with_capacity(total);
for v in vecs {
flat.extend_from_slice(v);
}
flat
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn w2v_roundtrip() {
let names = vec!["alice".to_string(), "bob".to_string()];
let vecs = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
let mut buf = Vec::new();
write_w2v_tsv(&mut buf, &names, &vecs).unwrap();
let (read_names, read_vecs) = read_w2v_tsv(buf.as_slice()).unwrap();
assert_eq!(read_names, names);
assert_eq!(read_vecs.len(), 2);
for (a, b) in vecs.iter().zip(read_vecs.iter()) {
for (x, y) in a.iter().zip(b.iter()) {
assert!((x - y).abs() < 1e-5);
}
}
}
#[test]
fn w2v_empty() {
let mut buf = Vec::new();
write_w2v_tsv(&mut buf, &[], &[]).unwrap();
assert!(buf.is_empty());
}
#[test]
fn binary_write() {
let vecs = vec![vec![1.0_f32, 2.0], vec![3.0, 4.0]];
let mut buf = Vec::new();
write_binary(&mut buf, &vecs).unwrap();
assert_eq!(buf.len(), 4 * 4); let first = f32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]);
assert!((first - 1.0).abs() < 1e-6);
}
#[test]
fn export_import_roundtrip() {
let dir = tempfile::tempdir().unwrap();
let ent_names = vec!["a".to_string(), "b".to_string()];
let ent_vecs = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let rel_names = vec!["r1".to_string()];
let rel_vecs = vec![vec![0.5, 0.5]];
export_embeddings(dir.path(), &ent_names, &ent_vecs, &rel_names, &rel_vecs).unwrap();
let loaded = load_embeddings(dir.path()).unwrap();
assert_eq!(loaded.entity_names, ent_names);
assert_eq!(loaded.relation_names, rel_names);
assert_eq!(loaded.entity_vecs.len(), 2);
assert_eq!(loaded.relation_vecs.len(), 1);
}
#[test]
fn flatten_matrix_works() {
let vecs = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let flat = flatten_matrix(&vecs);
assert_eq!(flat, vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn flatten_empty() {
let flat = flatten_matrix(&[]);
assert!(flat.is_empty());
}
#[test]
fn w2v_tsv_preserves_precision() {
let names = vec!["x".to_string()];
let vecs = vec![vec![std::f32::consts::PI, std::f32::consts::E]];
let mut buf = Vec::new();
write_w2v_tsv(&mut buf, &names, &vecs).unwrap();
let (_, read_vecs) = read_w2v_tsv(buf.as_slice()).unwrap();
assert!((read_vecs[0][0] - std::f32::consts::PI).abs() < 1e-4);
assert!((read_vecs[0][1] - std::f32::consts::E).abs() < 1e-4);
}
#[test]
fn read_w2v_bad_header() {
let bad = b"not_a_number dim\n";
let result = read_w2v_tsv(bad.as_slice());
assert!(result.is_err());
}
#[test]
fn read_w2v_dim_mismatch() {
let bad = b"1 3\nalice\t1.0\t2.0\n";
let result = read_w2v_tsv(bad.as_slice());
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("expected 3"),
"Error should mention expected dim: {msg}"
);
}
#[test]
fn write_vocab_tsv_roundtrip() {
let names = vec![
"alice".to_string(),
"bob".to_string(),
"charlie".to_string(),
];
let mut buf = Vec::new();
write_vocab_tsv(&mut buf, &names).unwrap();
let content = String::from_utf8(buf).unwrap();
assert_eq!(content, "0\talice\n1\tbob\n2\tcharlie\n");
}
}