use std::path::Path;
use serde::Deserialize;
use snafu::ResultExt;
use svod_arch::ctc::{CtcDecoder, GreedyDecoder};
use super::error::{ConfigIoSnafu, ConfigSnafu, Error, Result};
#[derive(Clone, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SubsamplingMode {
Conv1d,
Conv2d,
}
#[derive(Clone, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ConvNormType {
LayerNorm,
BatchNorm,
}
#[derive(Clone)]
pub struct GigaAmConfig {
pub max_batch_size: usize,
pub n_mels: usize,
pub d_model: usize,
pub n_heads: usize,
pub n_layers: usize,
pub d_ff: usize,
pub conv_kernel: usize,
pub subsampling_factor: usize,
pub subsampling_mode: SubsamplingMode,
pub subs_kernel_size: usize,
pub conv_norm_type: ConvNormType,
pub vocab_size: usize,
pub sample_rate: usize,
pub n_fft: usize,
pub hop_length: usize,
pub win_length: usize,
pub mel_center: bool,
pub max_mel_frames: usize,
pub max_encoder_frames: usize,
pub decoder: CtcDecoder,
pub transducer: Option<TransducerConfig>,
}
#[derive(Clone, Debug)]
pub struct TransducerConfig {
pub pred_hidden: usize,
pub pred_rnn_layers: usize,
pub joint_hidden: usize,
pub num_classes: usize,
pub max_symbols_per_step: usize,
pub vocabulary: Vec<String>,
pub sentencepiece: bool,
}
impl GigaAmConfig {
pub fn from_json(path: &Path) -> Result<Self> {
let data = std::fs::read_to_string(path).context(ConfigIoSnafu)?;
let root: serde_json::Value = serde_json::from_str(&data).context(ConfigSnafu)?;
let leaf = root.pointer("/cfg/model/cfg").ok_or_else(|| Error::DecoderConfig {
message: "config.json missing required path /cfg/model/cfg".into(),
})?;
let raw: RawModelCfg = serde_json::from_value(leaf.clone()).context(ConfigSnafu)?;
Self::from_raw(raw)
}
fn from_raw(raw: RawModelCfg) -> Result<Self> {
validate_preprocessor(&raw.preprocessor)?;
validate_encoder(&raw.encoder)?;
let max_encoder_frames = raw.encoder.pos_emb_max_len;
let max_mel_frames = raw
.encoder
.max_mel_frames
.or(raw.encoder.max_seq_len)
.unwrap_or(max_encoder_frames * raw.encoder.subsampling_factor);
let subs_kernel = match &raw.encoder.subsampling {
SubsamplingMode::Conv1d => raw.encoder.subs_kernel_size,
SubsamplingMode::Conv2d => 3,
};
let max_sub_frames = subsampled_len(subs_kernel, max_mel_frames);
if max_sub_frames > max_encoder_frames {
return Err(Error::DecoderConfig {
message: format!(
"max_mel_frames ({max_mel_frames}) subsamples to {max_sub_frames} encoder frames, exceeding pos_emb_max_len ({max_encoder_frames})"
),
});
}
let vocab_size = raw
.head
.num_classes
.or_else(|| raw.head.decoder.as_ref().and_then(|d| d.num_classes))
.or_else(|| raw.head.joint.as_ref().and_then(|j| j.num_classes))
.ok_or_else(|| Error::DecoderConfig {
message: "missing num_classes (head.num_classes or head.{decoder,joint}.num_classes)".into(),
})?;
let decoder = raw_to_decoder(raw.decoding.as_ref(), vocab_size)?;
let transducer = raw_to_transducer(&raw.head, raw.decoding.as_ref(), vocab_size)?;
Ok(Self {
max_batch_size: raw.encoder.max_batch_size,
n_mels: raw.preprocessor.features,
d_model: raw.encoder.d_model,
n_heads: raw.encoder.n_heads,
n_layers: raw.encoder.n_layers,
d_ff: raw.encoder.d_model * raw.encoder.ff_expansion_factor,
conv_kernel: raw.encoder.conv_kernel_size,
subsampling_factor: raw.encoder.subsampling_factor,
subsampling_mode: raw.encoder.subsampling,
subs_kernel_size: raw.encoder.subs_kernel_size,
conv_norm_type: raw.encoder.conv_norm_type,
vocab_size,
sample_rate: raw.preprocessor.sample_rate,
n_fft: raw.preprocessor.n_fft,
hop_length: raw.preprocessor.hop_length,
win_length: raw.preprocessor.win_length,
mel_center: raw.preprocessor.center,
max_mel_frames,
max_encoder_frames,
decoder,
transducer,
})
}
}
#[derive(Deserialize)]
struct RawModelCfg {
preprocessor: RawPreprocessor,
encoder: RawEncoder,
head: RawHead,
#[serde(default)]
decoding: Option<serde_json::Value>,
}
#[derive(Deserialize)]
struct RawPreprocessor {
features: usize,
sample_rate: usize,
n_fft: usize,
hop_length: usize,
win_length: usize,
#[serde(default = "default_true")]
center: bool,
#[serde(default)]
mel_scale: Option<String>,
#[serde(default)]
mel_norm: Option<String>,
}
#[derive(Deserialize)]
struct RawEncoder {
d_model: usize,
ff_expansion_factor: usize,
n_heads: usize,
n_layers: usize,
conv_kernel_size: usize,
subsampling_factor: usize,
#[serde(default = "default_self_attention_model")]
self_attention_model: String,
#[serde(default = "default_subs_kernel_size")]
subs_kernel_size: usize,
#[serde(default = "default_subsampling_mode")]
subsampling: SubsamplingMode,
#[serde(default = "default_conv_norm_type")]
conv_norm_type: ConvNormType,
#[serde(default = "default_pos_emb_max_len")]
pos_emb_max_len: usize,
#[serde(default)]
max_mel_frames: Option<usize>,
#[serde(default)]
max_seq_len: Option<usize>,
#[serde(default = "default_max_batch_size")]
max_batch_size: usize,
}
#[derive(Deserialize)]
struct RawHead {
#[serde(default)]
num_classes: Option<usize>,
#[serde(default)]
decoder: Option<RawHeadDecoder>,
#[serde(default)]
joint: Option<RawHeadJoint>,
}
#[derive(Deserialize)]
struct RawHeadDecoder {
pred_hidden: usize,
pred_rnn_layers: usize,
#[serde(default)]
num_classes: Option<usize>,
}
#[derive(Deserialize)]
struct RawHeadJoint {
joint_hidden: usize,
#[serde(default)]
num_classes: Option<usize>,
}
fn default_true() -> bool {
true
}
fn default_subs_kernel_size() -> usize {
3
}
fn default_subsampling_mode() -> SubsamplingMode {
SubsamplingMode::Conv2d
}
fn default_conv_norm_type() -> ConvNormType {
ConvNormType::BatchNorm
}
fn default_pos_emb_max_len() -> usize {
5000
}
fn default_self_attention_model() -> String {
"rotary".into()
}
fn default_max_batch_size() -> usize {
32
}
fn validate_preprocessor(pre: &RawPreprocessor) -> Result<()> {
if let Some(scale) = pre.mel_scale.as_deref()
&& scale != "htk"
{
return Err(Error::DecoderConfig {
message: format!(
"unsupported mel_scale {scale:?}; Svod GigaAM currently matches torchaudio's HTK mel frontend"
),
});
}
if let Some(norm) = pre.mel_norm.as_deref() {
return Err(Error::DecoderConfig {
message: format!(
"unsupported mel_norm {norm:?}; Svod GigaAM currently supports only null/no mel normalization"
),
});
}
if pre.n_fft != pre.win_length {
return Err(Error::DecoderConfig {
message: format!(
"unsupported mel frontend n_fft ({}) != win_length ({}); current GigaAM parity path requires equal FFT/window lengths",
pre.n_fft, pre.win_length
),
});
}
Ok(())
}
fn validate_encoder(encoder: &RawEncoder) -> Result<()> {
if encoder.self_attention_model != "rotary" {
return Err(Error::DecoderConfig {
message: format!(
"unsupported self_attention_model {:?}; Svod GigaAM currently implements rotary attention only",
encoder.self_attention_model
),
});
}
if encoder.subsampling_factor != 4 {
return Err(Error::DecoderConfig {
message: format!(
"unsupported subsampling_factor {}; Svod GigaAM currently implements exactly two stride-2 subsampling layers",
encoder.subsampling_factor
),
});
}
Ok(())
}
fn subsampled_len(kernel_size: usize, mel_frames: usize) -> usize {
let pad = (kernel_size - 1) / 2;
let mut len = mel_frames;
for _ in 0..2 {
len = len.saturating_add(2 * pad).saturating_sub(kernel_size) / 2 + 1;
}
len
}
fn raw_to_decoder(decoding: Option<&serde_json::Value>, vocab_size: usize) -> Result<CtcDecoder> {
let Some(decoding) = decoding else {
return Ok(CtcDecoder::Greedy(GreedyDecoder::new(Vec::new())));
};
if decoding.is_null() {
return Ok(CtcDecoder::Greedy(GreedyDecoder::new(Vec::new())));
}
let target = decoding["_target_"].as_str().unwrap_or("");
let decoder: CtcDecoder = if target.contains("CTCGreedyDecoding") {
let g: GreedyDecoder = serde_json::from_value(decoding.clone()).context(ConfigSnafu)?;
CtcDecoder::Greedy(g)
} else if target.contains("CTCBeamDecoding") {
let b: svod_arch::ctc::BeamDecoder = serde_json::from_value(decoding.clone()).context(ConfigSnafu)?;
CtcDecoder::Beam(Box::new(b))
} else {
let vocab: Vec<String> = decoding["vocabulary"]
.as_array()
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
.unwrap_or_default();
CtcDecoder::Greedy(GreedyDecoder::new(vocab))
};
if !decoder.vocabulary().is_empty() && decoder.total_vocab() != vocab_size {
return Err(Error::DecoderConfig {
message: format!(
"decoder vocabulary length + 1 ({}) != head.num_classes ({}); \
CTC convention is one blank token appended after the vocabulary",
decoder.total_vocab(),
vocab_size
),
});
}
Ok(decoder)
}
fn raw_to_transducer(
head: &RawHead,
decoding: Option<&serde_json::Value>,
vocab_size: usize,
) -> Result<Option<TransducerConfig>> {
let target = decoding.and_then(|d| d["_target_"].as_str()).unwrap_or("");
let has_decoder = head.decoder.is_some();
let has_joint = head.joint.is_some();
if !(target.contains("RNNT") || (has_decoder && has_joint)) {
return Ok(None);
}
let dec = head
.decoder
.as_ref()
.ok_or_else(|| Error::DecoderConfig { message: "RNN-T config: missing head.decoder block".into() })?;
let joint = head
.joint
.as_ref()
.ok_or_else(|| Error::DecoderConfig { message: "RNN-T config: missing head.joint block".into() })?;
let max_symbols_per_step = decoding.and_then(|d| d["max_symbols_per_step"].as_u64()).unwrap_or(10) as usize;
let vocabulary: Vec<String> = decoding
.and_then(|d| d["vocabulary"].as_array())
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
.unwrap_or_default();
let sentencepiece =
decoding.and_then(|d| d.get("model_path")).and_then(|v| v.as_str()).map(|s| !s.is_empty()).unwrap_or(false);
Ok(Some(TransducerConfig {
pred_hidden: dec.pred_hidden,
pred_rnn_layers: dec.pred_rnn_layers,
joint_hidden: joint.joint_hidden,
num_classes: vocab_size,
max_symbols_per_step,
vocabulary,
sentencepiece,
}))
}