Skip to main content

nam_rs/
model.rs

1//! Parsing of the on-disk `.nam` file format.
2//!
3//! A `.nam` file is a JSON object. The fields here mirror NAM's
4//! `export_config()` / `export_weights()` output (see crate-level attribution).
5//! Both the WaveNet and LSTM architectures are parsed here (see [`ModelConfig`]);
6//! the runtime forward passes live in their own modules.
7
8use serde::de::{self, Deserializer};
9use serde::Deserialize;
10
11use crate::error::Error;
12
13/// Sample rate assumed when a `.nam` file omits the `sample_rate` field.
14///
15/// Matches NAM's documented default.
16pub const DEFAULT_SAMPLE_RATE: f64 = 48_000.0;
17
18/// A parsed `.nam` model file.
19///
20/// This is the *file representation* — the raw config + flat weight blob. To run
21/// inference, build a [`crate::WaveNet`] from it.
22#[derive(Debug, Clone)]
23pub struct NamModel {
24    /// `.nam` format version string (e.g. `"0.5.4"`).
25    pub version: String,
26    /// Model architecture, e.g. `"WaveNet"`.
27    pub architecture: String,
28    /// Architecture-specific configuration (dispatched on [`Self::architecture`]).
29    pub config: ModelConfig,
30    /// Flat weight blob. The final element is `head_scale` (see NAM
31    /// `export_weights`). Stored as `f32` to match NAM Core's inference precision.
32    pub weights: Vec<f32>,
33    /// Training sample rate. Absent in older files; see [`Self::sample_rate`].
34    pub sample_rate: Option<f64>,
35    /// Opaque training/gear metadata. Not used for inference.
36    pub metadata: Option<serde_json::Value>,
37}
38
39/// LSTM configuration (NAM `_export_config`).
40#[derive(Debug, Clone, Deserialize)]
41pub struct LstmConfig {
42    /// Input width (1 for mono amp models).
43    pub input_size: usize,
44    /// Hidden state dimension `H`.
45    pub hidden_size: usize,
46    /// Number of stacked LSTM layers `L`.
47    pub num_layers: usize,
48}
49
50/// Architecture-specific configuration, tagged by `NamModel.architecture`.
51#[derive(Debug, Clone)]
52pub enum ModelConfig {
53    /// WaveNet: a stack of dilated-convolution layer-arrays. Runnable via
54    /// [`crate::WaveNet`].
55    WaveNet(WaveNetConfig),
56    /// LSTM: stacked recurrent layers plus a linear head.
57    Lstm(LstmConfig),
58}
59
60impl<'de> Deserialize<'de> for NamModel {
61    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
62    where
63        D: Deserializer<'de>,
64    {
65        // Parse the file shape with `config` left raw, then dispatch on
66        // `architecture` to type it. This reads the sibling `architecture` field,
67        // which `#[serde(deserialize_with)]` on a single field cannot do.
68        #[derive(Deserialize)]
69        struct Raw {
70            version: String,
71            architecture: String,
72            config: serde_json::Value,
73            weights: Vec<f32>,
74            #[serde(default)]
75            sample_rate: Option<f64>,
76            #[serde(default)]
77            metadata: Option<serde_json::Value>,
78        }
79
80        let raw = Raw::deserialize(deserializer)?;
81        let config = match raw.architecture.as_str() {
82            "WaveNet" => {
83                ModelConfig::WaveNet(serde_json::from_value(raw.config).map_err(de::Error::custom)?)
84            }
85            "LSTM" => {
86                ModelConfig::Lstm(serde_json::from_value(raw.config).map_err(de::Error::custom)?)
87            }
88            other => {
89                return Err(de::Error::custom(format!(
90                    "unsupported model architecture: {other:?}"
91                )))
92            }
93        };
94
95        Ok(NamModel {
96            version: raw.version,
97            architecture: raw.architecture,
98            config,
99            weights: raw.weights,
100            sample_rate: raw.sample_rate,
101            metadata: raw.metadata,
102        })
103    }
104}
105
106/// Loudness/level-calibration fields NAM may write into `metadata`. All optional;
107/// older or minimal files omit them. Unknown metadata keys are ignored.
108#[derive(Debug, Clone, Default, Deserialize)]
109pub struct Metadata {
110    /// Perceived loudness of the model's output, in LUFS (NAM's `loudness`).
111    #[serde(default)]
112    pub loudness: Option<f32>,
113    /// Analog level (dBu) corresponding to 0 dBFS at the model input.
114    #[serde(default)]
115    pub input_level_dbu: Option<f32>,
116    /// Analog level (dBu) corresponding to 0 dBFS at the model output.
117    #[serde(default)]
118    pub output_level_dbu: Option<f32>,
119}
120
121impl NamModel {
122    /// Read and parse a `.nam` model from a file on disk.
123    ///
124    /// Convenience over [`std::fs::read_to_string`] + [`Self::from_json_str`].
125    /// Returns [`Error::Io`] if the file can't be read, or [`Error::Json`] if its
126    /// contents aren't valid `.nam` JSON.
127    pub fn from_file(path: impl AsRef<std::path::Path>) -> Result<Self, Error> {
128        Self::from_json_str(&std::fs::read_to_string(path)?)
129    }
130
131    /// Parse a `.nam` model from a JSON string already in memory.
132    pub fn from_json_str(json: &str) -> Result<Self, Error> {
133        Ok(serde_json::from_str(json)?)
134    }
135
136    /// The model's sample rate, falling back to [`DEFAULT_SAMPLE_RATE`] when the
137    /// file does not specify one.
138    #[must_use]
139    pub fn sample_rate(&self) -> f64 {
140        self.sample_rate.unwrap_or(DEFAULT_SAMPLE_RATE)
141    }
142
143    /// Parse the calibration subset of `metadata`. Returns defaults (all `None`)
144    /// when there is no metadata block.
145    ///
146    /// Private helper: clones and re-parses the raw `metadata` JSON on each call.
147    /// That's fine for these cold-path (load-time) accessors. A caller that wants all
148    /// fields from one parse can deserialize the public [`Metadata`] from
149    /// [`Self::metadata`] directly.
150    fn metadata_typed(&self) -> Metadata {
151        match &self.metadata {
152            Some(v) => serde_json::from_value(v.clone()).unwrap_or_default(),
153            None => Metadata::default(),
154        }
155    }
156
157    /// Output loudness in LUFS, if the file records it.
158    #[must_use]
159    pub fn loudness(&self) -> Option<f32> {
160        self.metadata_typed().loudness
161    }
162
163    /// Input calibration level in dBu (analog level at 0 dBFS in), if present.
164    #[must_use]
165    pub fn input_level_dbu(&self) -> Option<f32> {
166        self.metadata_typed().input_level_dbu
167    }
168
169    /// Output calibration level in dBu (analog level at 0 dBFS out), if present.
170    #[must_use]
171    pub fn output_level_dbu(&self) -> Option<f32> {
172        self.metadata_typed().output_level_dbu
173    }
174}
175
176/// WaveNet configuration: a sequence of layer-arrays plus a final output scale.
177#[derive(Debug, Clone, Deserialize)]
178pub struct WaveNetConfig {
179    /// One config per layer-array (NAM standard models have two).
180    pub layers: Vec<LayerArrayConfig>,
181    /// Optional separate head. `null` in standard models.
182    #[serde(default)]
183    pub head: Option<serde_json::Value>,
184    /// Output gain applied after the head.
185    pub head_scale: f32,
186}
187
188/// Configuration for a single WaveNet layer-array (a stack of dilated layers
189/// sharing channel/kernel parameters).
190#[derive(Debug, Clone, Deserialize)]
191pub struct LayerArrayConfig {
192    /// Number of input channels into the array (1 for the first array).
193    pub input_size: usize,
194    /// Conditioning signal width (1 for standard amp models).
195    pub condition_size: usize,
196    /// Hidden channel count.
197    pub channels: usize,
198    /// Output channels of each layer's head 1x1.
199    pub head_size: usize,
200    /// Dilated-convolution kernel size (typically 3).
201    pub kernel_size: usize,
202    /// Per-layer dilation factors, e.g. `[1, 2, 4, ..., 512]`.
203    pub dilations: Vec<usize>,
204    /// Activation function name, e.g. `"Tanh"`.
205    pub activation: String,
206    /// Whether the layer uses a gated activation (`tanh * sigmoid`).
207    pub gated: bool,
208    /// Whether the head 1x1 has a bias term.
209    pub head_bias: bool,
210}