use std::{path::Path, time::Instant};
use clap::Parser;
#[cfg(all(feature = "wgpu", not(feature = "ndarray")))]
mod backend {
pub use burn::backend::{Wgpu as B, wgpu::WgpuDevice as Device};
pub fn device() -> Device { Device::DefaultDevice }
#[cfg(feature = "metal")]
pub const NAME: &str = "GPU (wgpu — Metal / MSL)";
#[cfg(feature = "vulkan")]
pub const NAME: &str = "GPU (wgpu — Vulkan / SPIR-V)";
#[cfg(not(any(feature = "metal", feature = "vulkan")))]
pub const NAME: &str = "GPU (wgpu — WGSL)";
}
#[cfg(feature = "ndarray")]
mod backend {
pub use burn::backend::NdArray as B;
pub type Device = burn::backend::ndarray::NdArrayDevice;
pub fn device() -> Device { Device::Cpu }
#[cfg(feature = "blas-accelerate")]
pub const NAME: &str = "CPU (NdArray + Apple Accelerate)";
#[cfg(feature = "openblas-system")]
pub const NAME: &str = "CPU (NdArray + OpenBLAS)";
#[cfg(not(any(feature = "blas-accelerate", feature = "openblas-system")))]
pub const NAME: &str = "CPU (NdArray + Rayon)";
}
use backend::{B, device};
#[derive(Parser, Debug)]
#[command(about = "OSF Sleep Foundation Model inference (Burn 0.20.1)")]
struct Args {
#[arg(long)]
weights: String,
#[arg(long)]
config: Option<String>,
#[arg(long)]
output: Option<String>,
#[arg(long, short = 'v')]
verbose: bool,
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
let t0 = Instant::now();
let dev = device();
println!("Backend : {}", backend::NAME);
let model_cfg = if let Some(ref cfg_path) = args.config {
let cfg_str = std::fs::read_to_string(cfg_path)?;
serde_json::from_str(&cfg_str)?
} else {
osf_rs::ModelConfig::default()
};
let (encoder, ms_weights) = osf_rs::OsfEncoder::<B>::load_with_config(
model_cfg,
Path::new(&args.weights),
dev.clone(),
)?;
println!("Model : {} ({ms_weights:.0} ms)", encoder.describe());
let n_channels = osf_rs::NUM_PSG_CHANNELS;
let n_samples = osf_rs::EPOCH_SAMPLES;
let signal = vec![0.0f32; n_channels * n_samples];
let batch = osf_rs::build_batch::<B>(signal, n_channels, n_samples, &dev);
let t_inf = Instant::now();
let result = encoder.run_batch(&batch)?;
let ms_infer = t_inf.elapsed().as_secs_f64() * 1000.0;
println!("Output : cls=[{}] patches=[{}, {}] ({ms_infer:.1} ms)",
result.embed_dim, result.num_patches, result.embed_dim);
if args.verbose {
let mean: f64 = result.cls_emb.iter().map(|&v| v as f64).sum::<f64>()
/ result.cls_emb.len() as f64;
let std: f64 = (result.cls_emb.iter().map(|&v| {
let d = v as f64 - mean; d * d
}).sum::<f64>() / result.cls_emb.len() as f64).sqrt();
println!(" CLS: mean={mean:+.6} std={std:.6}");
let p_mean: f64 = result.patch_embs.iter().map(|&v| v as f64).sum::<f64>()
/ result.patch_embs.len() as f64;
let p_std: f64 = (result.patch_embs.iter().map(|&v| {
let d = v as f64 - p_mean; d * d
}).sum::<f64>() / result.patch_embs.len() as f64).sqrt();
println!(" Patches: mean={p_mean:+.6} std={p_std:.6}");
}
if let Some(ref out_path) = args.output {
let encoding = osf_rs::EncodingResult {
epochs: vec![result],
ms_load: ms_weights,
ms_encode: ms_infer,
};
if let Some(p) = Path::new(out_path).parent() {
std::fs::create_dir_all(p)?;
}
encoding.save_safetensors(out_path)?;
println!("Saved → {out_path}");
}
let ms_total = t0.elapsed().as_secs_f64() * 1000.0;
println!("── Timing ───────────────────────────────────────────────────────");
println!(" Weights : {ms_weights:.0} ms");
println!(" Infer : {ms_infer:.0} ms");
println!(" Total : {ms_total:.0} ms");
Ok(())
}