insta_fun/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::fmt::Write;
4
5use fundsp::prelude::*;
6use insta::assert_binary_snapshot;
7
8const DEFAULT_HEIGHT: usize = 100;
9
10#[derive(Debug, Clone, Copy)]
11/// Configuration for snapshotting an audio node.
12pub struct SnapshotConfig {
13    /// Number of samples to generate.
14    ///
15    /// `Default` is 44100 - 1s
16    pub num_samples: usize,
17    /// Sample rate of the audio node.
18    ///
19    /// `Default` is 44100.0
20    pub sample_rate: f64,
21    /// Optional width of the SVG `viewBox`
22    ///
23    /// `None` means proportional to num_samples
24    pub svg_width: Option<usize>,
25    /// Height of **one** channel in the SVG `viewBox`
26    ///
27    /// `None` fallbacks to default - 100
28    pub svg_height_per_channel: Option<usize>,
29    /// Processing mode for snapshotting an audio node.
30    pub processing_mode: Processing,
31}
32
33/// Processing mode for snapshotting an audio node.
34#[derive(Debug, Clone, Copy, Default)]
35pub enum Processing {
36    #[default]
37    /// Process one sample at a time.
38    Tick,
39    /// Process a batch of samples at a time.
40    ///
41    /// max batch size is 64
42    Batch(u8),
43}
44
45impl Default for SnapshotConfig {
46    fn default() -> Self {
47        Self {
48            num_samples: 44100,
49            sample_rate: 44100.0,
50            svg_width: None,
51            svg_height_per_channel: Some(DEFAULT_HEIGHT),
52            processing_mode: Processing::default(),
53        }
54    }
55}
56
57impl SnapshotConfig {
58    pub fn with_samples(num_samples: usize) -> Self {
59        Self {
60            num_samples,
61            ..Default::default()
62        }
63    }
64}
65
66/// Input provided to the audio node
67pub enum InputSource {
68    /// No input
69    None,
70    /// Input provided by a channel vec
71    ///
72    /// - First vec contains all **channels**
73    /// - Second vec contains **samples** per channel
74    VecByChannel(Vec<Vec<f32>>),
75    /// Input provided by a tick vec
76    ///
77    /// - First vec contains all **ticks**
78    /// - Second vec contains **samples** for all **channels** per tick
79    VecByTick(Vec<Vec<f32>>),
80    /// Input **repeated** on every tick
81    ///
82    /// - Vector contains **samples** for all **channels** for **one** tick
83    Flat(Vec<f32>),
84    /// Input provided by a generator function
85    ///
86    /// - First argument is the sample index
87    /// - Second argument is the channel index
88    Generator(Box<dyn Fn(usize, usize) -> f32>),
89}
90
91impl InputSource {
92    pub fn impulse() -> Self {
93        Self::Generator(Box::new(|i, _| if i == 0 { 1.0 } else { 0.0 }))
94    }
95    pub fn sine(freq: f32, sr: f32) -> Self {
96        Self::Generator(Box::new(move |i, _| {
97            let phase = 2.0 * std::f32::consts::PI * freq * i as f32 / sr;
98            phase.sin()
99        }))
100    }
101}
102
103const OUTPUT_CHANNEL_COLORS: &[&str] = &[
104    "#4285F4", "#EA4335", "#FBBC04", "#34A853", "#FF6D00", "#AB47BC", "#00ACC1", "#7CB342",
105    "#9C27B0", "#3F51B5", "#009688", "#8BC34A", "#FFEB3B", "#FF9800", "#795548", "#607D8B",
106    "#E91E63", "#673AB7", "#2196F3", "#00BCD4", "#4CAF50", "#CDDC39", "#FFC107", "#FF5722",
107    "#9E9E9E", "#03A9F4", "#8D6E63", "#78909C", "#880E4F", "#4A148C", "#0D47A1", "#004D40",
108];
109
110const INPUT_CHANNEL_COLORS: &[&str] = &[
111    "#B39DDB", "#FFAB91", "#FFF59D", "#A5D6A7", "#FFCC80", "#CE93D8", "#80DEEA", "#C5E1A5",
112    "#BA68C8", "#9FA8DA", "#80CBC4", "#DCE775", "#FFF176", "#FFB74D", "#BCAAA4", "#B0BEC5",
113    "#F48FB1", "#B39DDB", "#90CAF9", "#80DEEA", "#A5D6A7", "#E6EE9C", "#FFD54F", "#FF8A65",
114    "#BDBDBD", "#81D4FA", "#A1887F", "#90A4AE", "#C2185B", "#7B1FA2", "#1976D2", "#00796B",
115];
116
117const PADDING: isize = 10;
118
119/// Create an SVG snapshot of audio node outputs
120/// ## Example
121///
122/// ```
123/// use insta_fun::*;
124/// use fundsp::hacker::prelude::*;
125///
126/// let node = sine_hz::<f32>(440.0);
127/// snapshot_audio_node("sine_hz_4", node);
128/// ```
129pub fn snapshot_audio_node<N>(name: &str, node: N)
130where
131    N: AudioUnit,
132{
133    snapshot_audionode_with_input_and_options(
134        name,
135        node,
136        InputSource::None,
137        SnapshotConfig::default(),
138    )
139}
140
141/// Create an SVG snapshot of audio node outputs, with options
142///
143/// ## Example
144///
145/// ```
146/// use insta_fun::*;
147/// use fundsp::hacker::prelude::*;
148///
149/// let node = sine_hz::<f32>(440.0);
150/// snapshot_audio_node_with_options("sine_hz_3", node, SnapshotConfig::default());
151/// ```
152pub fn snapshot_audio_node_with_options<N>(name: &str, node: N, options: SnapshotConfig)
153where
154    N: AudioUnit,
155{
156    snapshot_audionode_with_input_and_options(name, node, InputSource::None, options)
157}
158
159/// Create an SVG snapshot of audio node inputs and outputs
160///
161/// ## Example
162///
163/// ```
164/// use insta_fun::*;
165/// use fundsp::hacker::prelude::*;
166///
167/// let node = sine_hz::<f32>(440.0);
168/// snapshot_audio_node_with_input("sine_hz_2", node, InputSource::None);
169/// ```
170pub fn snapshot_audio_node_with_input<N>(name: &str, node: N, input_source: InputSource)
171where
172    N: AudioUnit,
173{
174    snapshot_audionode_with_input_and_options(name, node, input_source, SnapshotConfig::default())
175}
176
177/// Create an SVG snapshot of audio node inputs and outputs, with options
178///
179/// ## Example
180///
181/// ```
182/// use insta_fun::*;
183/// use fundsp::hacker::prelude::*;
184///
185/// let config = SnapshotConfig::default();
186/// let node = sine_hz::<f32>(440.0);
187/// snapshot_audionode_with_input_and_options("sine_hz_1", node, InputSource::None, config);
188/// ```
189pub fn snapshot_audionode_with_input_and_options<N>(
190    name: &str,
191    mut node: N,
192    input_source: InputSource,
193    config: SnapshotConfig,
194) where
195    N: AudioUnit,
196{
197    let num_inputs = N::inputs(&node);
198    let num_outputs = N::outputs(&node);
199
200    node.set_sample_rate(config.sample_rate);
201    node.reset();
202    node.allocate();
203
204    let input_data = match input_source {
205        InputSource::None => vec![vec![0.0; config.num_samples]; num_inputs],
206        InputSource::VecByChannel(data) => {
207            assert_eq!(
208                data.len(),
209                num_inputs,
210                "Input vec size mismatch. Expected {} channels, got {}",
211                num_inputs,
212                data.len()
213            );
214            assert!(
215                data.iter().all(|v| v.len() == config.num_samples),
216                "Input vec size mismatch. Expected {} samples per channel, got {}",
217                config.num_samples,
218                data.iter().map(|v| v.len()).max().unwrap_or(0)
219            );
220            data
221        }
222        InputSource::VecByTick(data) => {
223            assert!(
224                data.iter().all(|v| v.len() == num_inputs),
225                "Input vec size mismatch. Expected {} channels, got {}",
226                num_inputs,
227                data.iter().map(|v| v.len()).max().unwrap_or(0)
228            );
229            assert_eq!(
230                data.len(),
231                config.num_samples,
232                "Input vec size mismatch. Expected {} samples, got {}",
233                config.num_samples,
234                data.len()
235            );
236            (0..num_inputs)
237                .map(|ch| (0..config.num_samples).map(|i| data[i][ch]).collect())
238                .collect()
239        }
240        InputSource::Flat(data) => {
241            assert_eq!(
242                data.len(),
243                num_inputs,
244                "Input vec size mismatch. Expected {} channels, got {}",
245                num_inputs,
246                data.len()
247            );
248            (0..num_inputs)
249                .map(|ch| (0..config.num_samples).map(|_| data[ch]).collect())
250                .collect()
251        }
252        InputSource::Generator(generator_fn) => (0..num_inputs)
253            .map(|ch| {
254                (0..config.num_samples)
255                    .map(|i| generator_fn(i, ch))
256                    .collect()
257            })
258            .collect(),
259    };
260
261    let mut output_data: Vec<Vec<f32>> = vec![vec![]; num_outputs];
262
263    match config.processing_mode {
264        Processing::Tick => {
265            (0..config.num_samples).for_each(|i| {
266                let mut input_frame = vec![0.0; num_inputs];
267                for ch in 0..num_inputs {
268                    input_frame[ch] = input_data[ch][i] as f32;
269                }
270                let mut output_frame = vec![0.0; num_outputs];
271                node.tick(&input_frame, &mut output_frame);
272                for ch in 0..num_outputs {
273                    output_data[ch].push(output_frame[ch]);
274                }
275            });
276        }
277        Processing::Batch(batch_size) => {
278            assert!(
279                batch_size <= 64,
280                "Batch size must be less than or equal to 64"
281            );
282
283            let samples_index = (0..config.num_samples).collect::<Vec<_>>();
284            for chunk in samples_index.chunks(batch_size as usize) {
285                let mut input_buff = BufferVec::new(num_inputs);
286                for i in chunk {
287                    for (ch, input_data) in input_data.iter().enumerate() {
288                        let value: f32 = input_data[*i];
289                        input_buff.set_f32(ch, *i, value);
290                    }
291                }
292                let input_ref = input_buff.buffer_ref();
293                let mut output_buf = BufferVec::new(num_outputs);
294                let mut output_ref = output_buf.buffer_mut();
295
296                node.process(chunk.len(), &input_ref, &mut output_ref);
297
298                for (ch, data) in output_data.iter_mut().enumerate() {
299                    data.extend_from_slice(output_buf.channel_f32(ch));
300                }
301            }
302        }
303    }
304
305    let svg = generate_svg(&input_data, &output_data, &config);
306
307    assert_binary_snapshot!(&format!("{name}.svg"), svg.as_bytes().to_vec());
308}
309
310fn generate_svg(
311    input_data: &[Vec<f32>],
312    output_data: &[Vec<f32>],
313    config: &SnapshotConfig,
314) -> String {
315    let height_per_channel = config.svg_height_per_channel.unwrap_or(DEFAULT_HEIGHT);
316    let num_channels = output_data.len() + input_data.len();
317    let num_samples = output_data.first().map(|c| c.len()).unwrap_or(0);
318    if num_samples == 0 || num_channels == 0 {
319        return "<svg xmlns=\"http://www.w3.org/2000/svg\" viewBox=\"0 0 100 100\" preserveAspectRatio=\"none\"><text>Empty</text></svg>".to_string();
320    }
321
322    let svg_width = config.svg_width.unwrap_or(config.num_samples);
323    let total_height = height_per_channel * num_channels;
324    let y_scale = (height_per_channel as f32 / 2.0) * 0.9;
325    let x_scale = config
326        .svg_width
327        .map(|width| width as f32 / config.num_samples as f32);
328    let stroke_width = if let Some(scale) = x_scale {
329        (2.0 / scale).clamp(0.5, 5.0)
330    } else {
331        2.0
332    };
333
334    let mut svg = String::new();
335    let mut y_offset = 0;
336
337    writeln!(
338        &mut svg,
339        r#"<svg xmlns="http://www.w3.org/2000/svg" viewBox="{start_x} {start_y} {width} {height}" preserveAspectRatio="none">
340        <rect x="{start_x}" y="{start_y}" width="{background_width}" height="{background_height}" fill="black" />"#,
341        start_x = -PADDING,
342        start_y = -PADDING,
343        width = svg_width as isize + PADDING,
344        height = total_height as isize + PADDING,
345        background_width = svg_width as isize + PADDING * 2,
346        background_height = total_height as isize + PADDING * 2
347    ).unwrap();
348
349    let mut write_data = |all_channels_data: &[Vec<f32>], is_input: bool| {
350        for (ch, data) in all_channels_data.iter().enumerate() {
351            let color = if is_input {
352                INPUT_CHANNEL_COLORS[ch % INPUT_CHANNEL_COLORS.len()]
353            } else {
354                OUTPUT_CHANNEL_COLORS[ch % OUTPUT_CHANNEL_COLORS.len()]
355            };
356            let y_center = y_offset + height_per_channel / 2;
357
358            let min_val = data.iter().cloned().fold(f32::INFINITY, f32::min);
359            let max_val = data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
360            let range = (max_val - min_val).max(f32::EPSILON);
361
362            let mut path_data = String::from("M ");
363            for (i, &sample) in data.iter().enumerate() {
364                let x = if let Some(scale) = x_scale {
365                    scale * i as f32
366                } else {
367                    i as f32
368                };
369                let normalized = (sample.clamp(min_val, max_val) - min_val) / range * 2.0 - 1.0;
370                let y = y_center as f32 - normalized * y_scale;
371                if i == 0 {
372                    write!(&mut path_data, "{:.6},{:.6} ", x, y).unwrap();
373                } else {
374                    write!(&mut path_data, "L {:.6},{:.6} ", x, y).unwrap();
375                }
376            }
377
378            writeln!(
379                &mut svg,
380                r#"  <path d="{path_data}" fill="none" stroke="{color}" stroke-width="{stroke_width}"/>"#,
381            )
382            .unwrap();
383
384            writeln!(
385                &mut svg,
386                r#"  <text x="5" y="{y}" font-family="monospace" font-size="12" fill="{color}">{label} Ch#{ch}</text>"#,
387                y = y_offset + 15,
388                color = color,
389                label = if is_input {"Input"} else {"Output"},
390                ch=ch
391            )
392            .unwrap();
393
394            y_offset += height_per_channel
395        }
396    };
397
398    write_data(input_data, true);
399    write_data(output_data, false);
400
401    svg.push_str("</svg>");
402    svg
403}
404
405// Example tests
406#[cfg(test)]
407mod tests {
408    use super::*;
409
410    #[test]
411    fn test_sine() {
412        let config = SnapshotConfig::default();
413        let node = sine_hz::<f32>(440.0);
414        snapshot_audionode_with_input_and_options("sine_hz", node, InputSource::None, config);
415    }
416
417    #[test]
418    fn test_custom_input() {
419        let config = SnapshotConfig::with_samples(100);
420        let input = (0..100).map(|i| (i as f32 / 50.0).sin()).collect();
421
422        snapshot_audionode_with_input_and_options(
423            "filter_sine",
424            lowpass_hz(500.0, 0.7),
425            InputSource::VecByChannel(vec![input]),
426            config,
427        );
428    }
429
430    #[test]
431    fn test_stereo() {
432        let config = SnapshotConfig::default();
433        let node = sine_hz::<f32>(440.0) | sine_hz::<f32>(880.0);
434
435        snapshot_audionode_with_input_and_options("stereo", node, InputSource::None, config);
436    }
437
438    #[test]
439    fn test_lowpass_impulse() {
440        let config = SnapshotConfig::with_samples(300);
441        let node = lowpass_hz(1000.0, 1.0);
442
443        snapshot_audionode_with_input_and_options(
444            "lowpass_impulse",
445            node,
446            InputSource::impulse(),
447            config,
448        );
449    }
450
451    #[test]
452    fn test_net() {
453        let config = SnapshotConfig::with_samples(420);
454        let node = sine_hz::<f32>(440.0) >> lowpass_hz(500.0, 0.7);
455        let mut net = Net::new(0, 1);
456        let node_id = net.push(Box::new(node));
457        net.pipe_input(node_id);
458        net.pipe_output(node_id);
459
460        snapshot_audionode_with_input_and_options("net", net, InputSource::None, config);
461    }
462
463    #[test]
464    fn test_batch_prcessing() {
465        let config = SnapshotConfig {
466            processing_mode: Processing::Batch(64),
467            ..Default::default()
468        };
469
470        let node = sine_hz::<f32>(440.0);
471
472        snapshot_audio_node_with_options("process_64", node, config);
473    }
474
475    #[test]
476    fn test_vec_by_tick() {
477        let config = SnapshotConfig::with_samples(100);
478        // Create input data organized by ticks (100 ticks, 1 channel each)
479        let input_data: Vec<Vec<f32>> = (0..100).map(|i| vec![(i as f32 / 50.0).cos()]).collect();
480
481        snapshot_audionode_with_input_and_options(
482            "vec_by_tick",
483            lowpass_hz(800.0, 0.5),
484            InputSource::VecByTick(input_data),
485            config,
486        );
487    }
488
489    #[test]
490    fn test_flat_input() {
491        let config = SnapshotConfig::with_samples(200);
492        // Flat input repeated for every tick
493        let flat_input = vec![0.5];
494
495        snapshot_audionode_with_input_and_options(
496            "flat_input",
497            highpass_hz(200.0, 0.7),
498            InputSource::Flat(flat_input),
499            config,
500        );
501    }
502
503    #[test]
504    fn test_sine_input_source() {
505        let config = SnapshotConfig::with_samples(200);
506
507        snapshot_audionode_with_input_and_options(
508            "sine_input_source",
509            bandpass_hz(1000.0, 500.0),
510            InputSource::sine(100.0, 44100.0),
511            config,
512        );
513    }
514
515    #[test]
516    fn test_multi_channel_vec_by_channel() {
517        let config = SnapshotConfig::with_samples(150);
518        // Create stereo input data
519        let left_channel: Vec<f32> = (0..150)
520            .map(|i| (i as f32 / 75.0 * std::f32::consts::PI).sin())
521            .collect();
522        let right_channel: Vec<f32> = (0..150)
523            .map(|i| (i as f32 / 75.0 * std::f32::consts::PI).cos())
524            .collect();
525
526        let node = resonator_hz(440.0, 100.0) | resonator_hz(440.0, 100.0);
527
528        snapshot_audionode_with_input_and_options(
529            "multi_channel_vec",
530            node,
531            InputSource::VecByChannel(vec![left_channel, right_channel]),
532            config,
533        );
534    }
535}