Skip to main content

flutmax_cli/
sim.rs

1//! `flutmax sim` subcommand: run a compiled .maxpat through flutmax-sim
2//! and assert audio properties without Max.
3//!
4//! Example:
5//!   flutmax sim build/synth.maxpat \
6//!     --param mode=1 --param bow_pressure=0.5 \
7//!     --note-on 60 100 \
8//!     --duration 1.0 \
9//!     --assert-peak '>0.05' \
10//!     --assert-not-silent
11
12use flutmax_sim::{AudioOutput, GenSimulator, RnboSimulator};
13use std::fs;
14use std::process;
15
16/// CLI entry point for `flutmax sim`.
17pub fn run(args: &[String]) -> i32 {
18    let mut input_path: Option<String> = None;
19    let mut params: Vec<(String, f64)> = Vec::new();
20    let mut note_on: Vec<(u8, u8)> = Vec::new();
21    let mut note_off: Vec<u8> = Vec::new();
22    let mut signal_input: Option<f64> = None; // Constant input value
23    let mut sample_rate: f64 = 48000.0;
24    let mut duration: f64 = 0.5;
25    let mut sim_mode: SimMode = SimMode::Auto;
26    let mut assertions: Vec<Assertion> = Vec::new();
27    let mut print_metrics = false;
28
29    let mut i = 0;
30    while i < args.len() {
31        match args[i].as_str() {
32            "--param" => {
33                let v = require_arg(args, i, "--param");
34                let (name, value) = parse_kv(&v).unwrap_or_else(|e| {
35                    eprintln!("error: invalid --param {}: {}", v, e);
36                    process::exit(1);
37                });
38                params.push((name, value));
39                i += 2;
40            }
41            "--note-on" => {
42                if i + 2 >= args.len() {
43                    eprintln!("error: --note-on requires <note> <vel>");
44                    process::exit(1);
45                }
46                let n: u8 = args[i + 1].parse().unwrap_or_else(|_| {
47                    eprintln!("error: --note-on note must be 0-127");
48                    process::exit(1);
49                });
50                let v: u8 = args[i + 2].parse().unwrap_or_else(|_| {
51                    eprintln!("error: --note-on velocity must be 0-127");
52                    process::exit(1);
53                });
54                note_on.push((n, v));
55                i += 3;
56            }
57            "--note-off" => {
58                let v = require_arg(args, i, "--note-off");
59                let n: u8 = v.parse().unwrap_or_else(|_| {
60                    eprintln!("error: --note-off note must be 0-127");
61                    process::exit(1);
62                });
63                note_off.push(n);
64                i += 2;
65            }
66            "--signal-input" => {
67                let v = require_arg(args, i, "--signal-input");
68                signal_input = Some(v.parse().unwrap_or_else(|_| {
69                    eprintln!("error: --signal-input must be a number");
70                    process::exit(1);
71                }));
72                i += 2;
73            }
74            "--sample-rate" | "--sr" => {
75                let v = require_arg(args, i, "--sample-rate");
76                sample_rate = v.parse().unwrap_or_else(|_| {
77                    eprintln!("error: --sample-rate must be a number");
78                    process::exit(1);
79                });
80                i += 2;
81            }
82            "--duration" | "-d" => {
83                let v = require_arg(args, i, "--duration");
84                duration = v.parse().unwrap_or_else(|_| {
85                    eprintln!("error: --duration must be a number");
86                    process::exit(1);
87                });
88                i += 2;
89            }
90            "--mode" => {
91                let v = require_arg(args, i, "--mode");
92                sim_mode = match v.as_str() {
93                    "rnbo" => SimMode::Rnbo,
94                    "gen" => SimMode::Gen,
95                    "auto" => SimMode::Auto,
96                    other => {
97                        eprintln!("error: --mode must be rnbo|gen|auto, got '{}'", other);
98                        process::exit(1);
99                    }
100                };
101                i += 2;
102            }
103            "--assert-peak" => {
104                let v = require_arg(args, i, "--assert-peak");
105                let cmp = parse_comparison(&v).unwrap_or_else(|e| {
106                    eprintln!("error: --assert-peak: {}", e);
107                    process::exit(1);
108                });
109                assertions.push(Assertion::Peak(cmp));
110                i += 2;
111            }
112            "--assert-rms" => {
113                let v = require_arg(args, i, "--assert-rms");
114                let cmp = parse_comparison(&v).unwrap_or_else(|e| {
115                    eprintln!("error: --assert-rms: {}", e);
116                    process::exit(1);
117                });
118                assertions.push(Assertion::Rms(cmp));
119                i += 2;
120            }
121            "--assert-silent" => {
122                assertions.push(Assertion::Silent);
123                i += 1;
124            }
125            "--assert-not-silent" => {
126                assertions.push(Assertion::NotSilent);
127                i += 1;
128            }
129            "--assert-frequency" | "--assert-freq" => {
130                if i + 2 >= args.len() {
131                    eprintln!("error: --assert-frequency requires <target> <tolerance>");
132                    process::exit(1);
133                }
134                let target: f64 = args[i + 1].parse().unwrap_or_else(|_| {
135                    eprintln!("error: --assert-frequency target must be a number");
136                    process::exit(1);
137                });
138                let tolerance: f64 = args[i + 2].parse().unwrap_or_else(|_| {
139                    eprintln!("error: --assert-frequency tolerance must be a number");
140                    process::exit(1);
141                });
142                assertions.push(Assertion::Frequency(target, tolerance));
143                i += 3;
144            }
145            "--print-metrics" | "-p" => {
146                print_metrics = true;
147                i += 1;
148            }
149            "--help" | "-h" => {
150                print_help();
151                return 0;
152            }
153            arg if arg.starts_with('-') => {
154                eprintln!("error: unknown option '{}'", arg);
155                print_help();
156                return 1;
157            }
158            arg => {
159                if input_path.is_some() {
160                    eprintln!("error: multiple input paths specified");
161                    return 1;
162                }
163                input_path = Some(arg.to_string());
164                i += 1;
165            }
166        }
167    }
168
169    let input = match input_path {
170        Some(p) => p,
171        None => {
172            eprintln!("error: no input .maxpat file specified");
173            print_help();
174            return 1;
175        }
176    };
177
178    // Load JSON
179    let json = match fs::read_to_string(&input) {
180        Ok(s) => s,
181        Err(e) => {
182            eprintln!("error: failed to read {}: {}", input, e);
183            return 1;
184        }
185    };
186
187    // Auto-detect simulator mode if needed
188    let resolved_mode = match sim_mode {
189        SimMode::Auto => detect_mode(&json),
190        m => m,
191    };
192
193    // Run simulation
194    let output = match resolved_mode {
195        SimMode::Rnbo => run_rnbo(
196            &json,
197            &params,
198            &note_on,
199            &note_off,
200            signal_input,
201            sample_rate,
202            duration,
203        ),
204        SimMode::Gen => run_gen(&json, &params, signal_input, sample_rate, duration),
205        SimMode::Auto => unreachable!(),
206    };
207
208    let output = match output {
209        Ok(o) => o,
210        Err(e) => {
211            eprintln!("error: simulation failed: {}", e);
212            return 1;
213        }
214    };
215
216    // Print metrics if requested
217    if print_metrics || assertions.is_empty() {
218        let peak = output.peak();
219        let rms = output.rms();
220        let freq = output.freq_estimate();
221        println!("peak: {:.6}", peak);
222        println!("rms:  {:.6}", rms);
223        println!("freq: {:.1}", freq);
224        println!(
225            "samples: {}",
226            output.channels.first().map(|c| c.len()).unwrap_or(0)
227        );
228        println!("channels: {}", output.channels.len());
229    }
230
231    // Run assertions
232    let mut failed = 0;
233    for assertion in &assertions {
234        match check_assertion(assertion, &output) {
235            Ok(()) => {}
236            Err(msg) => {
237                eprintln!("FAIL: {}", msg);
238                failed += 1;
239            }
240        }
241    }
242
243    if failed > 0 {
244        eprintln!();
245        eprintln!("{} assertion(s) failed", failed);
246        1
247    } else {
248        if !assertions.is_empty() {
249            println!("All {} assertions passed", assertions.len());
250        }
251        0
252    }
253}
254
255fn print_help() {
256    eprintln!("flutmax sim - run a compiled .maxpat through DSP simulator");
257    eprintln!();
258    eprintln!("USAGE:");
259    eprintln!("    flutmax sim <input.maxpat> [options]");
260    eprintln!();
261    eprintln!("OPTIONS:");
262    eprintln!("    --param <name=value>          Set RNBO param");
263    eprintln!("    --note-on <note> <vel>        Send MIDI Note On (RNBO mode)");
264    eprintln!("    --note-off <note>             Send MIDI Note Off (RNBO mode)");
265    eprintln!("    --signal-input <value>        Constant signal input (gen~ in 0)");
266    eprintln!("    --sample-rate <hz>            Sample rate (default 48000)");
267    eprintln!("    --duration <seconds>          Run duration (default 0.5)");
268    eprintln!("    --mode rnbo|gen|auto          Force simulator mode (default auto)");
269    eprintln!();
270    eprintln!("ASSERTIONS:");
271    eprintln!("    --assert-peak <op N>          e.g. '>0.05', '<1.0', '=0.5'");
272    eprintln!("    --assert-rms <op N>           Same syntax as --assert-peak");
273    eprintln!("    --assert-silent               Output should be silent (peak < 1e-6)");
274    eprintln!("    --assert-not-silent           Output should produce sound");
275    eprintln!("    --assert-frequency <hz> <tol> Frequency within ±tolerance Hz");
276    eprintln!();
277    eprintln!("OUTPUT:");
278    eprintln!("    --print-metrics, -p           Print peak/rms/freq even with assertions");
279    eprintln!();
280    eprintln!("EXAMPLES:");
281    eprintln!("    flutmax sim build/synth.maxpat --param freq=440 --duration 1.0 -p");
282    eprintln!("    flutmax sim build/synth.maxpat --param mode=1 --note-on 60 100 \\");
283    eprintln!("        --assert-peak '>0.05' --assert-not-silent");
284}
285
286#[derive(Clone, Copy)]
287enum SimMode {
288    Auto,
289    Rnbo,
290    Gen,
291}
292
293#[derive(Debug)]
294enum Comparison {
295    Gt(f64),
296    Gte(f64),
297    Lt(f64),
298    Lte(f64),
299    Eq(f64),
300}
301
302#[derive(Debug)]
303enum Assertion {
304    Peak(Comparison),
305    Rms(Comparison),
306    Silent,
307    NotSilent,
308    Frequency(f64, f64),
309}
310
311fn require_arg(args: &[String], i: usize, name: &str) -> String {
312    if i + 1 >= args.len() {
313        eprintln!("error: {} requires an argument", name);
314        process::exit(1);
315    }
316    args[i + 1].clone()
317}
318
319fn parse_kv(s: &str) -> Result<(String, f64), String> {
320    let parts: Vec<&str> = s.splitn(2, '=').collect();
321    if parts.len() != 2 {
322        return Err("expected name=value".into());
323    }
324    let value: f64 = parts[1].parse().map_err(|e| format!("{}", e))?;
325    Ok((parts[0].to_string(), value))
326}
327
328fn parse_comparison(s: &str) -> Result<Comparison, String> {
329    let s = s.trim();
330    let (op, num_str) = if let Some(rest) = s.strip_prefix(">=") {
331        (">=", rest)
332    } else if let Some(rest) = s.strip_prefix("<=") {
333        ("<=", rest)
334    } else if let Some(rest) = s.strip_prefix('>') {
335        (">", rest)
336    } else if let Some(rest) = s.strip_prefix('<') {
337        ("<", rest)
338    } else if let Some(rest) = s.strip_prefix('=') {
339        ("=", rest)
340    } else {
341        return Err(format!(
342            "expected comparison operator (>, <, >=, <=, =), got '{}'",
343            s
344        ));
345    };
346    let value: f64 = num_str
347        .trim()
348        .parse()
349        .map_err(|e| format!("invalid number: {}", e))?;
350    Ok(match op {
351        ">" => Comparison::Gt(value),
352        ">=" => Comparison::Gte(value),
353        "<" => Comparison::Lt(value),
354        "<=" => Comparison::Lte(value),
355        "=" => Comparison::Eq(value),
356        _ => unreachable!(),
357    })
358}
359
360fn detect_mode(json: &str) -> SimMode {
361    // Parse the JSON once and inspect classnamespace fields directly. Falls
362    // back to RNBO mode for top-level patchers (which contain `rnbo~` /
363    // `gen~` boxes but typically no top-level classnamespace).
364    let value: serde_json::Value = match serde_json::from_str(json) {
365        Ok(v) => v,
366        Err(_) => return SimMode::Rnbo,
367    };
368    let top_ns = value
369        .pointer("/patcher/classnamespace")
370        .and_then(|v| v.as_str());
371
372    match top_ns {
373        Some("rnbo") => SimMode::Rnbo,
374        Some("dsp.gen") => {
375            // A standalone gen~ patch unless it is somehow nested inside an
376            // rnbo patcher (rare in practice but handle it).
377            if patcher_contains_rnbo_namespace(&value) {
378                SimMode::Rnbo
379            } else {
380                SimMode::Gen
381            }
382        }
383        _ => SimMode::Rnbo,
384    }
385}
386
387/// Recursively walk a patcher tree and return true if any nested
388/// `classnamespace` field equals `"rnbo"`.
389fn patcher_contains_rnbo_namespace(value: &serde_json::Value) -> bool {
390    fn walk(v: &serde_json::Value) -> bool {
391        if let Some(ns) = v.get("classnamespace").and_then(|n| n.as_str()) {
392            if ns == "rnbo" {
393                return true;
394            }
395        }
396        if let Some(boxes) = v.get("boxes").and_then(|b| b.as_array()) {
397            for b in boxes {
398                let inner = b.get("box").unwrap_or(b);
399                if let Some(p) = inner.get("patcher") {
400                    if walk(p) {
401                        return true;
402                    }
403                }
404            }
405        }
406        false
407    }
408    value.get("patcher").map(walk).unwrap_or(false)
409}
410
411fn run_rnbo(
412    json: &str,
413    params: &[(String, f64)],
414    note_on: &[(u8, u8)],
415    note_off: &[u8],
416    signal_input: Option<f64>,
417    sample_rate: f64,
418    duration: f64,
419) -> Result<AudioOutput, String> {
420    let mut sim = RnboSimulator::from_json_with_sr(json, sample_rate)
421        .map_err(|e| format!("RnboSimulator parse error: {:?}", e))?;
422
423    for (name, value) in params {
424        sim.set_param(name, *value);
425    }
426
427    for &(n, v) in note_on {
428        sim.send_note_on(n, v);
429    }
430    for &n in note_off {
431        sim.send_note_off(n);
432    }
433
434    // CLI exposes a single scalar; apply it to signal input 0.
435    if let Some(v) = signal_input {
436        sim.set_signal_input(0, v);
437    }
438
439    Ok(sim.run_seconds(duration))
440}
441
442fn run_gen(
443    json: &str,
444    params: &[(String, f64)],
445    signal_input: Option<f64>,
446    sample_rate: f64,
447    duration: f64,
448) -> Result<AudioOutput, String> {
449    let mut sim = GenSimulator::from_json_with_sr(json, sample_rate)
450        .map_err(|e| format!("GenSimulator parse error: {:?}", e))?;
451
452    // For gen~, params are positional inputs (in 1, in 2, ...)
453    // params named "in1", "in2", etc., map to indices
454    for (name, value) in params {
455        if let Some(idx_str) = name.strip_prefix("in") {
456            if let Ok(idx) = idx_str.parse::<usize>() {
457                if idx > 0 && idx <= sim.num_inputs() {
458                    sim.set_input(idx - 1, *value);
459                }
460            }
461        }
462    }
463
464    if let Some(v) = signal_input {
465        sim.set_input(0, v);
466    }
467
468    Ok(sim.run_seconds(duration))
469}
470
471fn check_assertion(assertion: &Assertion, output: &AudioOutput) -> Result<(), String> {
472    match assertion {
473        Assertion::Peak(cmp) => {
474            let v = output.peak();
475            check_cmp("peak", v, cmp)
476        }
477        Assertion::Rms(cmp) => {
478            let v = output.rms();
479            check_cmp("rms", v, cmp)
480        }
481        Assertion::Silent => {
482            if output.is_silent() {
483                Ok(())
484            } else {
485                Err(format!("expected silent, got peak={:.6}", output.peak()))
486            }
487        }
488        Assertion::NotSilent => {
489            if !output.is_silent() {
490                Ok(())
491            } else {
492                Err("expected sound, got silence".to_string())
493            }
494        }
495        Assertion::Frequency(target, tolerance) => {
496            let measured = output.freq_estimate();
497            if (measured - target).abs() <= *tolerance {
498                Ok(())
499            } else {
500                Err(format!(
501                    "frequency {:.1} not within ±{} of target {:.1}",
502                    measured, tolerance, target
503                ))
504            }
505        }
506    }
507}
508
509fn check_cmp(name: &str, value: f64, cmp: &Comparison) -> Result<(), String> {
510    let (passed, op_str, target) = match cmp {
511        Comparison::Gt(t) => (value > *t, ">", *t),
512        Comparison::Gte(t) => (value >= *t, ">=", *t),
513        Comparison::Lt(t) => (value < *t, "<", *t),
514        Comparison::Lte(t) => (value <= *t, "<=", *t),
515        Comparison::Eq(t) => {
516            // Peak/RMS values from a multi-thousand-sample DSP run rarely match
517            // a target to floating-point exactness. Use an absolute tolerance of
518            // 1e-6 with a 1e-4 relative tolerance for larger targets (e.g.
519            // frequencies in Hz).
520            let tol = 1e-6_f64.max(t.abs() * 1e-4);
521            ((value - t).abs() <= tol, "=", *t)
522        }
523    };
524    if passed {
525        Ok(())
526    } else {
527        Err(format!("{} {:.6} not {} {}", name, value, op_str, target))
528    }
529}
530
531#[cfg(test)]
532mod tests {
533    use super::*;
534
535    #[test]
536    fn test_eq_passes_within_relative_tolerance() {
537        // 0.5 vs 0.50001 — the old 1e-9 tolerance would fail here.
538        assert!(check_cmp("peak", 0.500_01, &Comparison::Eq(0.5)).is_ok());
539    }
540
541    #[test]
542    fn test_eq_fails_outside_tolerance() {
543        // 0.5 vs 0.6 — clearly not equal, must fail.
544        assert!(check_cmp("peak", 0.6, &Comparison::Eq(0.5)).is_err());
545    }
546
547    #[test]
548    fn test_eq_freq_within_relative_tolerance() {
549        // Larger targets such as a frequency estimate: 440 Hz vs 440.04 Hz is in range.
550        assert!(check_cmp("freq", 440.04, &Comparison::Eq(440.0)).is_ok());
551        // 440 vs 441 is out of range.
552        assert!(check_cmp("freq", 441.0, &Comparison::Eq(440.0)).is_err());
553    }
554
555    #[test]
556    fn test_eq_zero_target_uses_absolute_tolerance() {
557        // When the target is 0 the relative tolerance vanishes; fall back to absolute 1e-6.
558        assert!(check_cmp("rms", 1e-7, &Comparison::Eq(0.0)).is_ok());
559        assert!(check_cmp("rms", 1e-3, &Comparison::Eq(0.0)).is_err());
560    }
561
562    /// Verify that running a gen~ patch with `--sample-rate` produces an
563    /// `AudioOutput.sample_rate` and sample count that match the request.
564    ///
565    /// Regression test for F1: the previous implementation hard-coded 48000.0,
566    /// so passing any other value left the output `sample_rate` unchanged.
567    #[test]
568    fn test_run_gen_honors_sample_rate() {
569        // Pass-through gen~ patch: in 1 → out 1
570        let json = r#"{
571            "patcher": {
572                "boxes": [
573                    {"box": {"id": "a", "text": "in 1"}},
574                    {"box": {"id": "b", "text": "out 1"}}
575                ],
576                "lines": [
577                    {"patchline": {"source": ["a", 0], "destination": ["b", 0]}}
578                ]
579            }
580        }"#;
581
582        // 32000 Hz x 0.1s = 3200 samples
583        let out = run_gen(json, &[], None, 32000.0, 0.1).expect("run_gen");
584        assert!(
585            (out.sample_rate - 32000.0).abs() < 1e-9,
586            "sample_rate mismatch: {}",
587            out.sample_rate
588        );
589        assert_eq!(out.channels[0].len(), 3200);
590
591        // 96000 Hz x 0.1s = 9600 samples
592        let out = run_gen(json, &[], None, 96000.0, 0.1).expect("run_gen");
593        assert!((out.sample_rate - 96000.0).abs() < 1e-9);
594        assert_eq!(out.channels[0].len(), 9600);
595    }
596
597    /// Verify the RNBO simulator also honours `--sample-rate`.
598    #[test]
599    fn test_run_rnbo_honors_sample_rate() {
600        // Minimal RNBO patch: param val 1.0 → out~ 1
601        let json = r#"{
602            "patcher": {
603                "classnamespace": "rnbo",
604                "boxes": [
605                    {"box": {"id": "p", "maxclass": "newobj", "text": "param val 1.0"}},
606                    {"box": {"id": "o", "maxclass": "newobj", "text": "out~ 1"}}
607                ],
608                "lines": [
609                    {"patchline": {"source": ["p", 0], "destination": ["o", 0]}}
610                ]
611            }
612        }"#;
613
614        let out = run_rnbo(json, &[], &[], &[], None, 32000.0, 0.1).expect("run_rnbo");
615        assert!((out.sample_rate - 32000.0).abs() < 1e-9);
616        assert_eq!(out.channels[0].len(), 3200);
617    }
618
619    /// `detect_mode` should distinguish gen~, RNBO, and top-level patches by
620    /// parsing the JSON rather than substring-matching the raw text. F4
621    /// regression test.
622    #[test]
623    fn test_detect_mode_gen() {
624        let json = r#"{"patcher": {"classnamespace": "dsp.gen", "boxes": [], "lines": []}}"#;
625        assert!(matches!(detect_mode(json), SimMode::Gen));
626    }
627
628    #[test]
629    fn test_detect_mode_rnbo() {
630        let json = r#"{"patcher": {"classnamespace": "rnbo", "boxes": [], "lines": []}}"#;
631        assert!(matches!(detect_mode(json), SimMode::Rnbo));
632    }
633
634    #[test]
635    fn test_detect_mode_top_level_defaults_to_rnbo() {
636        let json = r#"{"patcher": {"boxes": [], "lines": []}}"#;
637        assert!(matches!(detect_mode(json), SimMode::Rnbo));
638    }
639
640    #[test]
641    fn test_detect_mode_ignores_string_field_collisions() {
642        // The phrase "classnamespace": "dsp.gen" appears here only inside a
643        // string-literal text field of a comment box, not as the patcher's
644        // own classnamespace. The old substring-based detection would have
645        // mis-classified this as Gen mode.
646        let json = r#"{
647            "patcher": {
648                "classnamespace": "rnbo",
649                "boxes": [
650                    {"box": {"id": "c", "maxclass": "comment",
651                             "text": "see \"classnamespace\": \"dsp.gen\" docs"}}
652                ],
653                "lines": []
654            }
655        }"#;
656        assert!(matches!(detect_mode(json), SimMode::Rnbo));
657    }
658
659    /// Verify `--signal-input` actually reaches `in~` in RNBO mode (F2).
660    #[test]
661    fn test_run_rnbo_applies_signal_input() {
662        // in~ 1 → * 2 → out~ 1
663        let json = r#"{
664            "patcher": {
665                "classnamespace": "rnbo",
666                "boxes": [
667                    {"box": {"id": "i", "maxclass": "newobj", "text": "in~ 1"}},
668                    {"box": {"id": "m", "maxclass": "newobj", "text": "* 2"}},
669                    {"box": {"id": "o", "maxclass": "newobj", "text": "out~ 1"}}
670                ],
671                "lines": [
672                    {"patchline": {"source": ["i", 0], "destination": ["m", 0]}},
673                    {"patchline": {"source": ["m", 0], "destination": ["o", 0]}}
674                ]
675            }
676        }"#;
677
678        let out = run_rnbo(json, &[], &[], &[], Some(0.3), 44100.0, 0.001).expect("run_rnbo");
679        // 0.3 input * 2 = 0.6 should appear at out~
680        assert!(
681            (out.channels[0][0] - 0.6).abs() < 1e-9,
682            "expected 0.6, got {}",
683            out.channels[0][0]
684        );
685    }
686}