use brainharmony::prelude::*;
use burn::backend::NdArray;
type B = NdArray;
fn main() -> anyhow::Result<()> {
let args: Vec<String> = std::env::args().collect();
if args.len() < 5 {
eprintln!("usage: classify <weights> <gradient.csv> <geoh.csv> <input.safetensors>");
std::process::exit(1);
}
let device = burn::backend::ndarray::NdArrayDevice::Cpu;
brainharmony::init_threads(None);
let cfg = ModelConfig::default();
let (encoder, _) = BrainHarmonyEncoder::<B>::from_weights(
&args[1], &args[2], &args[3], &cfg, &DataConfig::default(), &device,
)?;
let head = ClassificationHead::<B>::new(cfg.embed_dim, 2, &device);
let input = brainharmony::data::load_signal_safetensors::<B>(&args[4], &device)?;
let enc_out = encoder.encode_tensor(input.data)?;
let n = enc_out.n_patches();
let d = enc_out.embed_dim();
let emb_tensor = burn::prelude::Tensor::<B, 2>::from_data(
burn::prelude::TensorData::new(enc_out.embeddings.clone(), vec![n, d]),
&device,
).unsqueeze_dim::<3>(0);
let logits = head.forward(emb_tensor);
let classes = predict_classes(logits.clone());
let logit_data: Vec<f32> = logits.into_data().to_vec::<f32>().unwrap();
let class_data: Vec<i64> = classes.into_data().to_vec::<i64>().unwrap();
println!("Logits: [{:.4}, {:.4}]", logit_data[0], logit_data[1]);
println!("Predicted class: {} (untrained head — random)",
class_data[0]);
Ok(())
}