use nam_rs::{ModelConfig, NamModel, DEFAULT_SAMPLE_RATE};
const MINIMAL_WAVENET: &str = r#"{
"version": "0.5.4",
"architecture": "WaveNet",
"config": {
"layers": [
{
"input_size": 1,
"condition_size": 1,
"channels": 2,
"head_size": 1,
"kernel_size": 3,
"dilations": [1, 2],
"activation": "Tanh",
"gated": false,
"head_bias": false
}
],
"head": null,
"head_scale": 0.5
},
"weights": [0.1, -0.2, 0.3]
}"#;
#[test]
fn parses_minimal_wavenet_config() {
let model = NamModel::from_json_str(MINIMAL_WAVENET).expect("should parse");
assert_eq!(model.version, "0.5.4");
assert_eq!(model.architecture, "WaveNet");
let cfg = match &model.config {
ModelConfig::WaveNet(c) => c,
other => panic!("expected WaveNet config, got {other:?}"),
};
let layers = &cfg.layers;
assert_eq!(layers.len(), 1);
let layer = &layers[0];
assert_eq!(layer.channels, 2);
assert_eq!(layer.kernel_size, 3);
assert_eq!(layer.dilations, vec![1, 2]);
assert_eq!(layer.activation, "Tanh");
assert!(!layer.gated);
assert!(!layer.head_bias);
assert!((cfg.head_scale - 0.5).abs() < 1e-9);
assert_eq!(model.weights, vec![0.1_f32, -0.2, 0.3]);
}
#[test]
fn sample_rate_defaults_to_48k_when_absent() {
let model = NamModel::from_json_str(MINIMAL_WAVENET).expect("should parse");
assert!(model.sample_rate.is_none());
assert!((model.sample_rate() - DEFAULT_SAMPLE_RATE).abs() < 1e-9);
}
#[test]
fn sample_rate_is_read_when_present() {
let json = MINIMAL_WAVENET.replace(
"\"weights\": [0.1, -0.2, 0.3]",
"\"sample_rate\": 44100.0, \"weights\": [0.1, -0.2, 0.3]",
);
let model = NamModel::from_json_str(&json).expect("should parse");
assert!((model.sample_rate() - 44100.0).abs() < 1e-9);
}
#[test]
fn rejects_malformed_json() {
assert!(NamModel::from_json_str("{ not json").is_err());
}
const WITH_METADATA: &str = r#"{
"version": "0.5.4",
"architecture": "WaveNet",
"config": {
"layers": [{
"input_size": 1, "condition_size": 1, "channels": 1, "head_size": 1,
"kernel_size": 1, "dilations": [1], "activation": "ReLU",
"gated": false, "head_bias": false
}],
"head": null, "head_scale": 1.0
},
"weights": [1.0, 2.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0],
"metadata": {
"loudness": -20.02, "input_level_dbu": 18.3, "output_level_dbu": 12.3,
"name": "Test", "gear_type": "amp"
}
}"#;
#[test]
fn parses_loudness_and_calibration_metadata() {
let m = NamModel::from_json_str(WITH_METADATA).expect("parse");
let approx = |got: Option<f32>, want: f32| (got.expect("present") - want).abs() < 1e-4;
assert!(approx(m.loudness(), -20.02));
assert!(approx(m.input_level_dbu(), 18.3));
assert!(approx(m.output_level_dbu(), 12.3));
}
#[test]
fn metadata_absent_yields_none() {
let m = NamModel::from_json_str(MINIMAL_WAVENET).expect("parse");
assert_eq!(m.loudness(), None);
assert_eq!(m.input_level_dbu(), None);
assert_eq!(m.output_level_dbu(), None);
}
#[test]
fn unrelated_metadata_keys_are_ignored() {
let json = WITH_METADATA.replace(
"\"loudness\": -20.02, \"input_level_dbu\": 18.3, \"output_level_dbu\": 12.3,",
"",
);
let m = NamModel::from_json_str(&json).expect("parse");
assert_eq!(m.loudness(), None);
assert_eq!(m.input_level_dbu(), None);
assert_eq!(m.output_level_dbu(), None);
}
const MINIMAL_LSTM: &str = r#"{
"version": "0.5.4",
"architecture": "LSTM",
"config": { "input_size": 1, "hidden_size": 8, "num_layers": 1 },
"weights": [0.0],
"sample_rate": 44100.0
}"#;
#[test]
fn parses_lstm_config() {
let m = NamModel::from_json_str(MINIMAL_LSTM).expect("parse LSTM");
assert_eq!(m.architecture, "LSTM");
match &m.config {
ModelConfig::Lstm(c) => {
assert_eq!(c.input_size, 1);
assert_eq!(c.hidden_size, 8);
assert_eq!(c.num_layers, 1);
}
other => panic!("expected Lstm config, got {other:?}"),
}
assert_eq!(m.sample_rate(), 44100.0);
}
#[test]
fn wavenet_config_still_parses_through_enum() {
let m = NamModel::from_json_str(MINIMAL_WAVENET).expect("parse");
match &m.config {
ModelConfig::WaveNet(c) => assert_eq!(c.layers.len(), 1),
other => panic!("expected WaveNet config, got {other:?}"),
}
}
#[test]
fn unknown_architecture_fails_to_parse() {
let json = MINIMAL_WAVENET.replace("\"WaveNet\"", "\"Transformer\"");
let err = NamModel::from_json_str(&json).unwrap_err();
assert!(
format!("{err}").contains("Transformer"),
"error should name the bad architecture: {err}"
);
}