1use std::path::Path;
7use std::time::Instant;
8use clap::Parser;
9
10#[cfg(feature = "wgpu")]
12mod backend {
13 pub use burn::backend::Wgpu as B;
14 pub fn device() -> burn::backend::wgpu::WgpuDevice {
15 burn::backend::wgpu::WgpuDevice::default()
16 }
17}
18#[cfg(not(feature = "wgpu"))]
19mod backend {
20 pub use burn::backend::NdArray as B;
21 pub fn device() -> burn::backend::ndarray::NdArrayDevice {
22 burn::backend::ndarray::NdArrayDevice::Cpu
23 }
24}
25use backend::B;
26
27#[derive(Parser, Debug)]
28#[command(about = "OSF — PSG embedding extraction")]
29struct Args {
30 #[arg(long)]
31 weights: String,
32 #[arg(long)]
33 config: Option<String>,
34 #[arg(long, default_value = "data/embeddings.safetensors")]
35 output: String,
36 #[arg(long, short = 'v')]
37 verbose: bool,
38}
39
40fn generate_synthetic_psg(n_channels: usize, n_samples: usize) -> Vec<f32> {
41 let mut signal = vec![0.0f32; n_channels * n_samples];
42 for ch in 0..n_channels {
43 let freq = 1.0 + ch as f32 * 0.5;
44 let mut noise_state: u32 = (ch as u32 + 1) * 0xDEAD_BEEF;
45 for t in 0..n_samples {
46 let time = t as f32 / 64.0; let sine = (2.0 * std::f32::consts::PI * freq * time).sin() * 50e-6;
48 noise_state ^= noise_state << 13;
49 noise_state ^= noise_state >> 17;
50 noise_state ^= noise_state << 5;
51 let noise = (noise_state as f32 / u32::MAX as f32 - 0.5) * 10e-6;
52 signal[ch * n_samples + t] = sine + noise;
53 }
54 }
55 signal
56}
57
58fn main() -> anyhow::Result<()> {
59 let args = Args::parse();
60 let t0 = Instant::now();
61 let device = backend::device();
62
63 println!("╔══════════════════════════════════════════════════════════════╗");
64 println!("║ OSF — PSG Embedding Extraction ║");
65 println!("╚══════════════════════════════════════════════════════════════╝\n");
66
67 let model_cfg = if let Some(ref cfg_path) = args.config {
69 let cfg_str = std::fs::read_to_string(cfg_path)?;
70 serde_json::from_str(&cfg_str)?
71 } else {
72 osf_rs::ModelConfig::default()
73 };
74
75 println!("▸ Loading OSF model …");
76 let (encoder, ms_load) = osf_rs::OsfEncoder::<B>::load_with_config(
77 model_cfg,
78 Path::new(&args.weights),
79 device.clone(),
80 )?;
81 println!(" {} ({ms_load:.0} ms)\n", encoder.describe());
82
83 let n_channels = osf_rs::NUM_PSG_CHANNELS;
85 let n_samples = osf_rs::EPOCH_SAMPLES;
86 let n_epochs = 3;
87
88 println!("▸ Generating {} synthetic PSG epochs ({} ch × {} samples each)\n",
89 n_epochs, n_channels, n_samples);
90
91 let mut all_outputs = Vec::new();
92
93 for epoch_idx in 0..n_epochs {
94 let signal = generate_synthetic_psg(n_channels, n_samples);
95 let batch = osf_rs::build_batch::<B>(signal, n_channels, n_samples, &device);
96
97 let t = Instant::now();
98 let result = encoder.run_batch(&batch)?;
99 let ms = t.elapsed().as_secs_f64() * 1000.0;
100
101 let cls = &result.cls_emb;
103 let mean: f64 = cls.iter().map(|&v| v as f64).sum::<f64>() / cls.len() as f64;
104 let std: f64 = (cls.iter().map(|&v| { let d = v as f64 - mean; d * d }).sum::<f64>()
105 / cls.len() as f64).sqrt();
106
107 println!(" Epoch {epoch_idx}: cls=[{}] patches=[{},{}] mean={mean:+.4} std={std:.4} {ms:.1}ms",
108 result.embed_dim, result.num_patches, result.embed_dim);
109
110 if args.verbose && epoch_idx == 0 {
111 println!(" First 5 CLS values: {:?}", &cls[..5.min(cls.len())]);
112 }
113
114 all_outputs.push(result);
115 }
116
117 let encoding = osf_rs::EncodingResult {
119 epochs: all_outputs,
120 ms_load,
121 ms_encode: t0.elapsed().as_secs_f64() * 1000.0,
122 };
123
124 if let Some(p) = Path::new(&args.output).parent() {
125 std::fs::create_dir_all(p)?;
126 }
127 encoding.save_safetensors(&args.output)?;
128 println!("\n▸ Saved {} epochs → {}", n_epochs, args.output);
129
130 let ms_total = t0.elapsed().as_secs_f64() * 1000.0;
131 println!("\n Total: {ms_total:.0} ms");
132 Ok(())
133}