Skip to main content

OsfEncoder

Struct OsfEncoder 

Source
pub struct OsfEncoder<B: Backend> {
    pub model_cfg: ModelConfig,
    /* private fields */
}
Expand description

High-level OSF encoder for PSG signal processing.

Fields§

§model_cfg: ModelConfig

Implementations§

Source§

impl<B: Backend> OsfEncoder<B>

Source

pub fn load( config_path: &Path, weights_path: &Path, device: B::Device, ) -> Result<(Self, f64)>

Load model from config JSON and safetensors weights.

Source

pub fn load_with_config( model_cfg: ModelConfig, weights_path: &Path, device: B::Device, ) -> Result<(Self, f64)>

Load model from a ModelConfig and safetensors path directly.

Examples found in repository?
examples/embed.rs (lines 76-80)
58fn main() -> anyhow::Result<()> {
59    let args = Args::parse();
60    let t0 = Instant::now();
61    let device = backend::device();
62
63    println!("╔══════════════════════════════════════════════════════════════╗");
64    println!("║  OSF — PSG Embedding Extraction                             ║");
65    println!("╚══════════════════════════════════════════════════════════════╝\n");
66
67    // 1. Load model
68    let model_cfg = if let Some(ref cfg_path) = args.config {
69        let cfg_str = std::fs::read_to_string(cfg_path)?;
70        serde_json::from_str(&cfg_str)?
71    } else {
72        osf_rs::ModelConfig::default()
73    };
74
75    println!("▸ Loading OSF model …");
76    let (encoder, ms_load) = osf_rs::OsfEncoder::<B>::load_with_config(
77        model_cfg,
78        Path::new(&args.weights),
79        device.clone(),
80    )?;
81    println!("  {}  ({ms_load:.0} ms)\n", encoder.describe());
82
83    // 2. Generate synthetic PSG epochs
84    let n_channels = osf_rs::NUM_PSG_CHANNELS;
85    let n_samples = osf_rs::EPOCH_SAMPLES;
86    let n_epochs = 3;
87
88    println!("▸ Generating {} synthetic PSG epochs ({} ch × {} samples each)\n",
89        n_epochs, n_channels, n_samples);
90
91    let mut all_outputs = Vec::new();
92
93    for epoch_idx in 0..n_epochs {
94        let signal = generate_synthetic_psg(n_channels, n_samples);
95        let batch = osf_rs::build_batch::<B>(signal, n_channels, n_samples, &device);
96
97        let t = Instant::now();
98        let result = encoder.run_batch(&batch)?;
99        let ms = t.elapsed().as_secs_f64() * 1000.0;
100
101        // Stats
102        let cls = &result.cls_emb;
103        let mean: f64 = cls.iter().map(|&v| v as f64).sum::<f64>() / cls.len() as f64;
104        let std: f64 = (cls.iter().map(|&v| { let d = v as f64 - mean; d * d }).sum::<f64>()
105            / cls.len() as f64).sqrt();
106
107        println!("  Epoch {epoch_idx}: cls=[{}]  patches=[{},{}]  mean={mean:+.4}  std={std:.4}  {ms:.1}ms",
108            result.embed_dim, result.num_patches, result.embed_dim);
109
110        if args.verbose && epoch_idx == 0 {
111            println!("    First 5 CLS values: {:?}", &cls[..5.min(cls.len())]);
112        }
113
114        all_outputs.push(result);
115    }
116
117    // 3. Save
118    let encoding = osf_rs::EncodingResult {
119        epochs: all_outputs,
120        ms_load,
121        ms_encode: t0.elapsed().as_secs_f64() * 1000.0,
122    };
123
124    if let Some(p) = Path::new(&args.output).parent() {
125        std::fs::create_dir_all(p)?;
126    }
127    encoding.save_safetensors(&args.output)?;
128    println!("\n▸ Saved {} epochs → {}", n_epochs, args.output);
129
130    let ms_total = t0.elapsed().as_secs_f64() * 1000.0;
131    println!("\n  Total: {ms_total:.0} ms");
132    Ok(())
133}
More examples
Hide additional examples
examples/benchmark.rs (lines 116-120)
95fn main() -> anyhow::Result<()> {
96    let args = Args::parse();
97    let dev = device();
98    let json_mode = args.json;
99
100    if !json_mode {
101        eprintln!("╔══════════════════════════════════════════════════════════════╗");
102        eprintln!("║  OSF-RS — Inference Benchmark                               ║");
103        eprintln!("╚══════════════════════════════════════════════════════════════╝\n");
104        eprintln!("  Backend: {}", backend::NAME);
105    }
106
107    // Load config
108    let model_cfg = if let Some(ref cfg_path) = args.config {
109        let s = std::fs::read_to_string(cfg_path)?;
110        serde_json::from_str(&s)?
111    } else {
112        osf_rs::ModelConfig::default()
113    };
114
115    // ── 1. Weight loading benchmark ─────────────────────────────────────────
116    let (encoder, ms_load) = osf_rs::OsfEncoder::<B>::load_with_config(
117        model_cfg.clone(),
118        Path::new(&args.weights),
119        dev.clone(),
120    )?;
121
122    if !json_mode {
123        eprintln!("  Model:   {}", encoder.describe());
124        eprintln!("  Load:    {ms_load:.0} ms\n");
125    }
126
127    let n_ch = osf_rs::NUM_PSG_CHANNELS;   // 12
128    let n_t  = osf_rs::EPOCH_SAMPLES;      // 1920
129
130    // ── 2. Standard inference benchmark (B=1) ───────────────────────────────
131    if !json_mode {
132        eprintln!("  ▸ Standard inference (B=1, {}ch × {} samples)", n_ch, n_t);
133    }
134
135    let signal = generate_psg(n_ch, n_t, 42);
136    let batch = osf_rs::build_batch::<B>(signal, n_ch, n_t, &dev);
137
138    // Warmup
139    for _ in 0..args.warmup {
140        let _ = encoder.run_batch(&batch)?;
141    }
142
143    // Timed runs
144    let mut infer_times = Vec::with_capacity(args.runs);
145    for _ in 0..args.runs {
146        let (_, ms) = timed(|| encoder.run_batch(&batch));
147        infer_times.push(ms);
148    }
149
150    let infer_mean = infer_times.iter().sum::<f64>() / infer_times.len() as f64;
151    let infer_min = infer_times.iter().cloned().fold(f64::INFINITY, f64::min);
152    let infer_max = infer_times.iter().cloned().fold(0.0f64, f64::max);
153    let infer_std = (infer_times.iter().map(|t| (t - infer_mean).powi(2)).sum::<f64>()
154        / infer_times.len() as f64).sqrt();
155
156    if !json_mode {
157        eprintln!("    mean={infer_mean:.1}ms  min={infer_min:.1}ms  max={infer_max:.1}ms  std={infer_std:.1}ms  (n={})",
158            args.runs);
159    }
160
161    // ── 3. Batch size scaling ───────────────────────────────────────────────
162    let batch_sizes = [1, 2, 4, 8, 16];
163    let mut batch_scaling: Vec<serde_json::Value> = Vec::new();
164
165    if !json_mode {
166        eprintln!("\n  ▸ Batch size scaling ({}ch × {} samples):", n_ch, n_t);
167        eprintln!("    {:>6}  {:>10}  {:>12}", "Batch", "Mean (ms)", "Per-epoch (ms)");
168    }
169
170    for &bs in &batch_sizes {
171        // Build a batched input: [bs, 12, 1920]
172        let signal_batch: Vec<f32> = (0..bs).flat_map(|i| generate_psg(n_ch, n_t, 42 + i as u32)).collect();
173        let signal_tensor = burn::tensor::Tensor::<B, 2>::from_data(
174            burn::tensor::TensorData::new(signal_batch, vec![bs * n_ch, n_t]),
175            &dev,
176        ).reshape([bs, n_ch, n_t]);
177
178        // Warmup
179        let _ = encoder.model().forward_encoding(signal_tensor.clone());
180
181        // Timed
182        let mut t_vec = Vec::new();
183        for _ in 0..5.max(args.runs / 2) {
184            let (_, ms) = timed(|| encoder.model().forward_encoding(signal_tensor.clone()));
185            t_vec.push(ms);
186        }
187        let avg = t_vec.iter().sum::<f64>() / t_vec.len() as f64;
188        let per_epoch = avg / bs as f64;
189        let bmin = t_vec.iter().cloned().fold(f64::INFINITY, f64::min);
190        let bmax = t_vec.iter().cloned().fold(0.0f64, f64::max);
191
192        if !json_mode {
193            eprintln!("    {:>6}  {:>7.1} ms  {:>9.1} ms", bs, avg, per_epoch);
194        }
195
196        batch_scaling.push(serde_json::json!({
197            "batch_size": bs,
198            "mean_ms": round2(avg),
199            "min_ms": round2(bmin),
200            "max_ms": round2(bmax),
201            "per_epoch_ms": round2(per_epoch),
202            "runs": t_vec,
203        }));
204    }
205
206    // ── 4. Throughput (epochs/sec) ──────────────────────────────────────────
207    let throughput_epochs_sec = 1000.0 / infer_mean;
208
209    if !json_mode {
210        eprintln!("\n  ▸ Throughput: {throughput_epochs_sec:.1} epochs/sec (single-epoch inference)");
211    }
212
213    // ── 5. Channel subset scaling (robustness test with padding) ────────────
214    let channel_subsets = [2, 4, 6, 8, 10, 12];
215    let mut channel_scaling: Vec<serde_json::Value> = Vec::new();
216
217    if !json_mode {
218        eprintln!("\n  ▸ Channel count scaling (zero-padded to 12ch, T={}):", n_t);
219        eprintln!("    {:>6}  {:>10}", "Active", "Mean (ms)");
220    }
221
222    for &active_ch in &channel_subsets {
223        // Create signal with `active_ch` active channels, rest zero-padded to 12
224        let mut sig = vec![0.0f32; n_ch * n_t];
225        let active = generate_psg(active_ch, n_t, 100 + active_ch as u32);
226        for ch in 0..active_ch {
227            for t in 0..n_t {
228                sig[ch * n_t + t] = active[ch * n_t + t];
229            }
230        }
231        let b = osf_rs::build_batch::<B>(sig, n_ch, n_t, &dev);
232
233        let _ = encoder.run_batch(&b)?; // warmup
234        let mut t_vec = Vec::new();
235        for _ in 0..5 {
236            let (_, ms) = timed(|| encoder.run_batch(&b));
237            t_vec.push(ms);
238        }
239        let avg = t_vec.iter().sum::<f64>() / t_vec.len() as f64;
240        let cmin = t_vec.iter().cloned().fold(f64::INFINITY, f64::min);
241        let cmax = t_vec.iter().cloned().fold(0.0f64, f64::max);
242
243        if !json_mode {
244            eprintln!("    {:>6}  {:>7.1} ms", active_ch, avg);
245        }
246
247        channel_scaling.push(serde_json::json!({
248            "active_channels": active_ch,
249            "total_channels": n_ch,
250            "mean_ms": round2(avg),
251            "min_ms": round2(cmin),
252            "max_ms": round2(cmax),
253            "runs": t_vec,
254        }));
255    }
256
257    if !json_mode { eprintln!(); }
258
259    // ── JSON output ─────────────────────────────────────────────────────────
260    let result = serde_json::json!({
261        "backend": backend::NAME,
262        "model": {
263            "encoder_name": model_cfg.encoder_name,
264            "width": model_cfg.width,
265            "depth": model_cfg.depth,
266            "heads": model_cfg.heads,
267            "num_leads": model_cfg.num_leads,
268            "patch_size_time": model_cfg.patch_size_time,
269            "patch_size_ch": model_cfg.patch_size_ch,
270            "seq_len": model_cfg.seq_len,
271        },
272        "load_ms": round2(ms_load),
273        "inference": {
274            "channels": n_ch,
275            "samples": n_t,
276            "warmup": args.warmup,
277            "runs": args.runs,
278            "mean_ms": round2(infer_mean),
279            "min_ms": round2(infer_min),
280            "max_ms": round2(infer_max),
281            "std_ms": round2(infer_std),
282            "all_ms": infer_times,
283        },
284        "batch_scaling": batch_scaling,
285        "channel_scaling": channel_scaling,
286        "throughput_epochs_sec": round2(throughput_epochs_sec),
287    });
288
289    if json_mode {
290        println!("{}", serde_json::to_string_pretty(&result)?);
291    }
292
293    Ok(())
294}
Source

