Skip to main content

svod_model/gigaam/
config.rs

1//! GigaAM JSON config parsing.
2//!
3//! `GigaAmConfig::from_json` reads `config.json` and produces a fully-resolved
4//! [`GigaAmConfig`]. Parsing is via `serde_json` with mirror `Raw*` structs
5//! that match the on-disk shape (`cfg.model.cfg.{preprocessor,encoder,head,
6//! decoding}`); a thin `From`-style projection then validates cross-field
7//! invariants and dispatches the `decoding._target_` union substring-style.
8//!
9//! Substring dispatch (rather than `#[serde(tag = "_target_")]`) is intentional:
10//! the `_target_` paths drift across upstream NeMo versions, so exact-rename
11//! enum variants would break each release. The leaf decoder types
12//! (`GreedyDecoder`, `BeamDecoder`) are themselves `Deserialize`, so the
13//! within-variant fields still ride serde.
14
15use std::path::Path;
16
17use serde::Deserialize;
18use snafu::ResultExt;
19use svod_arch::ctc::{CtcDecoder, GreedyDecoder};
20
21use super::error::{ConfigIoSnafu, ConfigSnafu, Error, Result};
22
23#[derive(Clone, Deserialize)]
24#[serde(rename_all = "snake_case")]
25pub enum SubsamplingMode {
26    Conv1d,
27    Conv2d,
28}
29
30#[derive(Clone, Deserialize)]
31#[serde(rename_all = "snake_case")]
32pub enum ConvNormType {
33    LayerNorm,
34    BatchNorm,
35}
36
37#[derive(Clone)]
38pub struct GigaAmConfig {
39    pub max_batch_size: usize,
40    pub n_mels: usize,
41    pub d_model: usize,
42    pub n_heads: usize,
43    pub n_layers: usize,
44    pub d_ff: usize,
45    pub conv_kernel: usize,
46    pub subsampling_factor: usize,
47    pub subsampling_mode: SubsamplingMode,
48    pub subs_kernel_size: usize,
49    pub conv_norm_type: ConvNormType,
50    pub vocab_size: usize,
51    pub sample_rate: usize,
52    pub n_fft: usize,
53    pub hop_length: usize,
54    pub win_length: usize,
55    pub mel_center: bool,
56    pub max_mel_frames: usize,
57    pub max_encoder_frames: usize,
58    /// CTC decoder built from the `decoding` section of the config, or an
59    /// empty-vocabulary greedy decoder for synthetic configs that don't
60    /// declare one.
61    pub decoder: CtcDecoder,
62    /// Transducer-specific config, populated when `decoding._target_` ends
63    /// in `RNNTGreedyDecoding` (or the head config has predictor/joint
64    /// blocks). `None` for CTC checkpoints.
65    pub transducer: Option<TransducerConfig>,
66}
67
68/// RNN-T-specific config extracted from the JSON `head.decoder` /
69/// `head.joint` / `decoding` blocks. See `submodules/GigaAM/gigaam/decoder.py`
70/// for the reference shape.
71#[derive(Clone, Debug)]
72pub struct TransducerConfig {
73    pub pred_hidden: usize,
74    pub pred_rnn_layers: usize,
75    pub joint_hidden: usize,
76    /// `vocabulary.len() + 1` — includes the blank token at the end.
77    pub num_classes: usize,
78    pub max_symbols_per_step: usize,
79    pub vocabulary: Vec<String>,
80    /// True when the vocabulary entries are SentencePiece pieces (apply
81    /// `▁ → space` post-processing on the decoded string).
82    pub sentencepiece: bool,
83}
84
85impl GigaAmConfig {
86    pub fn from_json(path: &Path) -> Result<Self> {
87        let data = std::fs::read_to_string(path).context(ConfigIoSnafu)?;
88        let root: serde_json::Value = serde_json::from_str(&data).context(ConfigSnafu)?;
89        let leaf = root.pointer("/cfg/model/cfg").ok_or_else(|| Error::DecoderConfig {
90            message: "config.json missing required path /cfg/model/cfg".into(),
91        })?;
92        let raw: RawModelCfg = serde_json::from_value(leaf.clone()).context(ConfigSnafu)?;
93        Self::from_raw(raw)
94    }
95
96    fn from_raw(raw: RawModelCfg) -> Result<Self> {
97        validate_preprocessor(&raw.preprocessor)?;
98        validate_encoder(&raw.encoder)?;
99
100        // `max_mel_frames` is the pre-subsampling sequence-length bound. Configs that
101        // only specify `pos_emb_max_len` (the post-subsampling encoder bound) need it
102        // multiplied by `subsampling_factor` so audio approaching the encoder cap
103        // isn't rejected at the JIT input stage.
104        let max_encoder_frames = raw.encoder.pos_emb_max_len;
105        let max_mel_frames = raw
106            .encoder
107            .max_mel_frames
108            .or(raw.encoder.max_seq_len)
109            .unwrap_or(max_encoder_frames * raw.encoder.subsampling_factor);
110        let subs_kernel = match &raw.encoder.subsampling {
111            SubsamplingMode::Conv1d => raw.encoder.subs_kernel_size,
112            SubsamplingMode::Conv2d => 3,
113        };
114        let max_sub_frames = subsampled_len(subs_kernel, max_mel_frames);
115        if max_sub_frames > max_encoder_frames {
116            return Err(Error::DecoderConfig {
117                message: format!(
118                    "max_mel_frames ({max_mel_frames}) subsamples to {max_sub_frames} encoder frames, exceeding pos_emb_max_len ({max_encoder_frames})"
119                ),
120            });
121        }
122        // CTC configs put `num_classes` directly on `head`; RNN-T configs nest
123        // it under `head.decoder.num_classes` / `head.joint.num_classes`.
124        let vocab_size = raw
125            .head
126            .num_classes
127            .or_else(|| raw.head.decoder.as_ref().and_then(|d| d.num_classes))
128            .or_else(|| raw.head.joint.as_ref().and_then(|j| j.num_classes))
129            .ok_or_else(|| Error::DecoderConfig {
130                message: "missing num_classes (head.num_classes or head.{decoder,joint}.num_classes)".into(),
131            })?;
132        let decoder = raw_to_decoder(raw.decoding.as_ref(), vocab_size)?;
133        let transducer = raw_to_transducer(&raw.head, raw.decoding.as_ref(), vocab_size)?;
134        Ok(Self {
135            max_batch_size: raw.encoder.max_batch_size,
136            n_mels: raw.preprocessor.features,
137            d_model: raw.encoder.d_model,
138            n_heads: raw.encoder.n_heads,
139            n_layers: raw.encoder.n_layers,
140            d_ff: raw.encoder.d_model * raw.encoder.ff_expansion_factor,
141            conv_kernel: raw.encoder.conv_kernel_size,
142            subsampling_factor: raw.encoder.subsampling_factor,
143            subsampling_mode: raw.encoder.subsampling,
144            subs_kernel_size: raw.encoder.subs_kernel_size,
145            conv_norm_type: raw.encoder.conv_norm_type,
146            vocab_size,
147            sample_rate: raw.preprocessor.sample_rate,
148            n_fft: raw.preprocessor.n_fft,
149            hop_length: raw.preprocessor.hop_length,
150            win_length: raw.preprocessor.win_length,
151            mel_center: raw.preprocessor.center,
152            max_mel_frames,
153            max_encoder_frames,
154            decoder,
155            transducer,
156        })
157    }
158}
159
160// ─── Serde mirror structs (private) ───────────────────────────────────────
161//
162// On-disk shape is `cfg.model.cfg.{preprocessor,encoder,head,decoding}`; the
163// outer wrappers are navigated via `serde_json::Value::pointer` in `from_json`
164// rather than mirrored here so this file stays focused on the leaf shape.
165
166#[derive(Deserialize)]
167struct RawModelCfg {
168    preprocessor: RawPreprocessor,
169    encoder: RawEncoder,
170    head: RawHead,
171    #[serde(default)]
172    decoding: Option<serde_json::Value>,
173}
174
175#[derive(Deserialize)]
176struct RawPreprocessor {
177    features: usize,
178    sample_rate: usize,
179    n_fft: usize,
180    hop_length: usize,
181    win_length: usize,
182    #[serde(default = "default_true")]
183    center: bool,
184    #[serde(default)]
185    mel_scale: Option<String>,
186    #[serde(default)]
187    mel_norm: Option<String>,
188}
189
190#[derive(Deserialize)]
191struct RawEncoder {
192    d_model: usize,
193    ff_expansion_factor: usize,
194    n_heads: usize,
195    n_layers: usize,
196    conv_kernel_size: usize,
197    subsampling_factor: usize,
198    #[serde(default = "default_self_attention_model")]
199    self_attention_model: String,
200    #[serde(default = "default_subs_kernel_size")]
201    subs_kernel_size: usize,
202    #[serde(default = "default_subsampling_mode")]
203    subsampling: SubsamplingMode,
204    #[serde(default = "default_conv_norm_type")]
205    conv_norm_type: ConvNormType,
206    #[serde(default = "default_pos_emb_max_len")]
207    pos_emb_max_len: usize,
208    #[serde(default)]
209    max_mel_frames: Option<usize>,
210    #[serde(default)]
211    max_seq_len: Option<usize>,
212    #[serde(default = "default_max_batch_size")]
213    max_batch_size: usize,
214}
215
216#[derive(Deserialize)]
217struct RawHead {
218    #[serde(default)]
219    num_classes: Option<usize>,
220    #[serde(default)]
221    decoder: Option<RawHeadDecoder>,
222    #[serde(default)]
223    joint: Option<RawHeadJoint>,
224}
225
226#[derive(Deserialize)]
227struct RawHeadDecoder {
228    pred_hidden: usize,
229    pred_rnn_layers: usize,
230    #[serde(default)]
231    num_classes: Option<usize>,
232}
233
234#[derive(Deserialize)]
235struct RawHeadJoint {
236    joint_hidden: usize,
237    #[serde(default)]
238    num_classes: Option<usize>,
239}
240
241fn default_true() -> bool {
242    true
243}
244fn default_subs_kernel_size() -> usize {
245    3
246}
247fn default_subsampling_mode() -> SubsamplingMode {
248    SubsamplingMode::Conv2d
249}
250fn default_conv_norm_type() -> ConvNormType {
251    ConvNormType::BatchNorm
252}
253fn default_pos_emb_max_len() -> usize {
254    5000
255}
256fn default_self_attention_model() -> String {
257    "rotary".into()
258}
259fn default_max_batch_size() -> usize {
260    32
261}
262
263fn validate_preprocessor(pre: &RawPreprocessor) -> Result<()> {
264    if let Some(scale) = pre.mel_scale.as_deref()
265        && scale != "htk"
266    {
267        return Err(Error::DecoderConfig {
268            message: format!(
269                "unsupported mel_scale {scale:?}; Svod GigaAM currently matches torchaudio's HTK mel frontend"
270            ),
271        });
272    }
273    if let Some(norm) = pre.mel_norm.as_deref() {
274        return Err(Error::DecoderConfig {
275            message: format!(
276                "unsupported mel_norm {norm:?}; Svod GigaAM currently supports only null/no mel normalization"
277            ),
278        });
279    }
280    if pre.n_fft != pre.win_length {
281        return Err(Error::DecoderConfig {
282            message: format!(
283                "unsupported mel frontend n_fft ({}) != win_length ({}); current GigaAM parity path requires equal FFT/window lengths",
284                pre.n_fft, pre.win_length
285            ),
286        });
287    }
288    Ok(())
289}
290
291fn validate_encoder(encoder: &RawEncoder) -> Result<()> {
292    if encoder.self_attention_model != "rotary" {
293        return Err(Error::DecoderConfig {
294            message: format!(
295                "unsupported self_attention_model {:?}; Svod GigaAM currently implements rotary attention only",
296                encoder.self_attention_model
297            ),
298        });
299    }
300    if encoder.subsampling_factor != 4 {
301        return Err(Error::DecoderConfig {
302            message: format!(
303                "unsupported subsampling_factor {}; Svod GigaAM currently implements exactly two stride-2 subsampling layers",
304                encoder.subsampling_factor
305            ),
306        });
307    }
308    Ok(())
309}
310
311fn subsampled_len(kernel_size: usize, mel_frames: usize) -> usize {
312    let pad = (kernel_size - 1) / 2;
313    let mut len = mel_frames;
314    for _ in 0..2 {
315        len = len.saturating_add(2 * pad).saturating_sub(kernel_size) / 2 + 1;
316    }
317    len
318}
319
320// ─── Decoder + transducer dispatch ────────────────────────────────────────
321
322fn raw_to_decoder(decoding: Option<&serde_json::Value>, vocab_size: usize) -> Result<CtcDecoder> {
323    let Some(decoding) = decoding else {
324        return Ok(CtcDecoder::Greedy(GreedyDecoder::new(Vec::new())));
325    };
326    if decoding.is_null() {
327        return Ok(CtcDecoder::Greedy(GreedyDecoder::new(Vec::new())));
328    }
329    let target = decoding["_target_"].as_str().unwrap_or("");
330    let decoder: CtcDecoder = if target.contains("CTCGreedyDecoding") {
331        let g: GreedyDecoder = serde_json::from_value(decoding.clone()).context(ConfigSnafu)?;
332        CtcDecoder::Greedy(g)
333    } else if target.contains("CTCBeamDecoding") {
334        let b: svod_arch::ctc::BeamDecoder = serde_json::from_value(decoding.clone()).context(ConfigSnafu)?;
335        CtcDecoder::Beam(Box::new(b))
336    } else {
337        // Unknown / missing target. If there's a vocabulary array, default to
338        // greedy; otherwise empty.
339        let vocab: Vec<String> = decoding["vocabulary"]
340            .as_array()
341            .map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
342            .unwrap_or_default();
343        CtcDecoder::Greedy(GreedyDecoder::new(vocab))
344    };
345    if !decoder.vocabulary().is_empty() && decoder.total_vocab() != vocab_size {
346        return Err(Error::DecoderConfig {
347            message: format!(
348                "decoder vocabulary length + 1 ({}) != head.num_classes ({}); \
349                 CTC convention is one blank token appended after the vocabulary",
350                decoder.total_vocab(),
351                vocab_size
352            ),
353        });
354    }
355    Ok(decoder)
356}
357
358fn raw_to_transducer(
359    head: &RawHead,
360    decoding: Option<&serde_json::Value>,
361    vocab_size: usize,
362) -> Result<Option<TransducerConfig>> {
363    let target = decoding.and_then(|d| d["_target_"].as_str()).unwrap_or("");
364    let has_decoder = head.decoder.is_some();
365    let has_joint = head.joint.is_some();
366    if !(target.contains("RNNT") || (has_decoder && has_joint)) {
367        return Ok(None);
368    }
369    let dec = head
370        .decoder
371        .as_ref()
372        .ok_or_else(|| Error::DecoderConfig { message: "RNN-T config: missing head.decoder block".into() })?;
373    let joint = head
374        .joint
375        .as_ref()
376        .ok_or_else(|| Error::DecoderConfig { message: "RNN-T config: missing head.joint block".into() })?;
377    let max_symbols_per_step = decoding.and_then(|d| d["max_symbols_per_step"].as_u64()).unwrap_or(10) as usize;
378    // Vocabulary preference: `decoding.vocabulary` (CTC convention reused for
379    // RNN-T configs). For SentencePiece RNN-T checkpoints (e.g. v3_e2e_rnnt)
380    // this is `null` in the JSON config; the actual pieces ship as
381    // `tokenizer.model` and are loaded via `from_safetensors_with_tokenizer`.
382    // Empty here is fine — `from_state_dict` will splice in the override.
383    let vocabulary: Vec<String> = decoding
384        .and_then(|d| d["vocabulary"].as_array())
385        .map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
386        .unwrap_or_default();
387    // SentencePiece iff `decoding.model_path` is a non-empty string, else
388    // char-wise.
389    let sentencepiece =
390        decoding.and_then(|d| d.get("model_path")).and_then(|v| v.as_str()).map(|s| !s.is_empty()).unwrap_or(false);
391    Ok(Some(TransducerConfig {
392        pred_hidden: dec.pred_hidden,
393        pred_rnn_layers: dec.pred_rnn_layers,
394        joint_hidden: joint.joint_hidden,
395        num_classes: vocab_size,
396        max_symbols_per_step,
397        vocabulary,
398        sentencepiece,
399    }))
400}