Skip to main content

embed/
embed.rs

1/// Example: Extract embeddings from synthetic PSG data.
2///
3/// Usage:
4///   cargo run --example embed --release -- --weights data/osf_backbone.safetensors
5
6use std::path::Path;
7use std::time::Instant;
8use clap::Parser;
9
10// ── Backend ───────────────────────────────────────────────────────────────────
11#[cfg(feature = "wgpu")]
12mod backend {
13    pub use burn::backend::Wgpu as B;
14    pub fn device() -> burn::backend::wgpu::WgpuDevice {
15        burn::backend::wgpu::WgpuDevice::default()
16    }
17}
18#[cfg(not(feature = "wgpu"))]
19mod backend {
20    pub use burn::backend::NdArray as B;
21    pub fn device() -> burn::backend::ndarray::NdArrayDevice {
22        burn::backend::ndarray::NdArrayDevice::Cpu
23    }
24}
25use backend::B;
26
27#[derive(Parser, Debug)]
28#[command(about = "OSF — PSG embedding extraction")]
29struct Args {
30    #[arg(long)]
31    weights: String,
32    #[arg(long)]
33    config: Option<String>,
34    #[arg(long, default_value = "data/embeddings.safetensors")]
35    output: String,
36    #[arg(long, short = 'v')]
37    verbose: bool,
38}
39
40fn generate_synthetic_psg(n_channels: usize, n_samples: usize) -> Vec<f32> {
41    let mut signal = vec![0.0f32; n_channels * n_samples];
42    for ch in 0..n_channels {
43        let freq = 1.0 + ch as f32 * 0.5;
44        let mut noise_state: u32 = (ch as u32 + 1) * 0xDEAD_BEEF;
45        for t in 0..n_samples {
46            let time = t as f32 / 64.0; // 64 Hz
47            let sine = (2.0 * std::f32::consts::PI * freq * time).sin() * 50e-6;
48            noise_state ^= noise_state << 13;
49            noise_state ^= noise_state >> 17;
50            noise_state ^= noise_state << 5;
51            let noise = (noise_state as f32 / u32::MAX as f32 - 0.5) * 10e-6;
52            signal[ch * n_samples + t] = sine + noise;
53        }
54    }
55    signal
56}
57
58fn main() -> anyhow::Result<()> {
59    let args = Args::parse();
60    let t0 = Instant::now();
61    let device = backend::device();
62
63    println!("╔══════════════════════════════════════════════════════════════╗");
64    println!("║  OSF — PSG Embedding Extraction                             ║");
65    println!("╚══════════════════════════════════════════════════════════════╝\n");
66
67    // 1. Load model
68    let model_cfg = if let Some(ref cfg_path) = args.config {
69        let cfg_str = std::fs::read_to_string(cfg_path)?;
70        serde_json::from_str(&cfg_str)?
71    } else {
72        osf_rs::ModelConfig::default()
73    };
74
75    println!("▸ Loading OSF model …");
76    let (encoder, ms_load) = osf_rs::OsfEncoder::<B>::load_with_config(
77        model_cfg,
78        Path::new(&args.weights),
79        device.clone(),
80    )?;
81    println!("  {}  ({ms_load:.0} ms)\n", encoder.describe());
82
83    // 2. Generate synthetic PSG epochs
84    let n_channels = osf_rs::NUM_PSG_CHANNELS;
85    let n_samples = osf_rs::EPOCH_SAMPLES;
86    let n_epochs = 3;
87
88    println!("▸ Generating {} synthetic PSG epochs ({} ch × {} samples each)\n",
89        n_epochs, n_channels, n_samples);
90
91    let mut all_outputs = Vec::new();
92
93    for epoch_idx in 0..n_epochs {
94        let signal = generate_synthetic_psg(n_channels, n_samples);
95        let batch = osf_rs::build_batch::<B>(signal, n_channels, n_samples, &device);
96
97        let t = Instant::now();
98        let result = encoder.run_batch(&batch)?;
99        let ms = t.elapsed().as_secs_f64() * 1000.0;
100
101        // Stats
102        let cls = &result.cls_emb;
103        let mean: f64 = cls.iter().map(|&v| v as f64).sum::<f64>() / cls.len() as f64;
104        let std: f64 = (cls.iter().map(|&v| { let d = v as f64 - mean; d * d }).sum::<f64>()
105            / cls.len() as f64).sqrt();
106
107        println!("  Epoch {epoch_idx}: cls=[{}]  patches=[{},{}]  mean={mean:+.4}  std={std:.4}  {ms:.1}ms",
108            result.embed_dim, result.num_patches, result.embed_dim);
109
110        if args.verbose && epoch_idx == 0 {
111            println!("    First 5 CLS values: {:?}", &cls[..5.min(cls.len())]);
112        }
113
114        all_outputs.push(result);
115    }
116
117    // 3. Save
118    let encoding = osf_rs::EncodingResult {
119        epochs: all_outputs,
120        ms_load,
121        ms_encode: t0.elapsed().as_secs_f64() * 1000.0,
122    };
123
124    if let Some(p) = Path::new(&args.output).parent() {
125        std::fs::create_dir_all(p)?;
126    }
127    encoding.save_safetensors(&args.output)?;
128    println!("\n▸ Saved {} epochs → {}", n_epochs, args.output);
129
130    let ms_total = t0.elapsed().as_secs_f64() * 1000.0;
131    println!("\n  Total: {ms_total:.0} ms");
132    Ok(())
133}