pub fn describe(&self) -> String

Examples found in repository?
examples/embed.rs (line 81)
58fn main() -> anyhow::Result<()> {
59    let args = Args::parse();
60    let t0 = Instant::now();
61    let device = backend::device();
62
63    println!("╔══════════════════════════════════════════════════════════════╗");
64    println!("║  OSF — PSG Embedding Extraction                             ║");
65    println!("╚══════════════════════════════════════════════════════════════╝\n");
66
67    // 1. Load model
68    let model_cfg = if let Some(ref cfg_path) = args.config {
69        let cfg_str = std::fs::read_to_string(cfg_path)?;
70        serde_json::from_str(&cfg_str)?
71    } else {
72        osf_rs::ModelConfig::default()
73    };
74
75    println!("▸ Loading OSF model …");
76    let (encoder, ms_load) = osf_rs::OsfEncoder::<B>::load_with_config(
77        model_cfg,
78        Path::new(&args.weights),
79        device.clone(),
80    )?;
81    println!("  {}  ({ms_load:.0} ms)\n", encoder.describe());
82
83    // 2. Generate synthetic PSG epochs
84    let n_channels = osf_rs::NUM_PSG_CHANNELS;
85    let n_samples = osf_rs::EPOCH_SAMPLES;
86    let n_epochs = 3;
87
88    println!("▸ Generating {} synthetic PSG epochs ({} ch × {} samples each)\n",
89        n_epochs, n_channels, n_samples);
90
91    let mut all_outputs = Vec::new();
92
93    for epoch_idx in 0..n_epochs {
94        let signal = generate_synthetic_psg(n_channels, n_samples);
95        let batch = osf_rs::build_batch::<B>(signal, n_channels, n_samples, &device);
96
97        let t = Instant::now();
98        let result = encoder.run_batch(&batch)?;
99        let ms = t.elapsed().as_secs_f64() * 1000.0;
100
101        // Stats
102        let cls = &result.cls_emb;
103        let mean: f64 = cls.iter().map(|&v| v as f64).sum::<f64>() / cls.len() as f64;
104        let std: f64 = (cls.iter().map(|&v| { let d = v as f64 - mean; d * d }).sum::<f64>()
105            / cls.len() as f64).sqrt();
106
107        println!("  Epoch {epoch_idx}: cls=[{}]  patches=[{},{}]  mean={mean:+.4}  std={std:.4}  {ms:.1}ms",
108            result.embed_dim, result.num_patches, result.embed_dim);
109
110        if args.verbose && epoch_idx == 0 {
111            println!("    First 5 CLS values: {:?}", &cls[..5.min(cls.len())]);
112        }
113
114        all_outputs.push(result);
115    }
116
117    // 3. Save
118    let encoding = osf_rs::EncodingResult {
119        epochs: all_outputs,
120        ms_load,
121        ms_encode: t0.elapsed().as_secs_f64() * 1000.0,
122    };
123
124    if let Some(p) = Path::new(&args.output).parent() {
125        std::fs::create_dir_all(p)?;
126    }
127    encoding.save_safetensors(&args.output)?;
128    println!("\n▸ Saved {} epochs → {}", n_epochs, args.output);
129
130    let ms_total = t0.elapsed().as_secs_f64() * 1000.0;
131    println!("\n  Total: {ms_total:.0} ms");
132    Ok(())
133}
More examples
Hide additional examples
examples/benchmark.rs (line 123)
95fn main() -> anyhow::Result<()> {
96    let args = Args::parse();
97    let dev = device();
98    let json_mode = args.json;
99
100    if !json_mode {
101        eprintln!("╔══════════════════════════════════════════════════════════════╗");
102        eprintln!("║  OSF-RS — Inference Benchmark                               ║");
103        eprintln!("╚══════════════════════════════════════════════════════════════╝\n");
104        eprintln!("  Backend: {}", backend::NAME);
105    }
106
107    // Load config
108    let model_cfg = if let Some(ref cfg_path) = args.config {
109        let s = std::fs::read_to_string(cfg_path)?;
110        serde_json::from_str(&s)?
111    } else {
112        osf_rs::ModelConfig::default()
113    };
114
115    // ── 1. Weight loading benchmark ─────────────────────────────────────────
116    let (encoder, ms_load) = osf_rs::OsfEncoder::<B>::load_with_config(
117        model_cfg.clone(),
118        Path::new(&args.weights),
119        dev.clone(),
120    )?;
121
122    if !json_mode {
123        eprintln!("  Model:   {}", encoder.describe());
124        eprintln!("  Load:    {ms_load:.0} ms\n");
125    }
126
127    let n_ch = osf_rs::NUM_PSG_CHANNELS;   // 12
128    let n_t  = osf_rs::EPOCH_SAMPLES;      // 1920
129
130    // ── 2. Standard inference benchmark (B=1) ───────────────────────────────
131    if !json_mode {
132        eprintln!("  ▸ Standard inference (B=1, {}ch × {} samples)", n_ch, n_t);
133    }
134
135    let signal = generate_psg(n_ch, n_t, 42);
136    let batch = osf_rs::build_batch::<B>(signal, n_ch, n_t, &dev);
137
138    // Warmup
139    for _ in 0..args.warmup {
140        let _ = encoder.run_batch(&batch)?;
141    }
142
143    // Timed runs
144    let mut infer_times = Vec::with_capacity(args.runs);
145    for _ in 0..args.runs {
146        let (_, ms) = timed(|| encoder.run_batch(&batch));
147        infer_times.push(ms);
148    }
149
150    let infer_mean = infer_times.iter().sum::<f64>() / infer_times.len() as f64;
151    let infer_min = infer_times.iter().cloned().fold(f64::INFINITY, f64::min);
152    let infer_max = infer_times.iter().cloned().fold(0.0f64, f64::max);
153    let infer_std = (infer_times.iter().map(|t| (t - infer_mean).powi(2)).sum::<f64>()
154        / infer_times.len() as f64).sqrt();
155
156    if !json_mode {
157        eprintln!("    mean={infer_mean:.1}ms  min={infer_min:.1}ms  max={infer_max:.1}ms  std={infer_std:.1}ms  (n={})",
158            args.runs);
159    }
160
161    // ── 3. Batch size scaling ───────────────────────────────────────────────
162    let batch_sizes = [1, 2, 4, 8, 16];
163    let mut batch_scaling: Vec<serde_json::Value> = Vec::new();
164
165    if !json_mode {
166        eprintln!("\n  ▸ Batch size scaling ({}ch × {} samples):", n_ch, n_t);
167        eprintln!("    {:>6}  {:>10}  {:>12}", "Batch", "Mean (ms)", "Per-epoch (ms)");
168    }
169
170    for &bs in &batch_sizes {
171        // Build a batched input: [bs, 12, 1920]
172        let signal_batch: Vec<f32> = (0..bs).flat_map(|i| generate_psg(n_ch, n_t, 42 + i as u32)).collect();
173        let signal_tensor = burn::tensor::Tensor::<B, 2>::from_data(
174            burn::tensor::TensorData::new(signal_batch, vec![bs * n_ch, n_t]),
175            &dev,
176        ).reshape([bs, n_ch, n_t]);
177
178        // Warmup
179        let _ = encoder.model().forward_encoding(signal_tensor.clone());
180
181        // Timed
182        let mut t_vec = Vec::new();
183        for _ in 0..5.max(args.runs / 2) {
184            let (_, ms) = timed(|| encoder.model().forward_encoding(signal_tensor.clone()));
185            t_vec.push(ms);
186        }
187        let avg = t_vec.iter().sum::<f64>() / t_vec.len() as f64;
188        let per_epoch = avg / bs as f64;
189        let bmin = t_vec.iter().cloned().fold(f64::INFINITY, f64::min);
190        let bmax = t_vec.iter().cloned().fold(0.0f64, f64::max);
191
192        if !json_mode {
193            eprintln!("    {:>6}  {:>7.1} ms  {:>9.1} ms", bs, avg, per_epoch);
194        }
195
196        batch_scaling.push(serde_json::json!({
197            "batch_size": bs,
198            "mean_ms": round2(avg),
199            "min_ms": round2(bmin),
200            "max_ms": round2(bmax),
201            "per_epoch_ms": round2(per_epoch),
202            "runs": t_vec,
203        }));
204    }
205
206    // ── 4. Throughput (epochs/sec) ──────────────────────────────────────────
207    let throughput_epochs_sec = 1000.0 / infer_mean;
208
209    if !json_mode {
210        eprintln!("\n  ▸ Throughput: {throughput_epochs_sec:.1} epochs/sec (single-epoch inference)");
211    }
212
213    // ── 5. Channel subset scaling (robustness test with padding) ────────────
214    let channel_subsets = [2, 4, 6, 8, 10, 12];
215    let mut channel_scaling: Vec<serde_json::Value> = Vec::new();
216
217    if !json_mode {
218        eprintln!("\n  ▸ Channel count scaling (zero-padded to 12ch, T={}):", n_t);
219        eprintln!("    {:>6}  {:>10}", "Active", "Mean (ms)");
220    }
221
222    for &active_ch in &channel_subsets {
223        // Create signal with `active_ch` active channels, rest zero-padded to 12
224        let mut sig = vec![0.0f32; n_ch * n_t];
225        let active = generate_psg(active_ch, n_t, 100 + active_ch as u32);
226        for ch in 0..active_ch {
227            for t in 0..n_t {
228                sig[ch * n_t + t] = active[ch * n_t + t];
229            }
230        }
231        let b = osf_rs::build_batch::<B>(sig, n_ch, n_t, &dev);
232
233        let _ = encoder.run_batch(&b)?; // warmup
234        let mut t_vec = Vec::new();
235        for _ in 0..5 {
236            let (_, ms) = timed(|| encoder.run_batch(&b));
237            t_vec.push(ms);
238        }
239        let avg = t_vec.iter().sum::<f64>() / t_vec.len() as f64;
240        let cmin = t_vec.iter().cloned().fold(f64::INFINITY, f64::min);
241        let cmax = t_vec.iter().cloned().fold(0.0f64, f64::max);
242
243        if !json_mode {
244            eprintln!("    {:>6}  {:>7.1} ms", active_ch, avg);
245        }
246
247        channel_scaling.push(serde_json::json!({
248            "active_channels": active_ch,
249            "total_channels": n_ch,
250            "mean_ms": round2(avg),
251            "min_ms": round2(cmin),
252            "max_ms": round2(cmax),
253            "runs": t_vec,
254        }));
255    }
256
257    if !json_mode { eprintln!(); }
258
259    // ── JSON output ─────────────────────────────────────────────────────────
260    let result = serde_json::json!({
261        "backend": backend::NAME,
262        "model": {
263            "encoder_name": model_cfg.encoder_name,
264            "width": model_cfg.width,
265            "depth": model_cfg.depth,
266            "heads": model_cfg.heads,
267            "num_leads": model_cfg.num_leads,
268            "patch_size_time": model_cfg.patch_size_time,
269            "patch_size_ch": model_cfg.patch_size_ch,
270            "seq_len": model_cfg.seq_len,
271        },
272        "load_ms": round2(ms_load),
273        "inference": {
274            "channels": n_ch,
275            "samples": n_t,
276            "warmup": args.warmup,
277            "runs": args.runs,
278            "mean_ms": round2(infer_mean),
279            "min_ms": round2(infer_min),
280            "max_ms": round2(infer_max),
281            "std_ms": round2(infer_std),
282            "all_ms": infer_times,
283        },
284        "batch_scaling": batch_scaling,
285        "channel_scaling": channel_scaling,
286        "throughput_epochs_sec": round2(throughput_epochs_sec),
287    });
288
289    if json_mode {
290        println!("{}", serde_json::to_string_pretty(&result)?);
291    }
292
293    Ok(())
294}
Source

