1#![doc = include_str!("../README.md")]
2
3use std::fmt::Write;
4
5use fundsp::prelude::*;
6use insta::assert_binary_snapshot;
7
8const DEFAULT_HEIGHT: usize = 100;
9
10#[derive(Debug, Clone, Copy)]
11pub struct SnapshotConfig {
13 pub num_samples: usize,
17 pub sample_rate: f64,
21 pub svg_width: Option<usize>,
25 pub svg_height_per_channel: Option<usize>,
29 pub processing_mode: Processing,
31}
32
33#[derive(Debug, Clone, Copy, Default)]
35pub enum Processing {
36 #[default]
37 Tick,
39 Batch(u8),
43}
44
45impl Default for SnapshotConfig {
46 fn default() -> Self {
47 Self {
48 num_samples: 44100,
49 sample_rate: 44100.0,
50 svg_width: None,
51 svg_height_per_channel: Some(DEFAULT_HEIGHT),
52 processing_mode: Processing::default(),
53 }
54 }
55}
56
57impl SnapshotConfig {
58 pub fn with_samples(num_samples: usize) -> Self {
59 Self {
60 num_samples,
61 ..Default::default()
62 }
63 }
64}
65
66pub enum InputSource {
68 None,
70 VecByChannel(Vec<Vec<f32>>),
75 VecByTick(Vec<Vec<f32>>),
80 Flat(Vec<f32>),
84 Generator(Box<dyn Fn(usize, usize) -> f32>),
89}
90
91impl InputSource {
92 pub fn impulse() -> Self {
93 Self::Generator(Box::new(|i, _| if i == 0 { 1.0 } else { 0.0 }))
94 }
95 pub fn sine(freq: f32, sr: f32) -> Self {
96 Self::Generator(Box::new(move |i, _| {
97 let phase = 2.0 * std::f32::consts::PI * freq * i as f32 / sr;
98 phase.sin()
99 }))
100 }
101}
102
103const OUTPUT_CHANNEL_COLORS: &[&str] = &[
104 "#4285F4", "#EA4335", "#FBBC04", "#34A853", "#FF6D00", "#AB47BC", "#00ACC1", "#7CB342",
105 "#9C27B0", "#3F51B5", "#009688", "#8BC34A", "#FFEB3B", "#FF9800", "#795548", "#607D8B",
106 "#E91E63", "#673AB7", "#2196F3", "#00BCD4", "#4CAF50", "#CDDC39", "#FFC107", "#FF5722",
107 "#9E9E9E", "#03A9F4", "#8D6E63", "#78909C", "#880E4F", "#4A148C", "#0D47A1", "#004D40",
108];
109
110const INPUT_CHANNEL_COLORS: &[&str] = &[
111 "#B39DDB", "#FFAB91", "#FFF59D", "#A5D6A7", "#FFCC80", "#CE93D8", "#80DEEA", "#C5E1A5",
112 "#BA68C8", "#9FA8DA", "#80CBC4", "#DCE775", "#FFF176", "#FFB74D", "#BCAAA4", "#B0BEC5",
113 "#F48FB1", "#B39DDB", "#90CAF9", "#80DEEA", "#A5D6A7", "#E6EE9C", "#FFD54F", "#FF8A65",
114 "#BDBDBD", "#81D4FA", "#A1887F", "#90A4AE", "#C2185B", "#7B1FA2", "#1976D2", "#00796B",
115];
116
117const PADDING: isize = 10;
118
119pub fn snapshot_audio_node<N>(name: &str, node: N)
130where
131 N: AudioUnit,
132{
133 snapshot_audionode_with_input_and_options(
134 name,
135 node,
136 InputSource::None,
137 SnapshotConfig::default(),
138 )
139}
140
141pub fn snapshot_audio_node_with_options<N>(name: &str, node: N, options: SnapshotConfig)
153where
154 N: AudioUnit,
155{
156 snapshot_audionode_with_input_and_options(name, node, InputSource::None, options)
157}
158
159pub fn snapshot_audio_node_with_input<N>(name: &str, node: N, input_source: InputSource)
171where
172 N: AudioUnit,
173{
174 snapshot_audionode_with_input_and_options(name, node, input_source, SnapshotConfig::default())
175}
176
177pub fn snapshot_audionode_with_input_and_options<N>(
190 name: &str,
191 mut node: N,
192 input_source: InputSource,
193 config: SnapshotConfig,
194) where
195 N: AudioUnit,
196{
197 let num_inputs = N::inputs(&node);
198 let num_outputs = N::outputs(&node);
199
200 node.set_sample_rate(config.sample_rate);
201 node.reset();
202 node.allocate();
203
204 let input_data = match input_source {
205 InputSource::None => vec![vec![0.0; config.num_samples]; num_inputs],
206 InputSource::VecByChannel(data) => {
207 assert_eq!(
208 data.len(),
209 num_inputs,
210 "Input vec size mismatch. Expected {} channels, got {}",
211 num_inputs,
212 data.len()
213 );
214 assert!(
215 data.iter().all(|v| v.len() == config.num_samples),
216 "Input vec size mismatch. Expected {} samples per channel, got {}",
217 config.num_samples,
218 data.iter().map(|v| v.len()).max().unwrap_or(0)
219 );
220 data
221 }
222 InputSource::VecByTick(data) => {
223 assert!(
224 data.iter().all(|v| v.len() == num_inputs),
225 "Input vec size mismatch. Expected {} channels, got {}",
226 num_inputs,
227 data.iter().map(|v| v.len()).max().unwrap_or(0)
228 );
229 assert_eq!(
230 data.len(),
231 config.num_samples,
232 "Input vec size mismatch. Expected {} samples, got {}",
233 config.num_samples,
234 data.len()
235 );
236 (0..num_inputs)
237 .map(|ch| (0..config.num_samples).map(|i| data[i][ch]).collect())
238 .collect()
239 }
240 InputSource::Flat(data) => {
241 assert_eq!(
242 data.len(),
243 num_inputs,
244 "Input vec size mismatch. Expected {} channels, got {}",
245 num_inputs,
246 data.len()
247 );
248 (0..num_inputs)
249 .map(|ch| (0..config.num_samples).map(|_| data[ch]).collect())
250 .collect()
251 }
252 InputSource::Generator(generator_fn) => (0..num_inputs)
253 .map(|ch| {
254 (0..config.num_samples)
255 .map(|i| generator_fn(i, ch))
256 .collect()
257 })
258 .collect(),
259 };
260
261 let mut output_data: Vec<Vec<f32>> = vec![vec![]; num_outputs];
262
263 match config.processing_mode {
264 Processing::Tick => {
265 (0..config.num_samples).for_each(|i| {
266 let mut input_frame = vec![0.0; num_inputs];
267 for ch in 0..num_inputs {
268 input_frame[ch] = input_data[ch][i] as f32;
269 }
270 let mut output_frame = vec![0.0; num_outputs];
271 node.tick(&input_frame, &mut output_frame);
272 for ch in 0..num_outputs {
273 output_data[ch].push(output_frame[ch]);
274 }
275 });
276 }
277 Processing::Batch(batch_size) => {
278 assert!(
279 batch_size <= 64,
280 "Batch size must be less than or equal to 64"
281 );
282
283 let samples_index = (0..config.num_samples).collect::<Vec<_>>();
284 for chunk in samples_index.chunks(batch_size as usize) {
285 let mut input_buff = BufferVec::new(num_inputs);
286 for i in chunk {
287 for (ch, input_data) in input_data.iter().enumerate() {
288 let value: f32 = input_data[*i];
289 input_buff.set_f32(ch, *i, value);
290 }
291 }
292 let input_ref = input_buff.buffer_ref();
293 let mut output_buf = BufferVec::new(num_outputs);
294 let mut output_ref = output_buf.buffer_mut();
295
296 node.process(chunk.len(), &input_ref, &mut output_ref);
297
298 for (ch, data) in output_data.iter_mut().enumerate() {
299 data.extend_from_slice(output_buf.channel_f32(ch));
300 }
301 }
302 }
303 }
304
305 let svg = generate_svg(&input_data, &output_data, &config);
306
307 assert_binary_snapshot!(&format!("{name}.svg"), svg.as_bytes().to_vec());
308}
309
310fn generate_svg(
311 input_data: &[Vec<f32>],
312 output_data: &[Vec<f32>],
313 config: &SnapshotConfig,
314) -> String {
315 let height_per_channel = config.svg_height_per_channel.unwrap_or(DEFAULT_HEIGHT);
316 let num_channels = output_data.len() + input_data.len();
317 let num_samples = output_data.first().map(|c| c.len()).unwrap_or(0);
318 if num_samples == 0 || num_channels == 0 {
319 return "<svg xmlns=\"http://www.w3.org/2000/svg\" viewBox=\"0 0 100 100\" preserveAspectRatio=\"none\"><text>Empty</text></svg>".to_string();
320 }
321
322 let svg_width = config.svg_width.unwrap_or(config.num_samples);
323 let total_height = height_per_channel * num_channels;
324 let y_scale = (height_per_channel as f32 / 2.0) * 0.9;
325 let x_scale = config
326 .svg_width
327 .map(|width| width as f32 / config.num_samples as f32);
328 let stroke_width = if let Some(scale) = x_scale {
329 (2.0 / scale).clamp(0.5, 5.0)
330 } else {
331 2.0
332 };
333
334 let mut svg = String::new();
335 let mut y_offset = 0;
336
337 writeln!(
338 &mut svg,
339 r#"<svg xmlns="http://www.w3.org/2000/svg" viewBox="{start_x} {start_y} {width} {height}" preserveAspectRatio="none">
340 <rect x="{start_x}" y="{start_y}" width="{background_width}" height="{background_height}" fill="black" />"#,
341 start_x = -PADDING,
342 start_y = -PADDING,
343 width = svg_width as isize + PADDING,
344 height = total_height as isize + PADDING,
345 background_width = svg_width as isize + PADDING * 2,
346 background_height = total_height as isize + PADDING * 2
347 ).unwrap();
348
349 let mut write_data = |all_channels_data: &[Vec<f32>], is_input: bool| {
350 for (ch, data) in all_channels_data.iter().enumerate() {
351 let color = if is_input {
352 INPUT_CHANNEL_COLORS[ch % INPUT_CHANNEL_COLORS.len()]
353 } else {
354 OUTPUT_CHANNEL_COLORS[ch % OUTPUT_CHANNEL_COLORS.len()]
355 };
356 let y_center = y_offset + height_per_channel / 2;
357
358 let min_val = data.iter().cloned().fold(f32::INFINITY, f32::min);
359 let max_val = data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
360 let range = (max_val - min_val).max(f32::EPSILON);
361
362 let mut path_data = String::from("M ");
363 for (i, &sample) in data.iter().enumerate() {
364 let x = if let Some(scale) = x_scale {
365 scale * i as f32
366 } else {
367 i as f32
368 };
369 let normalized = (sample.clamp(min_val, max_val) - min_val) / range * 2.0 - 1.0;
370 let y = y_center as f32 - normalized * y_scale;
371 if i == 0 {
372 write!(&mut path_data, "{:.6},{:.6} ", x, y).unwrap();
373 } else {
374 write!(&mut path_data, "L {:.6},{:.6} ", x, y).unwrap();
375 }
376 }
377
378 writeln!(
379 &mut svg,
380 r#" <path d="{path_data}" fill="none" stroke="{color}" stroke-width="{stroke_width}"/>"#,
381 )
382 .unwrap();
383
384 writeln!(
385 &mut svg,
386 r#" <text x="5" y="{y}" font-family="monospace" font-size="12" fill="{color}">{label} Ch#{ch}</text>"#,
387 y = y_offset + 15,
388 color = color,
389 label = if is_input {"Input"} else {"Output"},
390 ch=ch
391 )
392 .unwrap();
393
394 y_offset += height_per_channel
395 }
396 };
397
398 write_data(input_data, true);
399 write_data(output_data, false);
400
401 svg.push_str("</svg>");
402 svg
403}
404
405#[cfg(test)]
407mod tests {
408 use super::*;
409
410 #[test]
411 fn test_sine() {
412 let config = SnapshotConfig::default();
413 let node = sine_hz::<f32>(440.0);
414 snapshot_audionode_with_input_and_options("sine_hz", node, InputSource::None, config);
415 }
416
417 #[test]
418 fn test_custom_input() {
419 let config = SnapshotConfig::with_samples(100);
420 let input = (0..100).map(|i| (i as f32 / 50.0).sin()).collect();
421
422 snapshot_audionode_with_input_and_options(
423 "filter_sine",
424 lowpass_hz(500.0, 0.7),
425 InputSource::VecByChannel(vec![input]),
426 config,
427 );
428 }
429
430 #[test]
431 fn test_stereo() {
432 let config = SnapshotConfig::default();
433 let node = sine_hz::<f32>(440.0) | sine_hz::<f32>(880.0);
434
435 snapshot_audionode_with_input_and_options("stereo", node, InputSource::None, config);
436 }
437
438 #[test]
439 fn test_lowpass_impulse() {
440 let config = SnapshotConfig::with_samples(300);
441 let node = lowpass_hz(1000.0, 1.0);
442
443 snapshot_audionode_with_input_and_options(
444 "lowpass_impulse",
445 node,
446 InputSource::impulse(),
447 config,
448 );
449 }
450
451 #[test]
452 fn test_net() {
453 let config = SnapshotConfig::with_samples(420);
454 let node = sine_hz::<f32>(440.0) >> lowpass_hz(500.0, 0.7);
455 let mut net = Net::new(0, 1);
456 let node_id = net.push(Box::new(node));
457 net.pipe_input(node_id);
458 net.pipe_output(node_id);
459
460 snapshot_audionode_with_input_and_options("net", net, InputSource::None, config);
461 }
462
463 #[test]
464 fn test_batch_prcessing() {
465 let config = SnapshotConfig {
466 processing_mode: Processing::Batch(64),
467 ..Default::default()
468 };
469
470 let node = sine_hz::<f32>(440.0);
471
472 snapshot_audio_node_with_options("process_64", node, config);
473 }
474
475 #[test]
476 fn test_vec_by_tick() {
477 let config = SnapshotConfig::with_samples(100);
478 let input_data: Vec<Vec<f32>> = (0..100).map(|i| vec![(i as f32 / 50.0).cos()]).collect();
480
481 snapshot_audionode_with_input_and_options(
482 "vec_by_tick",
483 lowpass_hz(800.0, 0.5),
484 InputSource::VecByTick(input_data),
485 config,
486 );
487 }
488
489 #[test]
490 fn test_flat_input() {
491 let config = SnapshotConfig::with_samples(200);
492 let flat_input = vec![0.5];
494
495 snapshot_audionode_with_input_and_options(
496 "flat_input",
497 highpass_hz(200.0, 0.7),
498 InputSource::Flat(flat_input),
499 config,
500 );
501 }
502
503 #[test]
504 fn test_sine_input_source() {
505 let config = SnapshotConfig::with_samples(200);
506
507 snapshot_audionode_with_input_and_options(
508 "sine_input_source",
509 bandpass_hz(1000.0, 500.0),
510 InputSource::sine(100.0, 44100.0),
511 config,
512 );
513 }
514
515 #[test]
516 fn test_multi_channel_vec_by_channel() {
517 let config = SnapshotConfig::with_samples(150);
518 let left_channel: Vec<f32> = (0..150)
520 .map(|i| (i as f32 / 75.0 * std::f32::consts::PI).sin())
521 .collect();
522 let right_channel: Vec<f32> = (0..150)
523 .map(|i| (i as f32 / 75.0 * std::f32::consts::PI).cos())
524 .collect();
525
526 let node = resonator_hz(440.0, 100.0) | resonator_hz(440.0, 100.0);
527
528 snapshot_audionode_with_input_and_options(
529 "multi_channel_vec",
530 node,
531 InputSource::VecByChannel(vec![left_channel, right_channel]),
532 config,
533 );
534 }
535}