use crate::Result;
use crate::vocab::Vocabulary;
use byteorder::{LittleEndian, WriteBytesExt};
use std::fs::File;
use std::io::{BufWriter, Write};
use std::path::Path;
pub fn save_word2vec_text<P: AsRef<Path>>(
path: P,
syn0: &[f32],
vocab: &Vocabulary,
vector_size: usize,
) -> Result<()> {
let path_ref = path.as_ref();
let file = File::create(path_ref)?;
let mut writer = BufWriter::new(file);
writeln!(writer, "{} {}", vocab.len(), vector_size)?;
for info in vocab.iter() {
let word_id = info.word_id; let remapped_id = info.remapped_id; let offset = remapped_id as usize * vector_size;
if offset + vector_size > syn0.len() {
continue;
}
write!(writer, "{}", word_id)?;
for i in 0..vector_size {
write!(writer, " {}", syn0[offset + i])?; }
writeln!(writer)?;
}
writer.flush()?;
eprintln!("Saved word2vec text format: {:?}", path_ref);
Ok(())
}
pub fn save_mcv1_format<P: AsRef<Path>>(
path: P,
syn0: &[f32],
vocab: &Vocabulary,
vector_size: usize,
max_word_id: u32,
) -> Result<()> {
let path_ref = path.as_ref();
let file = File::create(path_ref)?;
let mut writer = BufWriter::new(file);
let vocab_size = (max_word_id + 1) as usize;
eprintln!("Saving MCV1 format:");
eprintln!(
" Vocab size: {} (max_word_id: {})",
vocab_size, max_word_id
);
eprintln!(" Vector size: {}", vector_size);
eprintln!(" Trained words: {}", vocab.len());
writer.write_u32::<LittleEndian>(0x3143564D)?; writer.write_u32::<LittleEndian>(vocab_size as u32)?;
writer.write_u32::<LittleEndian>(vector_size as u32)?;
writer.write_u32::<LittleEndian>(0)?; writer.write_all(&[0u8; 16])?;
let zero_vec = vec![0.0f32; vector_size];
for word_id in 0..vocab_size {
if let Some(info) = vocab.get(word_id as u32) {
let remapped_id = info.remapped_id; let offset = remapped_id as usize * vector_size;
if offset + vector_size <= syn0.len() {
for i in 0..vector_size {
writer.write_f32::<LittleEndian>(syn0[offset + i])?;
}
} else {
for &val in &zero_vec {
writer.write_f32::<LittleEndian>(val)?;
}
}
} else {
for &val in &zero_vec {
writer.write_f32::<LittleEndian>(val)?;
}
}
}
writer.flush()?;
let file_size = vocab_size * vector_size * 4 + 32;
eprintln!(
" File size: {} bytes ({} MB)",
file_size,
file_size / 1024 / 1024
);
eprintln!("Saved MCV1 format: {:?}", path_ref);
Ok(())
}