nam_rs/model.rs
1//! Parsing of the on-disk `.nam` file format.
2//!
3//! A `.nam` file is a JSON object. The fields here mirror NAM's
4//! `export_config()` / `export_weights()` output (see crate-level attribution).
5//! Both the WaveNet and LSTM architectures are parsed here (see [`ModelConfig`]);
6//! the runtime forward passes live in their own modules.
7
8use serde::de::{self, Deserializer};
9use serde::Deserialize;
10
11use crate::error::Error;
12
13/// Sample rate assumed when a `.nam` file omits the `sample_rate` field.
14///
15/// Matches NAM's documented default.
16pub const DEFAULT_SAMPLE_RATE: f64 = 48_000.0;
17
18/// A parsed `.nam` model file.
19///
20/// This is the *file representation* — the raw config + flat weight blob. To run
21/// inference, build a [`crate::WaveNet`] from it.
22#[derive(Debug, Clone)]
23pub struct NamModel {
24 /// `.nam` format version string (e.g. `"0.5.4"`).
25 pub version: String,
26 /// Model architecture, e.g. `"WaveNet"`.
27 pub architecture: String,
28 /// Architecture-specific configuration (dispatched on [`Self::architecture`]).
29 pub config: ModelConfig,
30 /// Flat weight blob. The final element is `head_scale` (see NAM
31 /// `export_weights`). Stored as `f32` to match NAM Core's inference precision.
32 pub weights: Vec<f32>,
33 /// Training sample rate. Absent in older files; see [`Self::sample_rate`].
34 pub sample_rate: Option<f64>,
35 /// Opaque training/gear metadata. Not used for inference.
36 pub metadata: Option<serde_json::Value>,
37}
38
39/// LSTM configuration (NAM `_export_config`).
40#[derive(Debug, Clone, Deserialize)]
41pub struct LstmConfig {
42 /// Input width (1 for mono amp models).
43 pub input_size: usize,
44 /// Hidden state dimension `H`.
45 pub hidden_size: usize,
46 /// Number of stacked LSTM layers `L`.
47 pub num_layers: usize,
48}
49
50/// Architecture-specific configuration, tagged by `NamModel.architecture`.
51#[derive(Debug, Clone)]
52pub enum ModelConfig {
53 /// WaveNet: a stack of dilated-convolution layer-arrays. Runnable via
54 /// [`crate::WaveNet`].
55 WaveNet(WaveNetConfig),
56 /// LSTM: stacked recurrent layers plus a linear head.
57 Lstm(LstmConfig),
58}
59
60impl<'de> Deserialize<'de> for NamModel {
61 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
62 where
63 D: Deserializer<'de>,
64 {
65 // Parse the file shape with `config` left raw, then dispatch on
66 // `architecture` to type it. This reads the sibling `architecture` field,
67 // which `#[serde(deserialize_with)]` on a single field cannot do.
68 #[derive(Deserialize)]
69 struct Raw {
70 version: String,
71 architecture: String,
72 config: serde_json::Value,
73 weights: Vec<f32>,
74 #[serde(default)]
75 sample_rate: Option<f64>,
76 #[serde(default)]
77 metadata: Option<serde_json::Value>,
78 }
79
80 let raw = Raw::deserialize(deserializer)?;
81 let config = match raw.architecture.as_str() {
82 "WaveNet" => {
83 ModelConfig::WaveNet(serde_json::from_value(raw.config).map_err(de::Error::custom)?)
84 }
85 "LSTM" => {
86 ModelConfig::Lstm(serde_json::from_value(raw.config).map_err(de::Error::custom)?)
87 }
88 other => {
89 return Err(de::Error::custom(format!(
90 "unsupported model architecture: {other:?}"
91 )))
92 }
93 };
94
95 Ok(NamModel {
96 version: raw.version,
97 architecture: raw.architecture,
98 config,
99 weights: raw.weights,
100 sample_rate: raw.sample_rate,
101 metadata: raw.metadata,
102 })
103 }
104}
105
106/// Loudness/level-calibration fields NAM may write into `metadata`. All optional;
107/// older or minimal files omit them. Unknown metadata keys are ignored.
108#[derive(Debug, Clone, Default, Deserialize)]
109pub struct Metadata {
110 /// Perceived loudness of the model's output, in LUFS (NAM's `loudness`).
111 #[serde(default)]
112 pub loudness: Option<f32>,
113 /// Analog level (dBu) corresponding to 0 dBFS at the model input.
114 #[serde(default)]
115 pub input_level_dbu: Option<f32>,
116 /// Analog level (dBu) corresponding to 0 dBFS at the model output.
117 #[serde(default)]
118 pub output_level_dbu: Option<f32>,
119}
120
121impl NamModel {
122 /// Read and parse a `.nam` model from a file on disk.
123 ///
124 /// Convenience over [`std::fs::read_to_string`] + [`Self::from_json_str`].
125 /// Returns [`Error::Io`] if the file can't be read, or [`Error::Json`] if its
126 /// contents aren't valid `.nam` JSON.
127 pub fn from_file(path: impl AsRef<std::path::Path>) -> Result<Self, Error> {
128 Self::from_json_str(&std::fs::read_to_string(path)?)
129 }
130
131 /// Parse a `.nam` model from a JSON string already in memory.
132 pub fn from_json_str(json: &str) -> Result<Self, Error> {
133 Ok(serde_json::from_str(json)?)
134 }
135
136 /// The model's sample rate, falling back to [`DEFAULT_SAMPLE_RATE`] when the
137 /// file does not specify one.
138 #[must_use]
139 pub fn sample_rate(&self) -> f64 {
140 self.sample_rate.unwrap_or(DEFAULT_SAMPLE_RATE)
141 }
142
143 /// Parse the calibration subset of `metadata`. Returns defaults (all `None`)
144 /// when there is no metadata block.
145 ///
146 /// Private helper: clones and re-parses the raw `metadata` JSON on each call.
147 /// That's fine for these cold-path (load-time) accessors. A caller that wants all
148 /// fields from one parse can deserialize the public [`Metadata`] from
149 /// [`Self::metadata`] directly.
150 fn metadata_typed(&self) -> Metadata {
151 match &self.metadata {
152 Some(v) => serde_json::from_value(v.clone()).unwrap_or_default(),
153 None => Metadata::default(),
154 }
155 }
156
157 /// Output loudness in LUFS, if the file records it.
158 #[must_use]
159 pub fn loudness(&self) -> Option<f32> {
160 self.metadata_typed().loudness
161 }
162
163 /// Input calibration level in dBu (analog level at 0 dBFS in), if present.
164 #[must_use]
165 pub fn input_level_dbu(&self) -> Option<f32> {
166 self.metadata_typed().input_level_dbu
167 }
168
169 /// Output calibration level in dBu (analog level at 0 dBFS out), if present.
170 #[must_use]
171 pub fn output_level_dbu(&self) -> Option<f32> {
172 self.metadata_typed().output_level_dbu
173 }
174}
175
176/// WaveNet configuration: a sequence of layer-arrays plus a final output scale.
177#[derive(Debug, Clone, Deserialize)]
178pub struct WaveNetConfig {
179 /// One config per layer-array (NAM standard models have two).
180 pub layers: Vec<LayerArrayConfig>,
181 /// Optional separate head. `null` in standard models.
182 #[serde(default)]
183 pub head: Option<serde_json::Value>,
184 /// Output gain applied after the head.
185 pub head_scale: f32,
186}
187
188/// Configuration for a single WaveNet layer-array (a stack of dilated layers
189/// sharing channel/kernel parameters).
190#[derive(Debug, Clone, Deserialize)]
191pub struct LayerArrayConfig {
192 /// Number of input channels into the array (1 for the first array).
193 pub input_size: usize,
194 /// Conditioning signal width (1 for standard amp models).
195 pub condition_size: usize,
196 /// Hidden channel count.
197 pub channels: usize,
198 /// Output channels of each layer's head 1x1.
199 pub head_size: usize,
200 /// Dilated-convolution kernel size (typically 3).
201 pub kernel_size: usize,
202 /// Per-layer dilation factors, e.g. `[1, 2, 4, ..., 512]`.
203 pub dilations: Vec<usize>,
204 /// Activation function name, e.g. `"Tanh"`.
205 pub activation: String,
206 /// Whether the layer uses a gated activation (`tanh * sigmoid`).
207 pub gated: bool,
208 /// Whether the head 1x1 has a bias term.
209 pub head_bias: bool,
210}