pub fn run_batch(&self, batch: &InputBatch<B>) -> Result<EpochEmbedding>

Run inference on a prepared InputBatch.

Returns an EpochEmbedding with CLS and patch embeddings.

Examples found in repository?
examples/embed.rs (line 98)
58fn main() -> anyhow::Result<()> {
59    let args = Args::parse();
60    let t0 = Instant::now();
61    let device = backend::device();
62
63    println!("╔══════════════════════════════════════════════════════════════╗");
64    println!("║  OSF — PSG Embedding Extraction                             ║");
65    println!("╚══════════════════════════════════════════════════════════════╝\n");
66
67    // 1. Load model
68    let model_cfg = if let Some(ref cfg_path) = args.config {
69        let cfg_str = std::fs::read_to_string(cfg_path)?;
70        serde_json::from_str(&cfg_str)?
71    } else {
72        osf_rs::ModelConfig::default()
73    };
74
75    println!("▸ Loading OSF model …");
76    let (encoder, ms_load) = osf_rs::OsfEncoder::<B>::load_with_config(
77        model_cfg,
78        Path::new(&args.weights),
79        device.clone(),
80    )?;
81    println!("  {}  ({ms_load:.0} ms)\n", encoder.describe());
82
83    // 2. Generate synthetic PSG epochs
84    let n_channels = osf_rs::NUM_PSG_CHANNELS;
85    let n_samples = osf_rs::EPOCH_SAMPLES;
86    let n_epochs = 3;
87
88    println!("▸ Generating {} synthetic PSG epochs ({} ch × {} samples each)\n",
89        n_epochs, n_channels, n_samples);
90
91    let mut all_outputs = Vec::new();
92
93    for epoch_idx in 0..n_epochs {
94        let signal = generate_synthetic_psg(n_channels, n_samples);
95        let batch = osf_rs::build_batch::<B>(signal, n_channels, n_samples, &device);
96
97        let t = Instant::now();
98        let result = encoder.run_batch(&batch)?;
99        let ms = t.elapsed().as_secs_f64() * 1000.0;
100
101        // Stats
102        let cls = &result.cls_emb;
103        let mean: f64 = cls.iter().map(|&v| v as f64).sum::<f64>() / cls.len() as f64;
104        let std: f64 = (cls.iter().map(|&v| { let d = v as f64 - mean; d * d }).sum::<f64>()
105            / cls.len() as f64).sqrt();
106
107        println!("  Epoch {epoch_idx}: cls=[{}]  patches=[{},{}]  mean={mean:+.4}  std={std:.4}  {ms:.1}ms",
108            result.embed_dim, result.num_patches, result.embed_dim);
109
110        if args.verbose && epoch_idx == 0 {
111            println!("    First 5 CLS values: {:?}", &cls[..5.min(cls.len())]);
112        }
113
114        all_outputs.push(result);
115    }
116
117    // 3. Save
118    let encoding = osf_rs::EncodingResult {
119        epochs: all_outputs,
120        ms_load,
121        ms_encode: t0.elapsed().as_secs_f64() * 1000.0,
122    };
123
124    if let Some(p) = Path::new(&args.output).parent() {
125        std::fs::create_dir_all(p)?;
126    }
127    encoding.save_safetensors(&args.output)?;
128    println!("\n▸ Saved {} epochs → {}", n_epochs, args.output);
129
130    let ms_total = t0.elapsed().as_secs_f64() * 1000.0;
131    println!("\n  Total: {ms_total:.0} ms");
132    Ok(())
133}
More examples
Hide additional examples
examples/benchmark.rs (line 140)
95fn main() -> anyhow::Result<()> {
96    let args = Args::parse();
97    let dev = device();
98    let json_mode = args.json;
99
100    if !json_mode {
101        eprintln!("╔══════════════════════════════════════════════════════════════╗");
102        eprintln!("║  OSF-RS — Inference Benchmark                               ║");
103        eprintln!("╚══════════════════════════════════════════════════════════════╝\n");
104        eprintln!("  Backend: {}", backend::NAME);
105    }
106
107    // Load config
108    let model_cfg = if let Some(ref cfg_path) = args.config {
109        let s = std::fs::read_to_string(cfg_path)?;
110        serde_json::from_str(&s)?
111    } else {
112        osf_rs::ModelConfig::default()
113    };
114
115    // ── 1. Weight loading benchmark ─────────────────────────────────────────
116    let (encoder, ms_load) = osf_rs::OsfEncoder::<B>::load_with_config(
117        model_cfg.clone(),
118        Path::new(&args.weights),
119        dev.clone(),
120    )?;
121
122    if !json_mode {
123        eprintln!("  Model:   {}", encoder.describe());
124        eprintln!("  Load:    {ms_load:.0} ms\n");
125    }
126
127    let n_ch = osf_rs::NUM_PSG_CHANNELS;   // 12
128    let n_t  = osf_rs::EPOCH_SAMPLES;      // 1920
129
130    // ── 2. Standard inference benchmark (B=1) ───────────────────────────────
131    if !json_mode {
132        eprintln!("  ▸ Standard inference (B=1, {}ch × {} samples)", n_ch, n_t);
133    }
134
135    let signal = generate_psg(n_ch, n_t, 42);
136    let batch = osf_rs::build_batch::<B>(signal, n_ch, n_t, &dev);
137
138    // Warmup
139    for _ in 0..args.warmup {
140        let _ = encoder.run_batch(&batch)?;
141    }
142
143    // Timed runs
144    let mut infer_times = Vec::with_capacity(args.runs);
145    for _ in 0..args.runs {
146        let (_, ms) = timed(|| encoder.run_batch(&batch));
147        infer_times.push(ms);
148    }
149
150    let infer_mean = infer_times.iter().sum::<f64>() / infer_times.len() as f64;
151    let infer_min = infer_times.iter().cloned().fold(f64::INFINITY, f64::min);
152    let infer_max = infer_times.iter().cloned().fold(0.0f64, f64::max);
153    let infer_std = (infer_times.iter().map(|t| (t - infer_mean).powi(2)).sum::<f64>()
154        / infer_times.len() as f64).sqrt();
155
156    if !json_mode {
157        eprintln!("    mean={infer_mean:.1}ms  min={infer_min:.1}ms  max={infer_max:.1}ms  std={infer_std:.1}ms  (n={})",
158            args.runs);
159    }
160
161    // ── 3. Batch size scaling ───────────────────────────────────────────────
162    let batch_sizes = [1, 2, 4, 8, 16];
163    let mut batch_scaling: Vec<serde_json::Value> = Vec::new();
164
165    if !json_mode {
166        eprintln!("\n  ▸ Batch size scaling ({}ch × {} samples):", n_ch, n_t);
167        eprintln!("    {:>6}  {:>10}  {:>12}", "Batch", "Mean (ms)", "Per-epoch (ms)");
168    }
169
170    for &bs in &batch_sizes {
171        // Build a batched input: [bs, 12, 1920]
172        let signal_batch: Vec<f32> = (0..bs).flat_map(|i| generate_psg(n_ch, n_t, 42 + i as u32)).collect();
173        let signal_tensor = burn::tensor::Tensor::<B, 2>::from_data(
174            burn::tensor::TensorData::new(signal_batch, vec![bs * n_ch, n_t]),
175            &dev,
176        ).reshape([bs, n_ch, n_t]);
177
178        // Warmup
179        let _ = encoder.model().forward_encoding(signal_tensor.clone());
180
181        // Timed
182        let mut t_vec = Vec::new();
183        for _ in 0..5.max(args.runs / 2) {
184            let (_, ms) = timed(|| encoder.model().forward_encoding(signal_tensor.clone()));
185            t_vec.push(ms);
186        }
187        let avg = t_vec.iter().sum::<f64>() / t_vec.len() as f64;
188        let per_epoch = avg / bs as f64;
189        let bmin = t_vec.iter().cloned().fold(f64::INFINITY, f64::min);
190        let bmax = t_vec.iter().cloned().fold(0.0f64, f64::max);
191
192        if !json_mode {
193            eprintln!("    {:>6}  {:>7.1} ms  {:>9.1} ms", bs, avg, per_epoch);
194        }
195
196        batch_scaling.push(serde_json::json!({
197            "batch_size": bs,
198            "mean_ms": round2(avg),
199            "min_ms": round2(bmin),
200            "max_ms": round2(bmax),
201            "per_epoch_ms": round2(per_epoch),
202            "runs": t_vec,
203        }));
204    }
205
206    // ── 4. Throughput (epochs/sec) ──────────────────────────────────────────
207    let throughput_epochs_sec = 1000.0 / infer_mean;
208
209    if !json_mode {
210        eprintln!("\n  ▸ Throughput: {throughput_epochs_sec:.1} epochs/sec (single-epoch inference)");
211    }
212
213    // ── 5. Channel subset scaling (robustness test with padding) ────────────
214    let channel_subsets = [2, 4, 6, 8, 10, 12];
215    let mut channel_scaling: Vec<serde_json::Value> = Vec::new();
216
217    if !json_mode {
218        eprintln!("\n  ▸ Channel count scaling (zero-padded to 12ch, T={}):", n_t);
219        eprintln!("    {:>6}  {:>10}", "Active", "Mean (ms)");
220    }
221
222    for &active_ch in &channel_subsets {
223        // Create signal with `active_ch` active channels, rest zero-padded to 12
224        let mut sig = vec![0.0f32; n_ch * n_t];
225        let active = generate_psg(active_ch, n_t, 100 + active_ch as u32);
226        for ch in 0..active_ch {
227            for t in 0..n_t {
228                sig[ch * n_t + t] = active[ch * n_t + t];
229            }
230        }
231        let b = osf_rs::build_batch::<B>(sig, n_ch, n_t, &dev);
232
233        let _ = encoder.run_batch(&b)?; // warmup
234        let mut t_vec = Vec::new();
235        for _ in 0..5 {
236            let (_, ms) = timed(|| encoder.run_batch(&b));
237            t_vec.push(ms);
238        }
239        let avg = t_vec.iter().sum::<f64>() / t_vec.len() as f64;
240        let cmin = t_vec.iter().cloned().fold(f64::INFINITY, f64::min);
241        let cmax = t_vec.iter().cloned().fold(0.0f64, f64::max);
242
243        if !json_mode {
244            eprintln!("    {:>6}  {:>7.1} ms", active_ch, avg);
245        }
246
247        channel_scaling.push(serde_json::json!({
248            "active_channels": active_ch,
249            "total_channels": n_ch,
250            "mean_ms": round2(avg),
251            "min_ms": round2(cmin),
252            "max_ms": round2(cmax),
253            "runs": t_vec,
254        }));
255    }
256
257    if !json_mode { eprintln!(); }
258
259    // ── JSON output ─────────────────────────────────────────────────────────
260    let result = serde_json::json!({
261        "backend": backend::NAME,
262        "model": {
263            "encoder_name": model_cfg.encoder_name,
264            "width": model_cfg.width,
265            "depth": model_cfg.depth,
266            "heads": model_cfg.heads,
267            "num_leads": model_cfg.num_leads,
268            "patch_size_time": model_cfg.patch_size_time,
269            "patch_size_ch": model_cfg.patch_size_ch,
270            "seq_len": model_cfg.seq_len,
271        },
272        "load_ms": round2(ms_load),
273        "inference": {
274            "channels": n_ch,
275            "samples": n_t,
276            "warmup": args.warmup,
277            "runs": args.runs,
278            "mean_ms": round2(infer_mean),
279            "min_ms": round2(infer_min),
280            "max_ms": round2(infer_max),
281            "std_ms": round2(infer_std),
282            "all_ms": infer_times,
283        },
284        "batch_scaling": batch_scaling,
285        "channel_scaling": channel_scaling,
286        "throughput_epochs_sec": round2(throughput_epochs_sec),
287    });
288
289    if json_mode {
290        println!("{}", serde_json::to_string_pretty(&result)?);
291    }
292
293    Ok(())
294}
Source

