use brainharmony::prelude::*;
use burn::backend::NdArray;
type B = NdArray;
fn main() -> anyhow::Result<()> {
let args: Vec<String> = std::env::args().collect();
if args.len() < 5 {
eprintln!("usage: embed <weights.safetensors> <gradient.csv> <geoh.csv> <input.safetensors> [output.safetensors]");
std::process::exit(1);
}
let device = burn::backend::ndarray::NdArrayDevice::Cpu;
brainharmony::init_threads(None);
let (encoder, ms) = BrainHarmonyEncoder::<B>::from_weights(
&args[1], &args[2], &args[3],
&ModelConfig::default(),
&DataConfig::default(),
&device,
)?;
println!("Loaded in {ms:.0} ms: {}", encoder.describe());
let result = encoder.encode_safetensors(&args[4])?;
println!("Encoded: {} patches x {} dims in {:.1} ms",
result.n_patches(), result.embed_dim(), result.ms_encode);
let out = args.get(5).map(|s| s.as_str()).unwrap_or("embeddings.safetensors");
result.save_safetensors(out)?;
println!("Saved: {out}");
Ok(())
}