pub struct OsfEncoder<B: Backend> {
pub model_cfg: ModelConfig,
/* private fields */
}Expand description
High-level OSF encoder for PSG signal processing.
Fields§
§model_cfg: ModelConfigImplementations§
Source§impl<B: Backend> OsfEncoder<B>
impl<B: Backend> OsfEncoder<B>
Sourcepub fn load(
config_path: &Path,
weights_path: &Path,
device: B::Device,
) -> Result<(Self, f64)>
pub fn load( config_path: &Path, weights_path: &Path, device: B::Device, ) -> Result<(Self, f64)>
Load model from config JSON and safetensors weights.
Sourcepub fn load_with_config(
model_cfg: ModelConfig,
weights_path: &Path,
device: B::Device,
) -> Result<(Self, f64)>
pub fn load_with_config( model_cfg: ModelConfig, weights_path: &Path, device: B::Device, ) -> Result<(Self, f64)>
Load model from a ModelConfig and safetensors path directly.
Examples found in repository?
examples/embed.rs (lines 76-80)
58fn main() -> anyhow::Result<()> {
59 let args = Args::parse();
60 let t0 = Instant::now();
61 let device = backend::device();
62
63 println!("╔══════════════════════════════════════════════════════════════╗");
64 println!("║ OSF — PSG Embedding Extraction ║");
65 println!("╚══════════════════════════════════════════════════════════════╝\n");
66
67 // 1. Load model
68 let model_cfg = if let Some(ref cfg_path) = args.config {
69 let cfg_str = std::fs::read_to_string(cfg_path)?;
70 serde_json::from_str(&cfg_str)?
71 } else {
72 osf_rs::ModelConfig::default()
73 };
74
75 println!("▸ Loading OSF model …");
76 let (encoder, ms_load) = osf_rs::OsfEncoder::<B>::load_with_config(
77 model_cfg,
78 Path::new(&args.weights),
79 device.clone(),
80 )?;
81 println!(" {} ({ms_load:.0} ms)\n", encoder.describe());
82
83 // 2. Generate synthetic PSG epochs
84 let n_channels = osf_rs::NUM_PSG_CHANNELS;
85 let n_samples = osf_rs::EPOCH_SAMPLES;
86 let n_epochs = 3;
87
88 println!("▸ Generating {} synthetic PSG epochs ({} ch × {} samples each)\n",
89 n_epochs, n_channels, n_samples);
90
91 let mut all_outputs = Vec::new();
92
93 for epoch_idx in 0..n_epochs {
94 let signal = generate_synthetic_psg(n_channels, n_samples);
95 let batch = osf_rs::build_batch::<B>(signal, n_channels, n_samples, &device);
96
97 let t = Instant::now();
98 let result = encoder.run_batch(&batch)?;
99 let ms = t.elapsed().as_secs_f64() * 1000.0;
100
101 // Stats
102 let cls = &result.cls_emb;
103 let mean: f64 = cls.iter().map(|&v| v as f64).sum::<f64>() / cls.len() as f64;
104 let std: f64 = (cls.iter().map(|&v| { let d = v as f64 - mean; d * d }).sum::<f64>()
105 / cls.len() as f64).sqrt();
106
107 println!(" Epoch {epoch_idx}: cls=[{}] patches=[{},{}] mean={mean:+.4} std={std:.4} {ms:.1}ms",
108 result.embed_dim, result.num_patches, result.embed_dim);
109
110 if args.verbose && epoch_idx == 0 {
111 println!(" First 5 CLS values: {:?}", &cls[..5.min(cls.len())]);
112 }
113
114 all_outputs.push(result);
115 }
116
117 // 3. Save
118 let encoding = osf_rs::EncodingResult {
119 epochs: all_outputs,
120 ms_load,
121 ms_encode: t0.elapsed().as_secs_f64() * 1000.0,
122 };
123
124 if let Some(p) = Path::new(&args.output).parent() {
125 std::fs::create_dir_all(p)?;
126 }
127 encoding.save_safetensors(&args.output)?;
128 println!("\n▸ Saved {} epochs → {}", n_epochs, args.output);
129
130 let ms_total = t0.elapsed().as_secs_f64() * 1000.0;
131 println!("\n Total: {ms_total:.0} ms");
132 Ok(())
133}More examples
examples/benchmark.rs (lines 116-120)
95fn main() -> anyhow::Result<()> {
96 let args = Args::parse();
97 let dev = device();
98 let json_mode = args.json;
99
100 if !json_mode {
101 eprintln!("╔══════════════════════════════════════════════════════════════╗");
102 eprintln!("║ OSF-RS — Inference Benchmark ║");
103 eprintln!("╚══════════════════════════════════════════════════════════════╝\n");
104 eprintln!(" Backend: {}", backend::NAME);
105 }
106
107 // Load config
108 let model_cfg = if let Some(ref cfg_path) = args.config {
109 let s = std::fs::read_to_string(cfg_path)?;
110 serde_json::from_str(&s)?
111 } else {
112 osf_rs::ModelConfig::default()
113 };
114
115 // ── 1. Weight loading benchmark ─────────────────────────────────────────
116 let (encoder, ms_load) = osf_rs::OsfEncoder::<B>::load_with_config(
117 model_cfg.clone(),
118 Path::new(&args.weights),
119 dev.clone(),
120 )?;
121
122 if !json_mode {
123 eprintln!(" Model: {}", encoder.describe());
124 eprintln!(" Load: {ms_load:.0} ms\n");
125 }
126
127 let n_ch = osf_rs::NUM_PSG_CHANNELS; // 12
128 let n_t = osf_rs::EPOCH_SAMPLES; // 1920
129
130 // ── 2. Standard inference benchmark (B=1) ───────────────────────────────
131 if !json_mode {
132 eprintln!(" ▸ Standard inference (B=1, {}ch × {} samples)", n_ch, n_t);
133 }
134
135 let signal = generate_psg(n_ch, n_t, 42);
136 let batch = osf_rs::build_batch::<B>(signal, n_ch, n_t, &dev);
137
138 // Warmup
139 for _ in 0..args.warmup {
140 let _ = encoder.run_batch(&batch)?;
141 }
142
143 // Timed runs
144 let mut infer_times = Vec::with_capacity(args.runs);
145 for _ in 0..args.runs {
146 let (_, ms) = timed(|| encoder.run_batch(&batch));
147 infer_times.push(ms);
148 }
149
150 let infer_mean = infer_times.iter().sum::<f64>() / infer_times.len() as f64;
151 let infer_min = infer_times.iter().cloned().fold(f64::INFINITY, f64::min);
152 let infer_max = infer_times.iter().cloned().fold(0.0f64, f64::max);
153 let infer_std = (infer_times.iter().map(|t| (t - infer_mean).powi(2)).sum::<f64>()
154 / infer_times.len() as f64).sqrt();
155
156 if !json_mode {
157 eprintln!(" mean={infer_mean:.1}ms min={infer_min:.1}ms max={infer_max:.1}ms std={infer_std:.1}ms (n={})",
158 args.runs);
159 }
160
161 // ── 3. Batch size scaling ───────────────────────────────────────────────
162 let batch_sizes = [1, 2, 4, 8, 16];
163 let mut batch_scaling: Vec<serde_json::Value> = Vec::new();
164
165 if !json_mode {
166 eprintln!("\n ▸ Batch size scaling ({}ch × {} samples):", n_ch, n_t);
167 eprintln!(" {:>6} {:>10} {:>12}", "Batch", "Mean (ms)", "Per-epoch (ms)");
168 }
169
170 for &bs in &batch_sizes {
171 // Build a batched input: [bs, 12, 1920]
172 let signal_batch: Vec<f32> = (0..bs).flat_map(|i| generate_psg(n_ch, n_t, 42 + i as u32)).collect();
173 let signal_tensor = burn::tensor::Tensor::<B, 2>::from_data(
174 burn::tensor::TensorData::new(signal_batch, vec![bs * n_ch, n_t]),
175 &dev,
176 ).reshape([bs, n_ch, n_t]);
177
178 // Warmup
179 let _ = encoder.model().forward_encoding(signal_tensor.clone());
180
181 // Timed
182 let mut t_vec = Vec::new();
183 for _ in 0..5.max(args.runs / 2) {
184 let (_, ms) = timed(|| encoder.model().forward_encoding(signal_tensor.clone()));
185 t_vec.push(ms);
186 }
187 let avg = t_vec.iter().sum::<f64>() / t_vec.len() as f64;
188 let per_epoch = avg / bs as f64;
189 let bmin = t_vec.iter().cloned().fold(f64::INFINITY, f64::min);
190 let bmax = t_vec.iter().cloned().fold(0.0f64, f64::max);
191
192 if !json_mode {
193 eprintln!(" {:>6} {:>7.1} ms {:>9.1} ms", bs, avg, per_epoch);
194 }
195
196 batch_scaling.push(serde_json::json!({
197 "batch_size": bs,
198 "mean_ms": round2(avg),
199 "min_ms": round2(bmin),
200 "max_ms": round2(bmax),
201 "per_epoch_ms": round2(per_epoch),
202 "runs": t_vec,
203 }));
204 }
205
206 // ── 4. Throughput (epochs/sec) ──────────────────────────────────────────
207 let throughput_epochs_sec = 1000.0 / infer_mean;
208
209 if !json_mode {
210 eprintln!("\n ▸ Throughput: {throughput_epochs_sec:.1} epochs/sec (single-epoch inference)");
211 }
212
213 // ── 5. Channel subset scaling (robustness test with padding) ────────────
214 let channel_subsets = [2, 4, 6, 8, 10, 12];
215 let mut channel_scaling: Vec<serde_json::Value> = Vec::new();
216
217 if !json_mode {
218 eprintln!("\n ▸ Channel count scaling (zero-padded to 12ch, T={}):", n_t);
219 eprintln!(" {:>6} {:>10}", "Active", "Mean (ms)");
220 }
221
222 for &active_ch in &channel_subsets {
223 // Create signal with `active_ch` active channels, rest zero-padded to 12
224 let mut sig = vec![0.0f32; n_ch * n_t];
225 let active = generate_psg(active_ch, n_t, 100 + active_ch as u32);
226 for ch in 0..active_ch {
227 for t in 0..n_t {
228 sig[ch * n_t + t] = active[ch * n_t + t];
229 }
230 }
231 let b = osf_rs::build_batch::<B>(sig, n_ch, n_t, &dev);
232
233 let _ = encoder.run_batch(&b)?; // warmup
234 let mut t_vec = Vec::new();
235 for _ in 0..5 {
236 let (_, ms) = timed(|| encoder.run_batch(&b));
237 t_vec.push(ms);
238 }
239 let avg = t_vec.iter().sum::<f64>() / t_vec.len() as f64;
240 let cmin = t_vec.iter().cloned().fold(f64::INFINITY, f64::min);
241 let cmax = t_vec.iter().cloned().fold(0.0f64, f64::max);
242
243 if !json_mode {
244 eprintln!(" {:>6} {:>7.1} ms", active_ch, avg);
245 }
246
247 channel_scaling.push(serde_json::json!({
248 "active_channels": active_ch,
249 "total_channels": n_ch,
250 "mean_ms": round2(avg),
251 "min_ms": round2(cmin),
252 "max_ms": round2(cmax),
253 "runs": t_vec,
254 }));
255 }
256
257 if !json_mode { eprintln!(); }
258
259 // ── JSON output ─────────────────────────────────────────────────────────
260 let result = serde_json::json!({
261 "backend": backend::NAME,
262 "model": {
263 "encoder_name": model_cfg.encoder_name,
264 "width": model_cfg.width,
265 "depth": model_cfg.depth,
266 "heads": model_cfg.heads,
267 "num_leads": model_cfg.num_leads,
268 "patch_size_time": model_cfg.patch_size_time,
269 "patch_size_ch": model_cfg.patch_size_ch,
270 "seq_len": model_cfg.seq_len,
271 },
272 "load_ms": round2(ms_load),
273 "inference": {
274 "channels": n_ch,
275 "samples": n_t,
276 "warmup": args.warmup,
277 "runs": args.runs,
278 "mean_ms": round2(infer_mean),
279 "min_ms": round2(infer_min),
280 "max_ms": round2(infer_max),
281 "std_ms": round2(infer_std),
282 "all_ms": infer_times,
283 },
284 "batch_scaling": batch_scaling,
285 "channel_scaling": channel_scaling,
286 "throughput_epochs_sec": round2(throughput_epochs_sec),
287 });
288
289 if json_mode {
290 println!("{}", serde_json::to_string_pretty(&result)?);
291 }
292
293 Ok(())
294}Sourcepub fn describe(&self) -> String
pub fn describe(&self) -> String
Examples found in repository?
examples/embed.rs (line 81)
58fn main() -> anyhow::Result<()> {
59 let args = Args::parse();
60 let t0 = Instant::now();
61 let device = backend::device();
62
63 println!("╔══════════════════════════════════════════════════════════════╗");
64 println!("║ OSF — PSG Embedding Extraction ║");
65 println!("╚══════════════════════════════════════════════════════════════╝\n");
66
67 // 1. Load model
68 let model_cfg = if let Some(ref cfg_path) = args.config {
69 let cfg_str = std::fs::read_to_string(cfg_path)?;
70 serde_json::from_str(&cfg_str)?
71 } else {
72 osf_rs::ModelConfig::default()
73 };
74
75 println!("▸ Loading OSF model …");
76 let (encoder, ms_load) = osf_rs::OsfEncoder::<B>::load_with_config(
77 model_cfg,
78 Path::new(&args.weights),
79 device.clone(),
80 )?;
81 println!(" {} ({ms_load:.0} ms)\n", encoder.describe());
82
83 // 2. Generate synthetic PSG epochs
84 let n_channels = osf_rs::NUM_PSG_CHANNELS;
85 let n_samples = osf_rs::EPOCH_SAMPLES;
86 let n_epochs = 3;
87
88 println!("▸ Generating {} synthetic PSG epochs ({} ch × {} samples each)\n",
89 n_epochs, n_channels, n_samples);
90
91 let mut all_outputs = Vec::new();
92
93 for epoch_idx in 0..n_epochs {
94 let signal = generate_synthetic_psg(n_channels, n_samples);
95 let batch = osf_rs::build_batch::<B>(signal, n_channels, n_samples, &device);
96
97 let t = Instant::now();
98 let result = encoder.run_batch(&batch)?;
99 let ms = t.elapsed().as_secs_f64() * 1000.0;
100
101 // Stats
102 let cls = &result.cls_emb;
103 let mean: f64 = cls.iter().map(|&v| v as f64).sum::<f64>() / cls.len() as f64;
104 let std: f64 = (cls.iter().map(|&v| { let d = v as f64 - mean; d * d }).sum::<f64>()
105 / cls.len() as f64).sqrt();
106
107 println!(" Epoch {epoch_idx}: cls=[{}] patches=[{},{}] mean={mean:+.4} std={std:.4} {ms:.1}ms",
108 result.embed_dim, result.num_patches, result.embed_dim);
109
110 if args.verbose && epoch_idx == 0 {
111 println!(" First 5 CLS values: {:?}", &cls[..5.min(cls.len())]);
112 }
113
114 all_outputs.push(result);
115 }
116
117 // 3. Save
118 let encoding = osf_rs::EncodingResult {
119 epochs: all_outputs,
120 ms_load,
121 ms_encode: t0.elapsed().as_secs_f64() * 1000.0,
122 };
123
124 if let Some(p) = Path::new(&args.output).parent() {
125 std::fs::create_dir_all(p)?;
126 }
127 encoding.save_safetensors(&args.output)?;
128 println!("\n▸ Saved {} epochs → {}", n_epochs, args.output);
129
130 let ms_total = t0.elapsed().as_secs_f64() * 1000.0;
131 println!("\n Total: {ms_total:.0} ms");
132 Ok(())
133}More examples
examples/benchmark.rs (line 123)
95fn main() -> anyhow::Result<()> {
96 let args = Args::parse();
97 let dev = device();
98 let json_mode = args.json;
99
100 if !json_mode {
101 eprintln!("╔══════════════════════════════════════════════════════════════╗");
102 eprintln!("║ OSF-RS — Inference Benchmark ║");
103 eprintln!("╚══════════════════════════════════════════════════════════════╝\n");
104 eprintln!(" Backend: {}", backend::NAME);
105 }
106
107 // Load config
108 let model_cfg = if let Some(ref cfg_path) = args.config {
109 let s = std::fs::read_to_string(cfg_path)?;
110 serde_json::from_str(&s)?
111 } else {
112 osf_rs::ModelConfig::default()
113 };
114
115 // ── 1. Weight loading benchmark ─────────────────────────────────────────
116 let (encoder, ms_load) = osf_rs::OsfEncoder::<B>::load_with_config(
117 model_cfg.clone(),
118 Path::new(&args.weights),
119 dev.clone(),
120 )?;
121
122 if !json_mode {
123 eprintln!(" Model: {}", encoder.describe());
124 eprintln!(" Load: {ms_load:.0} ms\n");
125 }
126
127 let n_ch = osf_rs::NUM_PSG_CHANNELS; // 12
128 let n_t = osf_rs::EPOCH_SAMPLES; // 1920
129
130 // ── 2. Standard inference benchmark (B=1) ───────────────────────────────
131 if !json_mode {
132 eprintln!(" ▸ Standard inference (B=1, {}ch × {} samples)", n_ch, n_t);
133 }
134
135 let signal = generate_psg(n_ch, n_t, 42);
136 let batch = osf_rs::build_batch::<B>(signal, n_ch, n_t, &dev);
137
138 // Warmup
139 for _ in 0..args.warmup {
140 let _ = encoder.run_batch(&batch)?;
141 }
142
143 // Timed runs
144 let mut infer_times = Vec::with_capacity(args.runs);
145 for _ in 0..args.runs {
146 let (_, ms) = timed(|| encoder.run_batch(&batch));
147 infer_times.push(ms);
148 }
149
150 let infer_mean = infer_times.iter().sum::<f64>() / infer_times.len() as f64;
151 let infer_min = infer_times.iter().cloned().fold(f64::INFINITY, f64::min);
152 let infer_max = infer_times.iter().cloned().fold(0.0f64, f64::max);
153 let infer_std = (infer_times.iter().map(|t| (t - infer_mean).powi(2)).sum::<f64>()
154 / infer_times.len() as f64).sqrt();
155
156 if !json_mode {
157 eprintln!(" mean={infer_mean:.1}ms min={infer_min:.1}ms max={infer_max:.1}ms std={infer_std:.1}ms (n={})",
158 args.runs);
159 }
160
161 // ── 3. Batch size scaling ───────────────────────────────────────────────
162 let batch_sizes = [1, 2, 4, 8, 16];
163 let mut batch_scaling: Vec<serde_json::Value> = Vec::new();
164
165 if !json_mode {
166 eprintln!("\n ▸ Batch size scaling ({}ch × {} samples):", n_ch, n_t);
167 eprintln!(" {:>6} {:>10} {:>12}", "Batch", "Mean (ms)", "Per-epoch (ms)");
168 }
169
170 for &bs in &batch_sizes {
171 // Build a batched input: [bs, 12, 1920]
172 let signal_batch: Vec<f32> = (0..bs).flat_map(|i| generate_psg(n_ch, n_t, 42 + i as u32)).collect();
173 let signal_tensor = burn::tensor::Tensor::<B, 2>::from_data(
174 burn::tensor::TensorData::new(signal_batch, vec![bs * n_ch, n_t]),
175 &dev,
176 ).reshape([bs, n_ch, n_t]);
177
178 // Warmup
179 let _ = encoder.model().forward_encoding(signal_tensor.clone());
180
181 // Timed
182 let mut t_vec = Vec::new();
183 for _ in 0..5.max(args.runs / 2) {
184 let (_, ms) = timed(|| encoder.model().forward_encoding(signal_tensor.clone()));
185 t_vec.push(ms);
186 }
187 let avg = t_vec.iter().sum::<f64>() / t_vec.len() as f64;
188 let per_epoch = avg / bs as f64;
189 let bmin = t_vec.iter().cloned().fold(f64::INFINITY, f64::min);
190 let bmax = t_vec.iter().cloned().fold(0.0f64, f64::max);
191
192 if !json_mode {
193 eprintln!(" {:>6} {:>7.1} ms {:>9.1} ms", bs, avg, per_epoch);
194 }
195
196 batch_scaling.push(serde_json::json!({
197 "batch_size": bs,
198 "mean_ms": round2(avg),
199 "min_ms": round2(bmin),
200 "max_ms": round2(bmax),
201 "per_epoch_ms": round2(per_epoch),
202 "runs": t_vec,
203 }));
204 }
205
206 // ── 4. Throughput (epochs/sec) ──────────────────────────────────────────
207 let throughput_epochs_sec = 1000.0 / infer_mean;
208
209 if !json_mode {
210 eprintln!("\n ▸ Throughput: {throughput_epochs_sec:.1} epochs/sec (single-epoch inference)");
211 }
212
213 // ── 5. Channel subset scaling (robustness test with padding) ────────────
214 let channel_subsets = [2, 4, 6, 8, 10, 12];
215 let mut channel_scaling: Vec<serde_json::Value> = Vec::new();
216
217 if !json_mode {
218 eprintln!("\n ▸ Channel count scaling (zero-padded to 12ch, T={}):", n_t);
219 eprintln!(" {:>6} {:>10}", "Active", "Mean (ms)");
220 }
221
222 for &active_ch in &channel_subsets {
223 // Create signal with `active_ch` active channels, rest zero-padded to 12
224 let mut sig = vec![0.0f32; n_ch * n_t];
225 let active = generate_psg(active_ch, n_t, 100 + active_ch as u32);
226 for ch in 0..active_ch {
227 for t in 0..n_t {
228 sig[ch * n_t + t] = active[ch * n_t + t];
229 }
230 }
231 let b = osf_rs::build_batch::<B>(sig, n_ch, n_t, &dev);
232
233 let _ = encoder.run_batch(&b)?; // warmup
234 let mut t_vec = Vec::new();
235 for _ in 0..5 {
236 let (_, ms) = timed(|| encoder.run_batch(&b));
237 t_vec.push(ms);
238 }
239 let avg = t_vec.iter().sum::<f64>() / t_vec.len() as f64;
240 let cmin = t_vec.iter().cloned().fold(f64::INFINITY, f64::min);
241 let cmax = t_vec.iter().cloned().fold(0.0f64, f64::max);
242
243 if !json_mode {
244 eprintln!(" {:>6} {:>7.1} ms", active_ch, avg);
245 }
246
247 channel_scaling.push(serde_json::json!({
248 "active_channels": active_ch,
249 "total_channels": n_ch,
250 "mean_ms": round2(avg),
251 "min_ms": round2(cmin),
252 "max_ms": round2(cmax),
253 "runs": t_vec,
254 }));
255 }
256
257 if !json_mode { eprintln!(); }
258
259 // ── JSON output ─────────────────────────────────────────────────────────
260 let result = serde_json::json!({
261 "backend": backend::NAME,
262 "model": {
263 "encoder_name": model_cfg.encoder_name,
264 "width": model_cfg.width,
265 "depth": model_cfg.depth,
266 "heads": model_cfg.heads,
267 "num_leads": model_cfg.num_leads,
268 "patch_size_time": model_cfg.patch_size_time,
269 "patch_size_ch": model_cfg.patch_size_ch,
270 "seq_len": model_cfg.seq_len,
271 },
272 "load_ms": round2(ms_load),
273 "inference": {
274 "channels": n_ch,
275 "samples": n_t,
276 "warmup": args.warmup,
277 "runs": args.runs,
278 "mean_ms": round2(infer_mean),
279 "min_ms": round2(infer_min),
280 "max_ms": round2(infer_max),
281 "std_ms": round2(infer_std),
282 "all_ms": infer_times,
283 },
284 "batch_scaling": batch_scaling,
285 "channel_scaling": channel_scaling,
286 "throughput_epochs_sec": round2(throughput_epochs_sec),
287 });
288
289 if json_mode {
290 println!("{}", serde_json::to_string_pretty(&result)?);
291 }
292
293 Ok(())
294}Sourcepub fn run_batch(&self, batch: &InputBatch<B>) -> Result<EpochEmbedding>
pub fn run_batch(&self, batch: &InputBatch<B>) -> Result<EpochEmbedding>
Run inference on a prepared InputBatch.
Returns an EpochEmbedding with CLS and patch embeddings.
Examples found in repository?
examples/embed.rs (line 98)
58fn main() -> anyhow::Result<()> {
59 let args = Args::parse();
60 let t0 = Instant::now();
61 let device = backend::device();
62
63 println!("╔══════════════════════════════════════════════════════════════╗");
64 println!("║ OSF — PSG Embedding Extraction ║");
65 println!("╚══════════════════════════════════════════════════════════════╝\n");
66
67 // 1. Load model
68 let model_cfg = if let Some(ref cfg_path) = args.config {
69 let cfg_str = std::fs::read_to_string(cfg_path)?;
70 serde_json::from_str(&cfg_str)?
71 } else {
72 osf_rs::ModelConfig::default()
73 };
74
75 println!("▸ Loading OSF model …");
76 let (encoder, ms_load) = osf_rs::OsfEncoder::<B>::load_with_config(
77 model_cfg,
78 Path::new(&args.weights),
79 device.clone(),
80 )?;
81 println!(" {} ({ms_load:.0} ms)\n", encoder.describe());
82
83 // 2. Generate synthetic PSG epochs
84 let n_channels = osf_rs::NUM_PSG_CHANNELS;
85 let n_samples = osf_rs::EPOCH_SAMPLES;
86 let n_epochs = 3;
87
88 println!("▸ Generating {} synthetic PSG epochs ({} ch × {} samples each)\n",
89 n_epochs, n_channels, n_samples);
90
91 let mut all_outputs = Vec::new();
92
93 for epoch_idx in 0..n_epochs {
94 let signal = generate_synthetic_psg(n_channels, n_samples);
95 let batch = osf_rs::build_batch::<B>(signal, n_channels, n_samples, &device);
96
97 let t = Instant::now();
98 let result = encoder.run_batch(&batch)?;
99 let ms = t.elapsed().as_secs_f64() * 1000.0;
100
101 // Stats
102 let cls = &result.cls_emb;
103 let mean: f64 = cls.iter().map(|&v| v as f64).sum::<f64>() / cls.len() as f64;
104 let std: f64 = (cls.iter().map(|&v| { let d = v as f64 - mean; d * d }).sum::<f64>()
105 / cls.len() as f64).sqrt();
106
107 println!(" Epoch {epoch_idx}: cls=[{}] patches=[{},{}] mean={mean:+.4} std={std:.4} {ms:.1}ms",
108 result.embed_dim, result.num_patches, result.embed_dim);
109
110 if args.verbose && epoch_idx == 0 {
111 println!(" First 5 CLS values: {:?}", &cls[..5.min(cls.len())]);
112 }
113
114 all_outputs.push(result);
115 }
116
117 // 3. Save
118 let encoding = osf_rs::EncodingResult {
119 epochs: all_outputs,
120 ms_load,
121 ms_encode: t0.elapsed().as_secs_f64() * 1000.0,
122 };
123
124 if let Some(p) = Path::new(&args.output).parent() {
125 std::fs::create_dir_all(p)?;
126 }
127 encoding.save_safetensors(&args.output)?;
128 println!("\n▸ Saved {} epochs → {}", n_epochs, args.output);
129
130 let ms_total = t0.elapsed().as_secs_f64() * 1000.0;
131 println!("\n Total: {ms_total:.0} ms");
132 Ok(())
133}More examples
examples/benchmark.rs (line 140)
95fn main() -> anyhow::Result<()> {
96 let args = Args::parse();
97 let dev = device();
98 let json_mode = args.json;
99
100 if !json_mode {
101 eprintln!("╔══════════════════════════════════════════════════════════════╗");
102 eprintln!("║ OSF-RS — Inference Benchmark ║");
103 eprintln!("╚══════════════════════════════════════════════════════════════╝\n");
104 eprintln!(" Backend: {}", backend::NAME);
105 }
106
107 // Load config
108 let model_cfg = if let Some(ref cfg_path) = args.config {
109 let s = std::fs::read_to_string(cfg_path)?;
110 serde_json::from_str(&s)?
111 } else {
112 osf_rs::ModelConfig::default()
113 };
114
115 // ── 1. Weight loading benchmark ─────────────────────────────────────────
116 let (encoder, ms_load) = osf_rs::OsfEncoder::<B>::load_with_config(
117 model_cfg.clone(),
118 Path::new(&args.weights),
119 dev.clone(),
120 )?;
121
122 if !json_mode {
123 eprintln!(" Model: {}", encoder.describe());
124 eprintln!(" Load: {ms_load:.0} ms\n");
125 }
126
127 let n_ch = osf_rs::NUM_PSG_CHANNELS; // 12
128 let n_t = osf_rs::EPOCH_SAMPLES; // 1920
129
130 // ── 2. Standard inference benchmark (B=1) ───────────────────────────────
131 if !json_mode {
132 eprintln!(" ▸ Standard inference (B=1, {}ch × {} samples)", n_ch, n_t);
133 }
134
135 let signal = generate_psg(n_ch, n_t, 42);
136 let batch = osf_rs::build_batch::<B>(signal, n_ch, n_t, &dev);
137
138 // Warmup
139 for _ in 0..args.warmup {
140 let _ = encoder.run_batch(&batch)?;
141 }
142
143 // Timed runs
144 let mut infer_times = Vec::with_capacity(args.runs);
145 for _ in 0..args.runs {
146 let (_, ms) = timed(|| encoder.run_batch(&batch));
147 infer_times.push(ms);
148 }
149
150 let infer_mean = infer_times.iter().sum::<f64>() / infer_times.len() as f64;
151 let infer_min = infer_times.iter().cloned().fold(f64::INFINITY, f64::min);
152 let infer_max = infer_times.iter().cloned().fold(0.0f64, f64::max);
153 let infer_std = (infer_times.iter().map(|t| (t - infer_mean).powi(2)).sum::<f64>()
154 / infer_times.len() as f64).sqrt();
155
156 if !json_mode {
157 eprintln!(" mean={infer_mean:.1}ms min={infer_min:.1}ms max={infer_max:.1}ms std={infer_std:.1}ms (n={})",
158 args.runs);
159 }
160
161 // ── 3. Batch size scaling ───────────────────────────────────────────────
162 let batch_sizes = [1, 2, 4, 8, 16];
163 let mut batch_scaling: Vec<serde_json::Value> = Vec::new();
164
165 if !json_mode {
166 eprintln!("\n ▸ Batch size scaling ({}ch × {} samples):", n_ch, n_t);
167 eprintln!(" {:>6} {:>10} {:>12}", "Batch", "Mean (ms)", "Per-epoch (ms)");
168 }
169
170 for &bs in &batch_sizes {
171 // Build a batched input: [bs, 12, 1920]
172 let signal_batch: Vec<f32> = (0..bs).flat_map(|i| generate_psg(n_ch, n_t, 42 + i as u32)).collect();
173 let signal_tensor = burn::tensor::Tensor::<B, 2>::from_data(
174 burn::tensor::TensorData::new(signal_batch, vec![bs * n_ch, n_t]),
175 &dev,
176 ).reshape([bs, n_ch, n_t]);
177
178 // Warmup
179 let _ = encoder.model().forward_encoding(signal_tensor.clone());
180
181 // Timed
182 let mut t_vec = Vec::new();
183 for _ in 0..5.max(args.runs / 2) {
184 let (_, ms) = timed(|| encoder.model().forward_encoding(signal_tensor.clone()));
185 t_vec.push(ms);
186 }
187 let avg = t_vec.iter().sum::<f64>() / t_vec.len() as f64;
188 let per_epoch = avg / bs as f64;
189 let bmin = t_vec.iter().cloned().fold(f64::INFINITY, f64::min);
190 let bmax = t_vec.iter().cloned().fold(0.0f64, f64::max);
191
192 if !json_mode {
193 eprintln!(" {:>6} {:>7.1} ms {:>9.1} ms", bs, avg, per_epoch);
194 }
195
196 batch_scaling.push(serde_json::json!({
197 "batch_size": bs,
198 "mean_ms": round2(avg),
199 "min_ms": round2(bmin),
200 "max_ms": round2(bmax),
201 "per_epoch_ms": round2(per_epoch),
202 "runs": t_vec,
203 }));
204 }
205
206 // ── 4. Throughput (epochs/sec) ──────────────────────────────────────────
207 let throughput_epochs_sec = 1000.0 / infer_mean;
208
209 if !json_mode {
210 eprintln!("\n ▸ Throughput: {throughput_epochs_sec:.1} epochs/sec (single-epoch inference)");
211 }
212
213 // ── 5. Channel subset scaling (robustness test with padding) ────────────
214 let channel_subsets = [2, 4, 6, 8, 10, 12];
215 let mut channel_scaling: Vec<serde_json::Value> = Vec::new();
216
217 if !json_mode {
218 eprintln!("\n ▸ Channel count scaling (zero-padded to 12ch, T={}):", n_t);
219 eprintln!(" {:>6} {:>10}", "Active", "Mean (ms)");
220 }
221
222 for &active_ch in &channel_subsets {
223 // Create signal with `active_ch` active channels, rest zero-padded to 12
224 let mut sig = vec![0.0f32; n_ch * n_t];
225 let active = generate_psg(active_ch, n_t, 100 + active_ch as u32);
226 for ch in 0..active_ch {
227 for t in 0..n_t {
228 sig[ch * n_t + t] = active[ch * n_t + t];
229 }
230 }
231 let b = osf_rs::build_batch::<B>(sig, n_ch, n_t, &dev);
232
233 let _ = encoder.run_batch(&b)?; // warmup
234 let mut t_vec = Vec::new();
235 for _ in 0..5 {
236 let (_, ms) = timed(|| encoder.run_batch(&b));
237 t_vec.push(ms);
238 }
239 let avg = t_vec.iter().sum::<f64>() / t_vec.len() as f64;
240 let cmin = t_vec.iter().cloned().fold(f64::INFINITY, f64::min);
241 let cmax = t_vec.iter().cloned().fold(0.0f64, f64::max);
242
243 if !json_mode {
244 eprintln!(" {:>6} {:>7.1} ms", active_ch, avg);
245 }
246
247 channel_scaling.push(serde_json::json!({
248 "active_channels": active_ch,
249 "total_channels": n_ch,
250 "mean_ms": round2(avg),
251 "min_ms": round2(cmin),
252 "max_ms": round2(cmax),
253 "runs": t_vec,
254 }));
255 }
256
257 if !json_mode { eprintln!(); }
258
259 // ── JSON output ─────────────────────────────────────────────────────────
260 let result = serde_json::json!({
261 "backend": backend::NAME,
262 "model": {
263 "encoder_name": model_cfg.encoder_name,
264 "width": model_cfg.width,
265 "depth": model_cfg.depth,
266 "heads": model_cfg.heads,
267 "num_leads": model_cfg.num_leads,
268 "patch_size_time": model_cfg.patch_size_time,
269 "patch_size_ch": model_cfg.patch_size_ch,
270 "seq_len": model_cfg.seq_len,
271 },
272 "load_ms": round2(ms_load),
273 "inference": {
274 "channels": n_ch,
275 "samples": n_t,
276 "warmup": args.warmup,
277 "runs": args.runs,
278 "mean_ms": round2(infer_mean),
279 "min_ms": round2(infer_min),
280 "max_ms": round2(infer_max),
281 "std_ms": round2(infer_std),
282 "all_ms": infer_times,
283 },
284 "batch_scaling": batch_scaling,
285 "channel_scaling": channel_scaling,
286 "throughput_epochs_sec": round2(throughput_epochs_sec),
287 });
288
289 if json_mode {
290 println!("{}", serde_json::to_string_pretty(&result)?);
291 }
292
293 Ok(())
294}Sourcepub fn run_batches(
&self,
batches: &[InputBatch<B>],
) -> Result<Vec<EpochEmbedding>>
pub fn run_batches( &self, batches: &[InputBatch<B>], ) -> Result<Vec<EpochEmbedding>>
Run on multiple batches.
Sourcepub fn model(&self) -> &OsfViT<B>
pub fn model(&self) -> &OsfViT<B>
Get the raw ViT model reference.
Examples found in repository?
examples/benchmark.rs (line 179)
95fn main() -> anyhow::Result<()> {
96 let args = Args::parse();
97 let dev = device();
98 let json_mode = args.json;
99
100 if !json_mode {
101 eprintln!("╔══════════════════════════════════════════════════════════════╗");
102 eprintln!("║ OSF-RS — Inference Benchmark ║");
103 eprintln!("╚══════════════════════════════════════════════════════════════╝\n");
104 eprintln!(" Backend: {}", backend::NAME);
105 }
106
107 // Load config
108 let model_cfg = if let Some(ref cfg_path) = args.config {
109 let s = std::fs::read_to_string(cfg_path)?;
110 serde_json::from_str(&s)?
111 } else {
112 osf_rs::ModelConfig::default()
113 };
114
115 // ── 1. Weight loading benchmark ─────────────────────────────────────────
116 let (encoder, ms_load) = osf_rs::OsfEncoder::<B>::load_with_config(
117 model_cfg.clone(),
118 Path::new(&args.weights),
119 dev.clone(),
120 )?;
121
122 if !json_mode {
123 eprintln!(" Model: {}", encoder.describe());
124 eprintln!(" Load: {ms_load:.0} ms\n");
125 }
126
127 let n_ch = osf_rs::NUM_PSG_CHANNELS; // 12
128 let n_t = osf_rs::EPOCH_SAMPLES; // 1920
129
130 // ── 2. Standard inference benchmark (B=1) ───────────────────────────────
131 if !json_mode {
132 eprintln!(" ▸ Standard inference (B=1, {}ch × {} samples)", n_ch, n_t);
133 }
134
135 let signal = generate_psg(n_ch, n_t, 42);
136 let batch = osf_rs::build_batch::<B>(signal, n_ch, n_t, &dev);
137
138 // Warmup
139 for _ in 0..args.warmup {
140 let _ = encoder.run_batch(&batch)?;
141 }
142
143 // Timed runs
144 let mut infer_times = Vec::with_capacity(args.runs);
145 for _ in 0..args.runs {
146 let (_, ms) = timed(|| encoder.run_batch(&batch));
147 infer_times.push(ms);
148 }
149
150 let infer_mean = infer_times.iter().sum::<f64>() / infer_times.len() as f64;
151 let infer_min = infer_times.iter().cloned().fold(f64::INFINITY, f64::min);
152 let infer_max = infer_times.iter().cloned().fold(0.0f64, f64::max);
153 let infer_std = (infer_times.iter().map(|t| (t - infer_mean).powi(2)).sum::<f64>()
154 / infer_times.len() as f64).sqrt();
155
156 if !json_mode {
157 eprintln!(" mean={infer_mean:.1}ms min={infer_min:.1}ms max={infer_max:.1}ms std={infer_std:.1}ms (n={})",
158 args.runs);
159 }
160
161 // ── 3. Batch size scaling ───────────────────────────────────────────────
162 let batch_sizes = [1, 2, 4, 8, 16];
163 let mut batch_scaling: Vec<serde_json::Value> = Vec::new();
164
165 if !json_mode {
166 eprintln!("\n ▸ Batch size scaling ({}ch × {} samples):", n_ch, n_t);
167 eprintln!(" {:>6} {:>10} {:>12}", "Batch", "Mean (ms)", "Per-epoch (ms)");
168 }
169
170 for &bs in &batch_sizes {
171 // Build a batched input: [bs, 12, 1920]
172 let signal_batch: Vec<f32> = (0..bs).flat_map(|i| generate_psg(n_ch, n_t, 42 + i as u32)).collect();
173 let signal_tensor = burn::tensor::Tensor::<B, 2>::from_data(
174 burn::tensor::TensorData::new(signal_batch, vec![bs * n_ch, n_t]),
175 &dev,
176 ).reshape([bs, n_ch, n_t]);
177
178 // Warmup
179 let _ = encoder.model().forward_encoding(signal_tensor.clone());
180
181 // Timed
182 let mut t_vec = Vec::new();
183 for _ in 0..5.max(args.runs / 2) {
184 let (_, ms) = timed(|| encoder.model().forward_encoding(signal_tensor.clone()));
185 t_vec.push(ms);
186 }
187 let avg = t_vec.iter().sum::<f64>() / t_vec.len() as f64;
188 let per_epoch = avg / bs as f64;
189 let bmin = t_vec.iter().cloned().fold(f64::INFINITY, f64::min);
190 let bmax = t_vec.iter().cloned().fold(0.0f64, f64::max);
191
192 if !json_mode {
193 eprintln!(" {:>6} {:>7.1} ms {:>9.1} ms", bs, avg, per_epoch);
194 }
195
196 batch_scaling.push(serde_json::json!({
197 "batch_size": bs,
198 "mean_ms": round2(avg),
199 "min_ms": round2(bmin),
200 "max_ms": round2(bmax),
201 "per_epoch_ms": round2(per_epoch),
202 "runs": t_vec,
203 }));
204 }
205
206 // ── 4. Throughput (epochs/sec) ──────────────────────────────────────────
207 let throughput_epochs_sec = 1000.0 / infer_mean;
208
209 if !json_mode {
210 eprintln!("\n ▸ Throughput: {throughput_epochs_sec:.1} epochs/sec (single-epoch inference)");
211 }
212
213 // ── 5. Channel subset scaling (robustness test with padding) ────────────
214 let channel_subsets = [2, 4, 6, 8, 10, 12];
215 let mut channel_scaling: Vec<serde_json::Value> = Vec::new();
216
217 if !json_mode {
218 eprintln!("\n ▸ Channel count scaling (zero-padded to 12ch, T={}):", n_t);
219 eprintln!(" {:>6} {:>10}", "Active", "Mean (ms)");
220 }
221
222 for &active_ch in &channel_subsets {
223 // Create signal with `active_ch` active channels, rest zero-padded to 12
224 let mut sig = vec![0.0f32; n_ch * n_t];
225 let active = generate_psg(active_ch, n_t, 100 + active_ch as u32);
226 for ch in 0..active_ch {
227 for t in 0..n_t {
228 sig[ch * n_t + t] = active[ch * n_t + t];
229 }
230 }
231 let b = osf_rs::build_batch::<B>(sig, n_ch, n_t, &dev);
232
233 let _ = encoder.run_batch(&b)?; // warmup
234 let mut t_vec = Vec::new();
235 for _ in 0..5 {
236 let (_, ms) = timed(|| encoder.run_batch(&b));
237 t_vec.push(ms);
238 }
239 let avg = t_vec.iter().sum::<f64>() / t_vec.len() as f64;
240 let cmin = t_vec.iter().cloned().fold(f64::INFINITY, f64::min);
241 let cmax = t_vec.iter().cloned().fold(0.0f64, f64::max);
242
243 if !json_mode {
244 eprintln!(" {:>6} {:>7.1} ms", active_ch, avg);
245 }
246
247 channel_scaling.push(serde_json::json!({
248 "active_channels": active_ch,
249 "total_channels": n_ch,
250 "mean_ms": round2(avg),
251 "min_ms": round2(cmin),
252 "max_ms": round2(cmax),
253 "runs": t_vec,
254 }));
255 }
256
257 if !json_mode { eprintln!(); }
258
259 // ── JSON output ─────────────────────────────────────────────────────────
260 let result = serde_json::json!({
261 "backend": backend::NAME,
262 "model": {
263 "encoder_name": model_cfg.encoder_name,
264 "width": model_cfg.width,
265 "depth": model_cfg.depth,
266 "heads": model_cfg.heads,
267 "num_leads": model_cfg.num_leads,
268 "patch_size_time": model_cfg.patch_size_time,
269 "patch_size_ch": model_cfg.patch_size_ch,
270 "seq_len": model_cfg.seq_len,
271 },
272 "load_ms": round2(ms_load),
273 "inference": {
274 "channels": n_ch,
275 "samples": n_t,
276 "warmup": args.warmup,
277 "runs": args.runs,
278 "mean_ms": round2(infer_mean),
279 "min_ms": round2(infer_min),
280 "max_ms": round2(infer_max),
281 "std_ms": round2(infer_std),
282 "all_ms": infer_times,
283 },
284 "batch_scaling": batch_scaling,
285 "channel_scaling": channel_scaling,
286 "throughput_epochs_sec": round2(throughput_epochs_sec),
287 });
288
289 if json_mode {
290 println!("{}", serde_json::to_string_pretty(&result)?);
291 }
292
293 Ok(())
294}pub fn device(&self) -> &B::Device
Auto Trait Implementations§
impl<B> !Freeze for OsfEncoder<B>
impl<B> !RefUnwindSafe for OsfEncoder<B>
impl<B> Send for OsfEncoder<B>
impl<B> !Sync for OsfEncoder<B>
impl<B> Unpin for OsfEncoder<B>where
<B as Backend>::Device: Unpin,
<B as Backend>::FloatTensorPrimitive: Unpin,
<B as Backend>::QuantizedTensorPrimitive: Unpin,
impl<B> UnsafeUnpin for OsfEncoder<B>where
<B as Backend>::Device: UnsafeUnpin,
<B as Backend>::FloatTensorPrimitive: UnsafeUnpin,
<B as Backend>::QuantizedTensorPrimitive: UnsafeUnpin,
impl<B> !UnwindSafe for OsfEncoder<B>
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more