pub fn run_batches( &self, batches: &[InputBatch<B>], ) -> Result<Vec<EpochEmbedding>>

Run on multiple batches.

Source

pub fn model(&self) -> &OsfViT<B>

Get the raw ViT model reference.

Examples found in repository?
examples/benchmark.rs (line 179)
95fn main() -> anyhow::Result<()> {
96    let args = Args::parse();
97    let dev = device();
98    let json_mode = args.json;
99
100    if !json_mode {
101        eprintln!("╔══════════════════════════════════════════════════════════════╗");
102        eprintln!("║  OSF-RS — Inference Benchmark                               ║");
103        eprintln!("╚══════════════════════════════════════════════════════════════╝\n");
104        eprintln!("  Backend: {}", backend::NAME);
105    }
106
107    // Load config
108    let model_cfg = if let Some(ref cfg_path) = args.config {
109        let s = std::fs::read_to_string(cfg_path)?;
110        serde_json::from_str(&s)?
111    } else {
112        osf_rs::ModelConfig::default()
113    };
114
115    // ── 1. Weight loading benchmark ─────────────────────────────────────────
116    let (encoder, ms_load) = osf_rs::OsfEncoder::<B>::load_with_config(
117        model_cfg.clone(),
118        Path::new(&args.weights),
119        dev.clone(),
120    )?;
121
122    if !json_mode {
123        eprintln!("  Model:   {}", encoder.describe());
124        eprintln!("  Load:    {ms_load:.0} ms\n");
125    }
126
127    let n_ch = osf_rs::NUM_PSG_CHANNELS;   // 12
128    let n_t  = osf_rs::EPOCH_SAMPLES;      // 1920
129
130    // ── 2. Standard inference benchmark (B=1) ───────────────────────────────
131    if !json_mode {
132        eprintln!("  ▸ Standard inference (B=1, {}ch × {} samples)", n_ch, n_t);
133    }
134
135    let signal = generate_psg(n_ch, n_t, 42);
136    let batch = osf_rs::build_batch::<B>(signal, n_ch, n_t, &dev);
137
138    // Warmup
139    for _ in 0..args.warmup {
140        let _ = encoder.run_batch(&batch)?;
141    }
142
143    // Timed runs
144    let mut infer_times = Vec::with_capacity(args.runs);
145    for _ in 0..args.runs {
146        let (_, ms) = timed(|| encoder.run_batch(&batch));
147        infer_times.push(ms);
148    }
149
150    let infer_mean = infer_times.iter().sum::<f64>() / infer_times.len() as f64;
151    let infer_min = infer_times.iter().cloned().fold(f64::INFINITY, f64::min);
152    let infer_max = infer_times.iter().cloned().fold(0.0f64, f64::max);
153    let infer_std = (infer_times.iter().map(|t| (t - infer_mean).powi(2)).sum::<f64>()
154        / infer_times.len() as f64).sqrt();
155
156    if !json_mode {
157        eprintln!("    mean={infer_mean:.1}ms  min={infer_min:.1}ms  max={infer_max:.1}ms  std={infer_std:.1}ms  (n={})",
158            args.runs);
159    }
160
161    // ── 3. Batch size scaling ───────────────────────────────────────────────
162    let batch_sizes = [1, 2, 4, 8, 16];
163    let mut batch_scaling: Vec<serde_json::Value> = Vec::new();
164
165    if !json_mode {
166        eprintln!("\n  ▸ Batch size scaling ({}ch × {} samples):", n_ch, n_t);
167        eprintln!("    {:>6}  {:>10}  {:>12}", "Batch", "Mean (ms)", "Per-epoch (ms)");
168    }
169
170    for &bs in &batch_sizes {
171        // Build a batched input: [bs, 12, 1920]
172        let signal_batch: Vec<f32> = (0..bs).flat_map(|i| generate_psg(n_ch, n_t, 42 + i as u32)).collect();
173        let signal_tensor = burn::tensor::Tensor::<B, 2>::from_data(
174            burn::tensor::TensorData::new(signal_batch, vec![bs * n_ch, n_t]),
175            &dev,
176        ).reshape([bs, n_ch, n_t]);
177
178        // Warmup
179        let _ = encoder.model().forward_encoding(signal_tensor.clone());
180
181        // Timed
182        let mut t_vec = Vec::new();
183        for _ in 0..5.max(args.runs / 2) {
184            let (_, ms) = timed(|| encoder.model().forward_encoding(signal_tensor.clone()));
185            t_vec.push(ms);
186        }
187        let avg = t_vec.iter().sum::<f64>() / t_vec.len() as f64;
188        let per_epoch = avg / bs as f64;
189        let bmin = t_vec.iter().cloned().fold(f64::INFINITY, f64::min);
190        let bmax = t_vec.iter().cloned().fold(0.0f64, f64::max);
191
192        if !json_mode {
193            eprintln!("    {:>6}  {:>7.1} ms  {:>9.1} ms", bs, avg, per_epoch);
194        }
195
196        batch_scaling.push(serde_json::json!({
197            "batch_size": bs,
198            "mean_ms": round2(avg),
199            "min_ms": round2(bmin),
200            "max_ms": round2(bmax),
201            "per_epoch_ms": round2(per_epoch),
202            "runs": t_vec,
203        }));
204    }
205
206    // ── 4. Throughput (epochs/sec) ──────────────────────────────────────────
207    let throughput_epochs_sec = 1000.0 / infer_mean;
208
209    if !json_mode {
210        eprintln!("\n  ▸ Throughput: {throughput_epochs_sec:.1} epochs/sec (single-epoch inference)");
211    }
212
213    // ── 5. Channel subset scaling (robustness test with padding) ────────────
214    let channel_subsets = [2, 4, 6, 8, 10, 12];
215    let mut channel_scaling: Vec<serde_json::Value> = Vec::new();
216
217    if !json_mode {
218        eprintln!("\n  ▸ Channel count scaling (zero-padded to 12ch, T={}):", n_t);
219        eprintln!("    {:>6}  {:>10}", "Active", "Mean (ms)");
220    }
221
222    for &active_ch in &channel_subsets {
223        // Create signal with `active_ch` active channels, rest zero-padded to 12
224        let mut sig = vec![0.0f32; n_ch * n_t];
225        let active = generate_psg(active_ch, n_t, 100 + active_ch as u32);
226        for ch in 0..active_ch {
227            for t in 0..n_t {
228                sig[ch * n_t + t] = active[ch * n_t + t];
229            }
230        }
231        let b = osf_rs::build_batch::<B>(sig, n_ch, n_t, &dev);
232
233        let _ = encoder.run_batch(&b)?; // warmup
234        let mut t_vec = Vec::new();
235        for _ in 0..5 {
236            let (_, ms) = timed(|| encoder.run_batch(&b));
237            t_vec.push(ms);
238        }
239        let avg = t_vec.iter().sum::<f64>() / t_vec.len() as f64;
240        let cmin = t_vec.iter().cloned().fold(f64::INFINITY, f64::min);
241        let cmax = t_vec.iter().cloned().fold(0.0f64, f64::max);
242
243        if !json_mode {
244            eprintln!("    {:>6}  {:>7.1} ms", active_ch, avg);
245        }
246
247        channel_scaling.push(serde_json::json!({
248            "active_channels": active_ch,
249            "total_channels": n_ch,
250            "mean_ms": round2(avg),
251            "min_ms": round2(cmin),
252            "max_ms": round2(cmax),
253            "runs": t_vec,
254        }));
255    }
256
257    if !json_mode { eprintln!(); }
258
259    // ── JSON output ─────────────────────────────────────────────────────────
260    let result = serde_json::json!({
261        "backend": backend::NAME,
262        "model": {
263            "encoder_name": model_cfg.encoder_name,
264            "width": model_cfg.width,
265            "depth": model_cfg.depth,
266            "heads": model_cfg.heads,
267            "num_leads": model_cfg.num_leads,
268            "patch_size_time": model_cfg.patch_size_time,
269            "patch_size_ch": model_cfg.patch_size_ch,
270            "seq_len": model_cfg.seq_len,
271        },
272        "load_ms": round2(ms_load),
273        "inference": {
274            "channels": n_ch,
275            "samples": n_t,
276            "warmup": args.warmup,
277            "runs": args.runs,
278            "mean_ms": round2(infer_mean),
279            "min_ms": round2(infer_min),
280            "max_ms": round2(infer_max),
281            "std_ms": round2(infer_std),
282            "all_ms": infer_times,
283        },
284        "batch_scaling": batch_scaling,
285        "channel_scaling": channel_scaling,
286        "throughput_epochs_sec": round2(throughput_epochs_sec),
287    });
288
289    if json_mode {
290        println!("{}", serde_json::to_string_pretty(&result)?);
291    }
292
293    Ok(())
294}
Source

pub fn device(&self) -> &B::Device

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V