Skip to main content

nam_rs/
model.rs

1//! Parsing of the on-disk `.nam` file format.
2//!
3//! A `.nam` file is a JSON object. The fields here mirror NAM's
4//! `export_config()` / `export_weights()` output (see crate-level attribution).
5//! WaveNet, LSTM, and SlimmableContainer architectures are parsed here (see
6//! [`ModelConfig`]);
7//! the runtime forward passes live in their own modules.
8
9use serde::de::{self, Deserializer};
10use serde::Deserialize;
11
12use crate::error::Error;
13
14/// How a layer-array's `activation` field was specified in the `.nam`.
15///
16/// NAM A1 writes a bare string (`"Tanh"`); A2 may write a dict
17/// (`{"type": "LeakyReLU", "negative_slope": 0.01}`). A per-layer *list* (a
18/// distinct activation per layer) is not modeled and is captured as
19/// [`ActivationSpec::Unsupported`], which the runtime rejects with
20/// [`crate::Error::UnsupportedFeature`] rather than silently mis-running.
21#[derive(Debug, Clone, PartialEq)]
22pub enum ActivationSpec {
23    /// A single named activation, with an optional negative slope (LeakyReLU).
24    Named {
25        /// Activation name, e.g. `"Tanh"`, `"ReLU"`, `"LeakyReLU"`.
26        name: String,
27        /// LeakyReLU negative slope, if the file specified one. `None` → the
28        /// runtime applies NAM's default of `0.01`.
29        negative_slope: Option<f32>,
30    },
31    /// A shape this crate does not model (e.g. a per-layer activation list).
32    Unsupported(serde_json::Value),
33}
34
35impl<'de> Deserialize<'de> for ActivationSpec {
36    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
37    where
38        D: Deserializer<'de>,
39    {
40        let v = serde_json::Value::deserialize(deserializer)?;
41        Ok(match &v {
42            serde_json::Value::String(s) => ActivationSpec::Named {
43                name: s.clone(),
44                negative_slope: None,
45            },
46            serde_json::Value::Object(map) => match map.get("type") {
47                Some(serde_json::Value::String(t)) => match map.get("negative_slope") {
48                    // Absent or explicit-null slope → runtime default (0.01).
49                    None | Some(serde_json::Value::Null) => ActivationSpec::Named {
50                        name: t.clone(),
51                        negative_slope: None,
52                    },
53                    // Present and numeric → use it.
54                    Some(slope) if slope.as_f64().is_some() => ActivationSpec::Named {
55                        name: t.clone(),
56                        negative_slope: slope.as_f64().map(|x| x as f32),
57                    },
58                    // Present but not a number → malformed; reject rather than silently
59                    // defaulting (a corrupt/upstream-format error must not pass silently).
60                    Some(_) => ActivationSpec::Unsupported(v.clone()),
61                },
62                _ => ActivationSpec::Unsupported(v),
63            },
64            _ => ActivationSpec::Unsupported(v),
65        })
66    }
67}
68
69/// Sample rate assumed when a `.nam` file omits the `sample_rate` field.
70///
71/// Matches NAM's documented default.
72pub const DEFAULT_SAMPLE_RATE: f64 = 48_000.0;
73
74/// A parsed `.nam` model file.
75///
76/// This is the *file representation* — the raw config + flat weight blob. To run
77/// inference, build a [`crate::WaveNet`] from it.
78#[derive(Debug, Clone)]
79pub struct NamModel {
80    /// `.nam` format version string (e.g. `"0.5.4"`).
81    pub version: String,
82    /// Model architecture, e.g. `"WaveNet"`.
83    pub architecture: String,
84    /// Architecture-specific configuration (dispatched on [`Self::architecture`]).
85    pub config: ModelConfig,
86    /// Flat weight blob. The final element is `head_scale` (see NAM
87    /// `export_weights`). Stored as `f32` to match NAM Core's inference precision.
88    pub weights: Vec<f32>,
89    /// Training sample rate. Absent in older files; see [`Self::expected_sample_rate`].
90    pub sample_rate: Option<f64>,
91    /// Opaque training/gear metadata. Not used for inference.
92    pub metadata: Option<serde_json::Value>,
93}
94
95/// LSTM configuration (NAM `_export_config`).
96#[derive(Debug, Clone, Deserialize)]
97pub struct LstmConfig {
98    /// Input width (1 for mono amp models).
99    pub input_size: usize,
100    /// Hidden state dimension `H`.
101    pub hidden_size: usize,
102    /// Number of stacked LSTM layers `L`.
103    pub num_layers: usize,
104}
105
106/// One entry in a [`SlimmableConfig`]: a complete standalone submodel plus the
107/// width-dial threshold at which it becomes active.
108#[derive(Debug, Clone, Deserialize)]
109pub struct SlimmableSubmodel {
110    /// Upper width-dial value this submodel covers (NAM Core `max_value`).
111    pub max_value: f32,
112    /// The submodel itself — a full standalone `.nam` of any architecture.
113    pub model: NamModel,
114}
115
116/// `SlimmableContainer` configuration: an ordered list of standalone submodels
117/// selected at runtime by a width dial. The container holds no weights of its own.
118#[derive(Debug, Clone, Deserialize)]
119pub struct SlimmableConfig {
120    /// Submodels in ascending `max_value` order; the last is the full-width model.
121    pub submodels: Vec<SlimmableSubmodel>,
122}
123
124/// Architecture-specific configuration, tagged by `NamModel.architecture`.
125#[derive(Debug, Clone)]
126pub enum ModelConfig {
127    /// WaveNet: a stack of dilated-convolution layer-arrays. Runnable via
128    /// [`crate::WaveNet`].
129    WaveNet(WaveNetConfig),
130    /// LSTM: stacked recurrent layers plus a linear head.
131    Lstm(LstmConfig),
132    /// SlimmableContainer: a width-selectable set of standalone submodels.
133    Slimmable(SlimmableConfig),
134}
135
136impl<'de> Deserialize<'de> for NamModel {
137    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
138    where
139        D: Deserializer<'de>,
140    {
141        // Parse the file shape with `config` left raw, then dispatch on
142        // `architecture` to type it. This reads the sibling `architecture` field,
143        // which `#[serde(deserialize_with)]` on a single field cannot do.
144        #[derive(Deserialize)]
145        struct Raw {
146            version: String,
147            architecture: String,
148            config: serde_json::Value,
149            weights: Vec<f32>,
150            #[serde(default)]
151            sample_rate: Option<f64>,
152            #[serde(default)]
153            metadata: Option<serde_json::Value>,
154        }
155
156        let raw = Raw::deserialize(deserializer)?;
157        let config = match raw.architecture.as_str() {
158            "WaveNet" => {
159                let raw_wn: RawWaveNetConfig =
160                    serde_json::from_value(raw.config).map_err(de::Error::custom)?;
161                ModelConfig::WaveNet(raw_wn.normalize().map_err(de::Error::custom)?)
162            }
163            "LSTM" => {
164                ModelConfig::Lstm(serde_json::from_value(raw.config).map_err(de::Error::custom)?)
165            }
166            "SlimmableContainer" => ModelConfig::Slimmable(
167                serde_json::from_value(raw.config).map_err(de::Error::custom)?,
168            ),
169            other => {
170                return Err(de::Error::custom(format!(
171                    "unsupported model architecture: {other:?}"
172                )))
173            }
174        };
175
176        Ok(NamModel {
177            version: raw.version,
178            architecture: raw.architecture,
179            config,
180            weights: raw.weights,
181            sample_rate: raw.sample_rate,
182            metadata: raw.metadata,
183        })
184    }
185}
186
187/// Loudness/level-calibration fields NAM may write into `metadata`. All optional;
188/// older or minimal files omit them. Unknown metadata keys are ignored.
189#[derive(Debug, Clone, Default, Deserialize)]
190pub struct Metadata {
191    /// Perceived loudness of the model's output, in LUFS (NAM's `loudness`).
192    #[serde(default)]
193    pub loudness: Option<f32>,
194    /// Analog level (dBu) corresponding to 0 dBFS at the model input.
195    #[serde(default)]
196    pub input_level_dbu: Option<f32>,
197    /// Analog level (dBu) corresponding to 0 dBFS at the model output.
198    #[serde(default)]
199    pub output_level_dbu: Option<f32>,
200}
201
202impl NamModel {
203    /// Read and parse a `.nam` model from a file on disk.
204    ///
205    /// Convenience over [`std::fs::read_to_string`] + [`Self::from_json_str`].
206    /// Returns [`Error::Io`] if the file can't be read, or [`Error::Json`] if its
207    /// contents aren't valid `.nam` JSON.
208    pub fn from_file(path: impl AsRef<std::path::Path>) -> Result<Self, Error> {
209        Self::from_json_str(&std::fs::read_to_string(path)?)
210    }
211
212    /// Parse a `.nam` model from a JSON string already in memory.
213    pub fn from_json_str(json: &str) -> Result<Self, Error> {
214        Ok(serde_json::from_str(json)?)
215    }
216
217    /// The sample rate, in Hz, the model expects its input to be at — falling back
218    /// to [`DEFAULT_SAMPLE_RATE`] when the file does not specify one.
219    ///
220    /// **You must feed the model audio at this rate.** `nam-rs` runs the forward pass
221    /// at whatever rate you hand it and does *not* resample. A model captured at one
222    /// rate fed audio at another produces silently wrong output: its dilations and
223    /// recurrence are defined in samples, not seconds. If your host runs at a
224    /// different rate, resample to this rate before [`crate::Model::process_buffer`]
225    /// and back afterwards — resampling is the caller's responsibility. Mirrors NAM
226    /// Core's `GetExpectedSampleRate()`.
227    #[must_use]
228    pub fn expected_sample_rate(&self) -> f64 {
229        self.sample_rate.unwrap_or(DEFAULT_SAMPLE_RATE)
230    }
231
232    /// The typed [`Metadata`] (loudness + calibration levels), parsed from the raw
233    /// `metadata` block in one shot. Returns defaults (all `None`) when there is no
234    /// metadata block or it lacks these keys; unknown keys are ignored.
235    ///
236    /// Prefer this over the single-field accessors ([`Self::loudness`], etc.) when you
237    /// want several fields: each single-field accessor re-clones and re-parses the raw
238    /// JSON, whereas this parses once. (All are cold-path / load-time, so neither is on
239    /// the audio thread.)
240    #[must_use]
241    pub fn metadata_typed(&self) -> Metadata {
242        match &self.metadata {
243            Some(v) => serde_json::from_value(v.clone()).unwrap_or_default(),
244            None => Metadata::default(),
245        }
246    }
247
248    /// Output loudness in LUFS, if the file records it.
249    #[must_use]
250    pub fn loudness(&self) -> Option<f32> {
251        self.metadata_typed().loudness
252    }
253
254    /// Input calibration level in dBu (analog level at 0 dBFS in), if present.
255    #[must_use]
256    pub fn input_level_dbu(&self) -> Option<f32> {
257        self.metadata_typed().input_level_dbu
258    }
259
260    /// Output calibration level in dBu (analog level at 0 dBFS out), if present.
261    #[must_use]
262    pub fn output_level_dbu(&self) -> Option<f32> {
263        self.metadata_typed().output_level_dbu
264    }
265}
266
267/// Activation gating mode for a WaveNet layer (NAMCore `GatingMode`).
268#[derive(Debug, Clone, Copy, PartialEq, Eq)]
269pub enum GatingMode {
270    /// No gating: `out = activation(z)`.
271    None,
272    /// Gated: `out = primary(z_a) * secondary(z_b)` (classic `tanh*sigmoid`).
273    Gated,
274    /// Blended: `out = α·primary(z_a) + (1-α)·z_a`, `α = secondary(z_b)`.
275    Blended,
276}
277
278impl GatingMode {
279    /// Parse a NAMCore gating-mode name (`"none"`/`"gated"`/`"blended"`).
280    pub(crate) fn from_name(s: &str) -> Result<Self, String> {
281        match s {
282            "none" => Ok(Self::None),
283            "gated" => Ok(Self::Gated),
284            "blended" => Ok(Self::Blended),
285            other => Err(format!("unknown gating_mode: {other:?}")),
286        }
287    }
288}
289
290/// A layer's residual 1×1 (`layer1x1`): maps the activated bottleneck back to
291/// `channels`. Active by default (the A1 `_1x1`).
292#[derive(Debug, Clone, Copy, PartialEq, Eq)]
293pub struct Layer1x1Config {
294    /// Whether the 1×1 is present (inactive ⇒ identity residual, needs bottleneck==channels).
295    pub active: bool,
296    /// Grouped-conv group count (1 = dense).
297    pub groups: usize,
298}
299
300/// Read an optional unsigned-int field off a JSON object as `usize`. `None` when
301/// the key is absent or not a non-negative integer. Centralizes the
302/// `get(key).and_then(as_u64).map(as usize)` shape used across the config decoders.
303fn opt_usize(o: &serde_json::Value, key: &str) -> Option<usize> {
304    o.get(key).and_then(|x| x.as_u64()).map(|x| x as usize)
305}
306
307impl Layer1x1Config {
308    pub(crate) fn from_json(v: Option<&serde_json::Value>) -> Self {
309        match v {
310            None => Self {
311                active: true,
312                groups: 1,
313            },
314            Some(o) => Self {
315                active: o.get("active").and_then(|x| x.as_bool()).unwrap_or(true),
316                groups: opt_usize(o, "groups").unwrap_or(1),
317            },
318        }
319    }
320}
321
322/// A layer's head 1×1 (`head1x1`): an optional 1×1 producing this layer's head
323/// contribution. Inactive by default (then the head contribution is the activated
324/// bottleneck directly).
325#[derive(Debug, Clone, Copy, PartialEq, Eq)]
326pub struct Head1x1Config {
327    /// Whether the head 1×1 is present.
328    pub active: bool,
329    /// Output channels (defaults to `channels` when active and unspecified).
330    pub out_channels: Option<usize>,
331    /// Grouped-conv group count.
332    pub groups: usize,
333}
334
335impl Head1x1Config {
336    pub(crate) fn from_json(v: Option<&serde_json::Value>) -> Self {
337        match v {
338            None => Self {
339                active: false,
340                out_channels: None,
341                groups: 1,
342            },
343            Some(o) => Self {
344                active: o.get("active").and_then(|x| x.as_bool()).unwrap_or(false),
345                out_channels: opt_usize(o, "out_channels"),
346                groups: opt_usize(o, "groups").unwrap_or(1),
347            },
348        }
349    }
350}
351
352/// One FiLM block (`*_pre_film` / `*_post_film`): conditions a scale (+ optional
353/// shift) from the conditioning signal. Absent or `false` ⇒ inactive.
354#[derive(Debug, Clone, Copy, PartialEq, Eq)]
355pub struct FilmConfig {
356    /// Whether this FiLM site is applied.
357    pub active: bool,
358    /// Whether it adds a shift term (else scale-only).
359    pub shift: bool,
360    /// Grouped-conv group count for the conditioning 1×1.
361    pub groups: usize,
362}
363
364impl FilmConfig {
365    /// The inactive default (absent key or explicit `false`).
366    pub const INACTIVE: Self = Self {
367        active: false,
368        shift: false,
369        groups: 1,
370    };
371
372    pub(crate) fn from_json(v: Option<&serde_json::Value>) -> Self {
373        match v {
374            None => Self::INACTIVE,
375            Some(serde_json::Value::Bool(false)) => Self::INACTIVE,
376            Some(o) => Self {
377                active: o.get("active").and_then(|x| x.as_bool()).unwrap_or(true),
378                shift: o.get("shift").and_then(|x| x.as_bool()).unwrap_or(true),
379                groups: opt_usize(o, "groups").unwrap_or(1),
380            },
381        }
382    }
383}
384
385/// Post-stack head (`config.head`): a stack of `activation → Conv1D` applied after
386/// the layer-arrays. `None` for A1 / current A2 defaults.
387#[derive(Debug, Clone)]
388pub struct PostStackHeadConfig {
389    /// Hidden channel count between head convs.
390    pub channels: usize,
391    /// Final output channels.
392    pub out_channels: usize,
393    /// Per-conv kernel sizes (one conv per entry).
394    pub kernel_sizes: Vec<usize>,
395    /// Activation applied before each head conv.
396    pub activation: ActivationSpec,
397}
398
399/// WaveNet configuration: layer-arrays, optional post-stack head + condition DSP,
400/// and the output scale. Per-layer quantities are normalized into `Vec`s.
401#[derive(Debug, Clone)]
402pub struct WaveNetConfig {
403    /// One config per layer-array.
404    pub layers: Vec<LayerArrayConfig>,
405    /// Optional post-stack head (`config.head`).
406    pub post_stack_head: Option<PostStackHeadConfig>,
407    /// Output gain (note: the runtime value is the trailing weight).
408    pub head_scale: f32,
409    /// Input channels (default 1).
410    pub in_channels: usize,
411    /// Optional nested conditioning DSP.
412    pub condition_dsp: Option<Box<NamModel>>,
413}
414
415#[derive(serde::Deserialize)]
416struct RawWaveNetConfig {
417    layers: Vec<RawLayerArrayConfig>,
418    #[serde(default)]
419    head: Option<serde_json::Value>,
420    head_scale: f32,
421    #[serde(default)]
422    in_channels: Option<usize>,
423    #[serde(default)]
424    condition_dsp: Option<serde_json::Value>,
425}
426
427impl RawWaveNetConfig {
428    fn normalize(self) -> Result<WaveNetConfig, String> {
429        let layers = self
430            .layers
431            .into_iter()
432            .map(RawLayerArrayConfig::normalize)
433            .collect::<Result<Vec<_>, _>>()?;
434
435        let post_stack_head = match self.head {
436            Some(h) if !h.is_null() => {
437                let channels =
438                    h.get("channels")
439                        .and_then(|x| x.as_u64())
440                        .ok_or("post-stack head missing channels")? as usize;
441                let out_channels = h
442                    .get("out_channels")
443                    .and_then(|x| x.as_u64())
444                    .ok_or("post-stack head missing out_channels")?
445                    as usize;
446                let kernel_sizes: Vec<usize> = h
447                    .get("kernel_sizes")
448                    .and_then(|x| x.as_array())
449                    .ok_or("post-stack head missing kernel_sizes")?
450                    .iter()
451                    .map(|k| {
452                        k.as_u64()
453                            .map(|v| v as usize)
454                            .ok_or("kernel_sizes entry not an int".to_string())
455                    })
456                    .collect::<Result<_, _>>()?;
457                let activation = serde_json::from_value::<ActivationSpec>(
458                    h.get("activation")
459                        .cloned()
460                        .unwrap_or(serde_json::Value::Null),
461                )
462                .map_err(|e| e.to_string())?;
463                Some(PostStackHeadConfig {
464                    channels,
465                    out_channels,
466                    kernel_sizes,
467                    activation,
468                })
469            }
470            _ => None,
471        };
472
473        let condition_dsp = match self.condition_dsp {
474            Some(v) if !v.is_null() => {
475                let m = serde_json::from_value::<NamModel>(v).map_err(|e| e.to_string())?;
476                Some(Box::new(m))
477            }
478            _ => None,
479        };
480
481        Ok(WaveNetConfig {
482            layers,
483            post_stack_head,
484            head_scale: self.head_scale,
485            in_channels: self.in_channels.unwrap_or(1),
486            condition_dsp,
487        })
488    }
489}
490
491/// Configuration for one WaveNet layer-array, normalized so every per-layer
492/// quantity is a `Vec` of length `dilations.len()`. Built from the on-disk JSON by
493/// the internal `RawLayerArrayConfig::normalize`; A1 files fill the A2 fields with
494/// defaults.
495#[derive(Debug, Clone)]
496pub struct LayerArrayConfig {
497    /// Input channels into the array (1 for the first array).
498    pub input_size: usize,
499    /// Conditioning signal width.
500    pub condition_size: usize,
501    /// Hidden channel count between layers.
502    pub channels: usize,
503    /// Internal per-layer width (defaults to `channels`).
504    pub bottleneck: usize,
505    /// Per-layer dilation factors; its length defines the number of layers.
506    pub dilations: Vec<usize>,
507    /// Per-layer dilated-conv kernel sizes (length == `dilations.len()`).
508    pub kernel_sizes: Vec<usize>,
509    /// Per-layer primary activations (length == `dilations.len()`).
510    pub activations: Vec<ActivationSpec>,
511    /// Per-layer gating modes (length == `dilations.len()`).
512    pub gating_modes: Vec<GatingMode>,
513    /// Per-layer secondary activations (for gating); element may be the default
514    /// (a `Named{"Sigmoid"}`) where unspecified. Length == `dilations.len()`.
515    pub secondary_activations: Vec<ActivationSpec>,
516    /// Grouped-conv groups for the dilated conv.
517    pub groups_input: usize,
518    /// Grouped-conv groups for the input mixer.
519    pub groups_input_mixin: usize,
520    /// Head rechannel output width.
521    pub head_size: usize,
522    /// Head rechannel kernel size (1 for A1; e.g. 16 for A2 conv heads).
523    pub head_kernel_size: usize,
524    /// Whether the head rechannel has a bias.
525    pub head_bias: bool,
526    /// Residual 1×1 config.
527    pub layer1x1: Layer1x1Config,
528    /// Head 1×1 config.
529    pub head1x1: Head1x1Config,
530    /// FiLM: applied to the layer input before the dilated conv.
531    pub conv_pre_film: FilmConfig,
532    /// FiLM: applied to the dilated-conv output.
533    pub conv_post_film: FilmConfig,
534    /// FiLM: applied to the conditioning before the input mixer.
535    pub input_mixin_pre_film: FilmConfig,
536    /// FiLM: applied to the input-mixer output.
537    pub input_mixin_post_film: FilmConfig,
538    /// FiLM: applied to the conv+mixin sum before activation.
539    pub activation_pre_film: FilmConfig,
540    /// FiLM: applied to the activation output.
541    pub activation_post_film: FilmConfig,
542    /// FiLM: applied to the layer1x1 output (BLENDED branch only, per NAMCore).
543    pub layer1x1_post_film: FilmConfig,
544    /// FiLM: applied to the head1x1 output.
545    pub head1x1_post_film: FilmConfig,
546}
547
548impl LayerArrayConfig {
549    /// The array's uniform gating mode.
550    ///
551    /// `normalize()` produces one `gating_modes` entry per layer, and the runtime
552    /// guards that they are all equal before building (mixed modes are an
553    /// `UnsupportedFeature`). This accessor encapsulates that post-guard invariant
554    /// — the single uniform mode — instead of indexing `gating_modes[0]` at each use
555    /// site (which would panic on a directly-constructed empty-vec config, since the
556    /// struct is `pub`). Returns [`GatingMode::None`] for an empty list.
557    pub fn gating_mode(&self) -> GatingMode {
558        self.gating_modes
559            .first()
560            .copied()
561            .unwrap_or(GatingMode::None)
562    }
563}
564
565/// On-disk shape of a layer-array config: optional / either-or fields exactly as
566/// NAM writes them. Converted to [`LayerArrayConfig`] by [`Self::normalize`].
567#[derive(Debug, Clone, serde::Deserialize)]
568pub(crate) struct RawLayerArrayConfig {
569    input_size: usize,
570    condition_size: usize,
571    channels: usize,
572    #[serde(default)]
573    bottleneck: Option<usize>,
574    dilations: Vec<usize>,
575    #[serde(default)]
576    kernel_size: Option<usize>,
577    #[serde(default)]
578    kernel_sizes: Option<Vec<usize>>,
579    activation: serde_json::Value,
580    #[serde(default)]
581    gating_mode: Option<serde_json::Value>,
582    #[serde(default)]
583    gated: Option<bool>,
584    #[serde(default)]
585    secondary_activation: Option<serde_json::Value>,
586    #[serde(default)]
587    groups_input: Option<usize>,
588    #[serde(default)]
589    groups_input_mixin: Option<usize>,
590    #[serde(default)]
591    head: Option<serde_json::Value>,
592    #[serde(default)]
593    head_size: Option<usize>,
594    #[serde(default)]
595    head_bias: Option<bool>,
596    #[serde(default)]
597    layer1x1: Option<serde_json::Value>,
598    #[serde(default)]
599    head1x1: Option<serde_json::Value>,
600    #[serde(default)]
601    conv_pre_film: Option<serde_json::Value>,
602    #[serde(default)]
603    conv_post_film: Option<serde_json::Value>,
604    #[serde(default)]
605    input_mixin_pre_film: Option<serde_json::Value>,
606    #[serde(default)]
607    input_mixin_post_film: Option<serde_json::Value>,
608    #[serde(default)]
609    activation_pre_film: Option<serde_json::Value>,
610    #[serde(default)]
611    activation_post_film: Option<serde_json::Value>,
612    #[serde(default)]
613    layer1x1_post_film: Option<serde_json::Value>,
614    #[serde(default)]
615    head1x1_post_film: Option<serde_json::Value>,
616}
617
618impl RawLayerArrayConfig {
619    pub(crate) fn normalize(self) -> Result<LayerArrayConfig, String> {
620        let n = self.dilations.len();
621        if n == 0 {
622            return Err("layer-array has no dilations".into());
623        }
624
625        let kernel_sizes = match (self.kernel_size, self.kernel_sizes) {
626            (Some(_), Some(_)) => {
627                return Err("layer-array specifies both kernel_size and kernel_sizes".into())
628            }
629            (Some(k), None) => vec![k; n],
630            (None, Some(ks)) => {
631                if ks.len() != n {
632                    return Err(format!(
633                        "kernel_sizes length {} != number of layers {n}",
634                        ks.len()
635                    ));
636                }
637                ks
638            }
639            (None, None) => {
640                return Err("layer-array specifies neither kernel_size nor kernel_sizes".into())
641            }
642        };
643
644        let activations = broadcast_activations(&self.activation, n)?;
645
646        // `gating_mode` (A2, per-layer enum) supersedes the legacy boolean `gated`
647        // (A1) when both are present: the richer field wins silently. In practice a
648        // file carries one or the other, so the conflict is theoretical.
649        let gating_modes = match (&self.gating_mode, self.gated) {
650            (Some(v), _) => broadcast_gating(v, n)?,
651            (None, Some(true)) => vec![GatingMode::Gated; n],
652            (None, _) => vec![GatingMode::None; n],
653        };
654
655        let secondary_activations = match &self.secondary_activation {
656            Some(v) => broadcast_secondary(v, n)?,
657            None => vec![default_sigmoid(); n],
658        };
659
660        let (head_size, head_kernel_size, head_bias) = match &self.head {
661            Some(h) if !h.is_null() => {
662                let out = h
663                    .get("out_channels")
664                    .and_then(|x| x.as_u64())
665                    .ok_or("layer head missing out_channels")? as usize;
666                let k = h
667                    .get("kernel_size")
668                    .and_then(|x| x.as_u64())
669                    .ok_or("layer head missing kernel_size")? as usize;
670                // NAMCore requires `bias` on a nested head object (`.at("bias")`
671                // throws if absent). We're leniently defaulting to `true` — the value
672                // every real exporter writes — so a hand-edited file missing it still
673                // loads with the NAMCore-default behavior rather than erroring.
674                let bias = h.get("bias").and_then(|x| x.as_bool()).unwrap_or(true);
675                (out, k, bias)
676            }
677            _ => {
678                let hs = self
679                    .head_size
680                    .ok_or("layer-array missing head_size (and no head object)")?;
681                (hs, 1, self.head_bias.unwrap_or(false))
682            }
683        };
684
685        // Reject zero/degenerate dimensions before they reach the runtime: a
686        // `head_kernel_size == 0` underflows `head_kernel_size - 1` and overflows
687        // the `Conv1d` ring buffer; zero kernel/dilation/channel counts likewise
688        // produce nonsense buffers. NAMCore rejects `head_kernel_size < 1`
689        // (`wavenet/model.cpp`); mirror that here as a clean `Err`, not a panic.
690        if head_kernel_size == 0 {
691            return Err("layer-array head_kernel_size must be >= 1".into());
692        }
693        if self.channels == 0 {
694            return Err("layer-array channels must be >= 1".into());
695        }
696        if head_size == 0 {
697            return Err("layer-array head_size must be >= 1".into());
698        }
699        if kernel_sizes.contains(&0) {
700            return Err("layer-array kernel_sizes entries must be >= 1".into());
701        }
702        if self.dilations.contains(&0) {
703            return Err("layer-array dilations entries must be >= 1".into());
704        }
705        let bottleneck = self.bottleneck.unwrap_or(self.channels);
706        if bottleneck == 0 {
707            return Err("layer-array bottleneck must be >= 1".into());
708        }
709
710        let groups_input = self.groups_input.unwrap_or(1);
711        let groups_input_mixin = self.groups_input_mixin.unwrap_or(1);
712        let layer1x1 = Layer1x1Config::from_json(self.layer1x1.as_ref());
713        let head1x1 = Head1x1Config::from_json(self.head1x1.as_ref());
714        let films = [
715            FilmConfig::from_json(self.conv_pre_film.as_ref()),
716            FilmConfig::from_json(self.conv_post_film.as_ref()),
717            FilmConfig::from_json(self.input_mixin_pre_film.as_ref()),
718            FilmConfig::from_json(self.input_mixin_post_film.as_ref()),
719            FilmConfig::from_json(self.activation_pre_film.as_ref()),
720            FilmConfig::from_json(self.activation_post_film.as_ref()),
721            FilmConfig::from_json(self.layer1x1_post_film.as_ref()),
722            FilmConfig::from_json(self.head1x1_post_film.as_ref()),
723        ];
724        // Grouped-conv group counts must be >= 1: a zero divides by zero when the
725        // runtime lays out the block-diagonal weight tensor. (Divisibility of the
726        // channel dims by the group count is checked in `array_weight_count`, which
727        // knows every dim.) Reject here as a clean `Err`, not a panic — mirroring
728        // NAMCore's `% groups` precondition.
729        let group_counts = [
730            ("groups_input", groups_input),
731            ("groups_input_mixin", groups_input_mixin),
732            ("layer1x1.groups", layer1x1.groups),
733            ("head1x1.groups", head1x1.groups),
734            (
735                "film.groups",
736                films.iter().map(|f| f.groups).min().unwrap_or(1),
737            ),
738        ];
739        for (name, g) in group_counts {
740            if g == 0 {
741                return Err(format!("layer-array {name} must be >= 1"));
742            }
743        }
744        let [conv_pre_film, conv_post_film, input_mixin_pre_film, input_mixin_post_film, activation_pre_film, activation_post_film, layer1x1_post_film, head1x1_post_film] =
745            films;
746
747        Ok(LayerArrayConfig {
748            input_size: self.input_size,
749            condition_size: self.condition_size,
750            channels: self.channels,
751            bottleneck,
752            dilations: self.dilations,
753            kernel_sizes,
754            activations,
755            gating_modes,
756            secondary_activations,
757            groups_input,
758            groups_input_mixin,
759            head_size,
760            head_kernel_size,
761            head_bias,
762            layer1x1,
763            head1x1,
764            conv_pre_film,
765            conv_post_film,
766            input_mixin_pre_film,
767            input_mixin_post_film,
768            activation_pre_film,
769            activation_post_film,
770            layer1x1_post_film,
771            head1x1_post_film,
772        })
773    }
774}
775
776/// A `Named{"Sigmoid"}` activation, the gating secondary default.
777fn default_sigmoid() -> ActivationSpec {
778    ActivationSpec::Named {
779        name: "Sigmoid".into(),
780        negative_slope: None,
781    }
782}
783
784/// Broadcast a single activation or expand a per-layer list to length `n`.
785/// Expand a per-layer field to length `n`: a JSON array must already be exactly
786/// `n` long (each element parsed by `parse`); any other (scalar) value is parsed
787/// once and broadcast to all `n` layers. `kind` names the field in length errors.
788fn broadcast<T: Clone>(
789    v: &serde_json::Value,
790    n: usize,
791    kind: &str,
792    parse: impl Fn(&serde_json::Value) -> Result<T, String>,
793) -> Result<Vec<T>, String> {
794    match v {
795        serde_json::Value::Array(items) => {
796            if items.len() != n {
797                return Err(format!(
798                    "{kind} list length {} != number of layers {n}",
799                    items.len()
800                ));
801            }
802            items.iter().map(&parse).collect()
803        }
804        other => Ok(vec![parse(other)?; n]),
805    }
806}
807
808fn parse_activation(e: &serde_json::Value) -> Result<ActivationSpec, String> {
809    serde_json::from_value::<ActivationSpec>(e.clone()).map_err(|e| e.to_string())
810}
811
812fn broadcast_activations(v: &serde_json::Value, n: usize) -> Result<Vec<ActivationSpec>, String> {
813    broadcast(v, n, "activation", parse_activation)
814}
815
816/// Broadcast/expand `secondary_activation`; JSON `null` elements become the
817/// Sigmoid default (NAMCore's default secondary).
818fn broadcast_secondary(v: &serde_json::Value, n: usize) -> Result<Vec<ActivationSpec>, String> {
819    broadcast(v, n, "secondary_activation", |e| {
820        if e.is_null() {
821            Ok(default_sigmoid())
822        } else {
823            parse_activation(e)
824        }
825    })
826}
827
828/// Broadcast a single gating name or expand a per-layer list to length `n`.
829fn broadcast_gating(v: &serde_json::Value, n: usize) -> Result<Vec<GatingMode>, String> {
830    broadcast(v, n, "gating_mode", |e| {
831        e.as_str()
832            .ok_or_else(|| "gating_mode entry is not a string".to_string())
833            .and_then(GatingMode::from_name)
834    })
835}
836
837#[cfg(test)]
838mod layer_array_normalize_tests {
839    use super::*;
840
841    fn norm(v: serde_json::Value) -> LayerArrayConfig {
842        let raw: RawLayerArrayConfig = serde_json::from_value(v).unwrap();
843        raw.normalize().unwrap()
844    }
845
846    #[test]
847    fn a1_layer_broadcasts_scalar_kernel_and_string_activation() {
848        let la = norm(serde_json::json!({
849            "input_size": 1, "condition_size": 1, "channels": 2, "head_size": 1,
850            "kernel_size": 3, "dilations": [1, 2, 4], "activation": "Tanh",
851            "gated": false, "head_bias": false
852        }));
853        assert_eq!(la.channels, 2);
854        assert_eq!(la.bottleneck, 2);
855        assert_eq!(la.kernel_sizes, vec![3, 3, 3]);
856        assert_eq!(la.gating_modes, vec![GatingMode::None; 3]);
857        assert_eq!(la.head_size, 1);
858        assert_eq!(la.head_kernel_size, 1);
859        assert!(!la.head_bias);
860        assert!(la.layer1x1.active);
861        assert!(!la.head1x1.active);
862        assert_eq!(la.groups_input, 1);
863        assert_eq!(la.activations.len(), 3);
864        assert!(matches!(&la.activations[0], ActivationSpec::Named { name, .. } if name == "Tanh"));
865        let g = norm(serde_json::json!({
866            "input_size": 1, "condition_size": 1, "channels": 2, "head_size": 1,
867            "kernel_size": 3, "dilations": [1], "activation": "Tanh",
868            "gated": true, "head_bias": true
869        }));
870        assert_eq!(g.gating_modes, vec![GatingMode::Gated]);
871    }
872
873    #[test]
874    fn a2_flexible_layer_parses_per_layer_vectors_and_nested_head() {
875        let la = norm(serde_json::json!({
876            "input_size": 1, "condition_size": 1, "channels": 3, "bottleneck": 3,
877            "dilations": [1, 3, 7],
878            "kernel_sizes": [6, 6, 15],
879            "activation": [
880                {"type": "LeakyReLU", "negative_slope": 0.01},
881                {"type": "LeakyReLU", "negative_slope": 0.01},
882                {"type": "LeakyReLU", "negative_slope": 0.01}
883            ],
884            "head": {"out_channels": 1, "kernel_size": 16, "bias": true},
885            "head1x1": {"active": false, "out_channels": 1, "groups": 1},
886            "layer1x1": {"active": true, "groups": 1},
887            "groups_input": 1, "groups_input_mixin": 1,
888            "gating_mode": ["none", "none", "none"],
889            "secondary_activation": [null, null, null],
890            "conv_pre_film": {"active": false, "shift": true, "groups": 1}
891        }));
892        assert_eq!(la.kernel_sizes, vec![6, 6, 15]);
893        assert_eq!(la.gating_modes, vec![GatingMode::None; 3]);
894        assert_eq!(la.head_size, 1);
895        assert_eq!(la.head_kernel_size, 16);
896        assert!(la.head_bias);
897        assert_eq!(la.bottleneck, 3);
898        assert_eq!(la.activations.len(), 3);
899        assert!(!la.conv_pre_film.active);
900    }
901
902    #[test]
903    fn both_kernel_forms_is_an_error() {
904        let raw: RawLayerArrayConfig = serde_json::from_value(serde_json::json!({
905            "input_size": 1, "condition_size": 1, "channels": 1, "head_size": 1,
906            "kernel_size": 3, "kernel_sizes": [3], "dilations": [1],
907            "activation": "Tanh", "gated": false, "head_bias": false
908        }))
909        .unwrap();
910        assert!(raw.normalize().is_err());
911    }
912
913    #[test]
914    fn kernel_sizes_length_mismatch_is_an_error() {
915        let raw: RawLayerArrayConfig = serde_json::from_value(serde_json::json!({
916            "input_size": 1, "condition_size": 1, "channels": 1, "head_size": 1,
917            "kernel_sizes": [3, 3], "dilations": [1],
918            "activation": "Tanh", "gated": false, "head_bias": false
919        }))
920        .unwrap();
921        assert!(raw.normalize().is_err());
922    }
923
924    #[test]
925    fn activation_list_length_mismatch_is_an_error() {
926        let raw: RawLayerArrayConfig = serde_json::from_value(serde_json::json!({
927            "input_size": 1, "condition_size": 1, "channels": 1, "head_size": 1,
928            "kernel_size": 3, "dilations": [1, 2],
929            "activation": ["Tanh"], "gated": false, "head_bias": false
930        }))
931        .unwrap();
932        assert!(raw.normalize().is_err());
933    }
934
935    /// Build a minimal-but-valid raw layer-array, then apply `mutate` so each
936    /// degenerate-dimension test only has to express the one field it breaks.
937    fn raw_layer_array(mutate: impl FnOnce(&mut serde_json::Value)) -> RawLayerArrayConfig {
938        let mut v = serde_json::json!({
939            "input_size": 1, "condition_size": 1, "channels": 1, "head_size": 1,
940            "kernel_size": 3, "dilations": [1],
941            "activation": "Tanh", "gated": false, "head_bias": false
942        });
943        mutate(&mut v);
944        serde_json::from_value(v).unwrap()
945    }
946
947    #[test]
948    fn baseline_raw_layer_array_normalizes() {
949        // Guard: the helper's unmutated form must be valid, else the negative
950        // tests below would pass for the wrong reason.
951        assert!(raw_layer_array(|_| {}).normalize().is_ok());
952    }
953
954    #[test]
955    fn zero_channels_is_an_error() {
956        let raw = raw_layer_array(|v| v["channels"] = serde_json::json!(0));
957        assert!(raw.normalize().is_err());
958    }
959
960    #[test]
961    fn zero_head_size_is_an_error() {
962        let raw = raw_layer_array(|v| v["head_size"] = serde_json::json!(0));
963        assert!(raw.normalize().is_err());
964    }
965
966    #[test]
967    fn zero_kernel_size_is_an_error() {
968        let raw = raw_layer_array(|v| v["kernel_size"] = serde_json::json!(0));
969        assert!(raw.normalize().is_err());
970    }
971
972    #[test]
973    fn zero_dilation_is_an_error() {
974        let raw = raw_layer_array(|v| v["dilations"] = serde_json::json!([0]));
975        assert!(raw.normalize().is_err());
976    }
977
978    #[test]
979    fn zero_bottleneck_is_an_error() {
980        // An explicit `bottleneck == 0` yields `mid == 0` and degenerate conv buffers;
981        // reject it like the other zero dims (it is otherwise unguarded since it
982        // defaults to `channels` when absent).
983        let raw = raw_layer_array(|v| v["bottleneck"] = serde_json::json!(0));
984        assert!(raw.normalize().is_err());
985    }
986
987    #[test]
988    fn zero_groups_is_an_error() {
989        // `groups == 0` would divide-by-zero in the block-diagonal weight layout.
990        for field in ["groups_input", "groups_input_mixin"] {
991            let raw = raw_layer_array(|v| v[field] = serde_json::json!(0));
992            assert!(raw.normalize().is_err(), "{field} == 0 must error");
993        }
994        let raw = raw_layer_array(|v| {
995            v["layer1x1"] = serde_json::json!({ "active": true, "groups": 0 });
996        });
997        assert!(raw.normalize().is_err(), "layer1x1.groups == 0 must error");
998    }
999
1000    #[test]
1001    fn zero_head_kernel_size_is_an_error() {
1002        let raw = raw_layer_array(|v| {
1003            v.as_object_mut().unwrap().remove("head_size");
1004            v["head"] = serde_json::json!({
1005                "out_channels": 1, "kernel_size": 0, "activation": "ReLU"
1006            });
1007        });
1008        assert!(raw.normalize().is_err());
1009    }
1010}
1011
1012#[cfg(test)]
1013mod a2_subconfig_tests {
1014    use super::*;
1015
1016    #[test]
1017    fn gating_mode_from_str() {
1018        assert_eq!(GatingMode::from_name("none").unwrap(), GatingMode::None);
1019        assert_eq!(GatingMode::from_name("gated").unwrap(), GatingMode::Gated);
1020        assert_eq!(
1021            GatingMode::from_name("blended").unwrap(),
1022            GatingMode::Blended
1023        );
1024        assert!(GatingMode::from_name("wat").is_err());
1025    }
1026
1027    #[test]
1028    fn film_absent_or_false_is_inactive() {
1029        assert_eq!(FilmConfig::from_json(None), FilmConfig::INACTIVE);
1030        assert_eq!(
1031            FilmConfig::from_json(Some(&serde_json::json!(false))),
1032            FilmConfig::INACTIVE
1033        );
1034    }
1035
1036    #[test]
1037    fn film_object_defaults_active_shift_groups() {
1038        let v = serde_json::json!({});
1039        let f = FilmConfig::from_json(Some(&v));
1040        assert_eq!(
1041            f,
1042            FilmConfig {
1043                active: true,
1044                shift: true,
1045                groups: 1
1046            }
1047        );
1048        let v = serde_json::json!({"active": false, "shift": false, "groups": 2});
1049        assert_eq!(
1050            FilmConfig::from_json(Some(&v)),
1051            FilmConfig {
1052                active: false,
1053                shift: false,
1054                groups: 2
1055            }
1056        );
1057    }
1058
1059    #[test]
1060    fn layer1x1_defaults_active_true_groups_1() {
1061        assert_eq!(
1062            Layer1x1Config::from_json(None),
1063            Layer1x1Config {
1064                active: true,
1065                groups: 1
1066            }
1067        );
1068        let v = serde_json::json!({"active": true, "groups": 1});
1069        assert_eq!(
1070            Layer1x1Config::from_json(Some(&v)),
1071            Layer1x1Config {
1072                active: true,
1073                groups: 1
1074            }
1075        );
1076    }
1077
1078    #[test]
1079    fn head1x1_defaults_inactive() {
1080        let h = Head1x1Config::from_json(None);
1081        assert_eq!(
1082            h,
1083            Head1x1Config {
1084                active: false,
1085                out_channels: None,
1086                groups: 1
1087            }
1088        );
1089        let v = serde_json::json!({"active": false, "out_channels": 1, "groups": 1});
1090        assert_eq!(
1091            Head1x1Config::from_json(Some(&v)),
1092            Head1x1Config {
1093                active: false,
1094                out_channels: Some(1),
1095                groups: 1
1096            }
1097        );
1098    }
1099}
1100
1101#[cfg(test)]
1102mod wavenet_config_tests {
1103    use super::*;
1104
1105    fn parse(json: &str) -> WaveNetConfig {
1106        match NamModel::from_json_str(json).unwrap().config {
1107            ModelConfig::WaveNet(c) => c,
1108            other => panic!("expected WaveNet, got {other:?}"),
1109        }
1110    }
1111
1112    #[test]
1113    fn a1_config_parses_unchanged() {
1114        let c = parse(
1115            r#"{
1116            "version":"0.5.4","architecture":"WaveNet","config":{
1117                "layers":[{"input_size":1,"condition_size":1,"channels":2,"head_size":1,
1118                    "kernel_size":3,"dilations":[1,2],"activation":"Tanh",
1119                    "gated":false,"head_bias":false}],
1120                "head":null,"head_scale":2.0},
1121            "weights":[]}"#,
1122        );
1123        assert_eq!(c.layers.len(), 1);
1124        assert_eq!(c.head_scale, 2.0);
1125        assert!(c.post_stack_head.is_none());
1126        assert!(c.condition_dsp.is_none());
1127        assert_eq!(c.layers[0].kernel_sizes, vec![3, 3]);
1128    }
1129
1130    #[test]
1131    fn a2_flexible_container_submodel_config_parses() {
1132        let c = parse(
1133            r#"{
1134            "version":"0.7.0","architecture":"WaveNet","config":{
1135                "layers":[{"input_size":1,"condition_size":1,"channels":3,"bottleneck":3,
1136                    "dilations":[1,3,7],"kernel_sizes":[6,6,15],
1137                    "activation":[{"type":"LeakyReLU"},{"type":"LeakyReLU"},{"type":"LeakyReLU"}],
1138                    "head":{"out_channels":1,"kernel_size":16,"bias":true},
1139                    "head1x1":{"active":false},"layer1x1":{"active":true,"groups":1},
1140                    "gating_mode":["none","none","none"]}],
1141                "head":null,"head_scale":0.5},
1142            "weights":[]}"#,
1143        );
1144        assert_eq!(c.layers[0].head_kernel_size, 16);
1145        assert_eq!(c.layers[0].kernel_sizes, vec![6, 6, 15]);
1146        assert!(c.post_stack_head.is_none());
1147    }
1148
1149    #[test]
1150    fn post_stack_head_parses() {
1151        let c = parse(
1152            r#"{
1153            "version":"0.6.0","architecture":"WaveNet","config":{
1154                "layers":[{"input_size":1,"condition_size":1,"channels":2,"head_size":2,
1155                    "kernel_size":3,"dilations":[1],"activation":"Tanh",
1156                    "gated":false,"head_bias":false}],
1157                "head":{"channels":4,"out_channels":1,"kernel_sizes":[1,1],"activation":"ReLU"},
1158                "head_scale":1.0},
1159            "weights":[]}"#,
1160        );
1161        let h = c.post_stack_head.expect("post-stack head present");
1162        assert_eq!(h.channels, 4);
1163        assert_eq!(h.out_channels, 1);
1164        assert_eq!(h.kernel_sizes, vec![1, 1]);
1165    }
1166
1167    #[test]
1168    fn condition_dsp_parses_as_nested_model() {
1169        let c = parse(
1170            r#"{
1171            "version":"0.6.0","architecture":"WaveNet","config":{
1172                "layers":[{"input_size":1,"condition_size":1,"channels":2,"head_size":1,
1173                    "kernel_size":3,"dilations":[1],"activation":"Tanh",
1174                    "gated":false,"head_bias":false}],
1175                "head":null,"head_scale":1.0,
1176                "condition_dsp":{"version":"0.5.4","architecture":"WaveNet","config":{
1177                    "layers":[{"input_size":1,"condition_size":1,"channels":1,"head_size":1,
1178                        "kernel_size":1,"dilations":[1],"activation":"Tanh",
1179                        "gated":false,"head_bias":false}],
1180                    "head":null,"head_scale":1.0},"weights":[]}},
1181            "weights":[]}"#,
1182        );
1183        let dsp = c.condition_dsp.expect("condition_dsp present");
1184        assert_eq!(dsp.architecture, "WaveNet");
1185    }
1186}