use std::path::Path;
use std::time::Instant;
use clap::Parser;
#[cfg(feature = "wgpu")]
mod backend {
pub use burn::backend::Wgpu as B;
pub fn device() -> burn::backend::wgpu::WgpuDevice {
burn::backend::wgpu::WgpuDevice::default()
}
}
#[cfg(not(feature = "wgpu"))]
mod backend {
pub use burn::backend::NdArray as B;
pub fn device() -> burn::backend::ndarray::NdArrayDevice {
burn::backend::ndarray::NdArrayDevice::Cpu
}
}
use backend::B;
#[derive(Parser, Debug)]
#[command(about = "OSF — PSG embedding extraction")]
struct Args {
#[arg(long)]
weights: String,
#[arg(long)]
config: Option<String>,
#[arg(long, default_value = "data/embeddings.safetensors")]
output: String,
#[arg(long, short = 'v')]
verbose: bool,
}
fn generate_synthetic_psg(n_channels: usize, n_samples: usize) -> Vec<f32> {
let mut signal = vec![0.0f32; n_channels * n_samples];
for ch in 0..n_channels {
let freq = 1.0 + ch as f32 * 0.5;
let mut noise_state: u32 = (ch as u32 + 1) * 0xDEAD_BEEF;
for t in 0..n_samples {
let time = t as f32 / 64.0; let sine = (2.0 * std::f32::consts::PI * freq * time).sin() * 50e-6;
noise_state ^= noise_state << 13;
noise_state ^= noise_state >> 17;
noise_state ^= noise_state << 5;
let noise = (noise_state as f32 / u32::MAX as f32 - 0.5) * 10e-6;
signal[ch * n_samples + t] = sine + noise;
}
}
signal
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
let t0 = Instant::now();
let device = backend::device();
println!("╔══════════════════════════════════════════════════════════════╗");
println!("║ OSF — PSG Embedding Extraction ║");
println!("╚══════════════════════════════════════════════════════════════╝\n");
let model_cfg = if let Some(ref cfg_path) = args.config {
let cfg_str = std::fs::read_to_string(cfg_path)?;
serde_json::from_str(&cfg_str)?
} else {
osf_rs::ModelConfig::default()
};
println!("▸ Loading OSF model …");
let (encoder, ms_load) = osf_rs::OsfEncoder::<B>::load_with_config(
model_cfg,
Path::new(&args.weights),
device.clone(),
)?;
println!(" {} ({ms_load:.0} ms)\n", encoder.describe());
let n_channels = osf_rs::NUM_PSG_CHANNELS;
let n_samples = osf_rs::EPOCH_SAMPLES;
let n_epochs = 3;
println!("▸ Generating {} synthetic PSG epochs ({} ch × {} samples each)\n",
n_epochs, n_channels, n_samples);
let mut all_outputs = Vec::new();
for epoch_idx in 0..n_epochs {
let signal = generate_synthetic_psg(n_channels, n_samples);
let batch = osf_rs::build_batch::<B>(signal, n_channels, n_samples, &device);
let t = Instant::now();
let result = encoder.run_batch(&batch)?;
let ms = t.elapsed().as_secs_f64() * 1000.0;
let cls = &result.cls_emb;
let mean: f64 = cls.iter().map(|&v| v as f64).sum::<f64>() / cls.len() as f64;
let std: f64 = (cls.iter().map(|&v| { let d = v as f64 - mean; d * d }).sum::<f64>()
/ cls.len() as f64).sqrt();
println!(" Epoch {epoch_idx}: cls=[{}] patches=[{},{}] mean={mean:+.4} std={std:.4} {ms:.1}ms",
result.embed_dim, result.num_patches, result.embed_dim);
if args.verbose && epoch_idx == 0 {
println!(" First 5 CLS values: {:?}", &cls[..5.min(cls.len())]);
}
all_outputs.push(result);
}
let encoding = osf_rs::EncodingResult {
epochs: all_outputs,
ms_load,
ms_encode: t0.elapsed().as_secs_f64() * 1000.0,
};
if let Some(p) = Path::new(&args.output).parent() {
std::fs::create_dir_all(p)?;
}
encoding.save_safetensors(&args.output)?;
println!("\n▸ Saved {} epochs → {}", n_epochs, args.output);
let ms_total = t0.elapsed().as_secs_f64() * 1000.0;
println!("\n Total: {ms_total:.0} ms");
Ok(())
}