use serde::de::{self, Deserializer};
use serde::Deserialize;
use crate::error::Error;
pub const DEFAULT_SAMPLE_RATE: f64 = 48_000.0;
#[derive(Debug, Clone)]
pub struct NamModel {
pub version: String,
pub architecture: String,
pub config: ModelConfig,
pub weights: Vec<f32>,
pub sample_rate: Option<f64>,
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct LstmConfig {
pub input_size: usize,
pub hidden_size: usize,
pub num_layers: usize,
}
#[derive(Debug, Clone)]
pub enum ModelConfig {
WaveNet(WaveNetConfig),
Lstm(LstmConfig),
}
impl<'de> Deserialize<'de> for NamModel {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
struct Raw {
version: String,
architecture: String,
config: serde_json::Value,
weights: Vec<f32>,
#[serde(default)]
sample_rate: Option<f64>,
#[serde(default)]
metadata: Option<serde_json::Value>,
}
let raw = Raw::deserialize(deserializer)?;
let config = match raw.architecture.as_str() {
"WaveNet" => {
ModelConfig::WaveNet(serde_json::from_value(raw.config).map_err(de::Error::custom)?)
}
"LSTM" => {
ModelConfig::Lstm(serde_json::from_value(raw.config).map_err(de::Error::custom)?)
}
other => {
return Err(de::Error::custom(format!(
"unsupported model architecture: {other:?}"
)))
}
};
Ok(NamModel {
version: raw.version,
architecture: raw.architecture,
config,
weights: raw.weights,
sample_rate: raw.sample_rate,
metadata: raw.metadata,
})
}
}
#[derive(Debug, Clone, Default, Deserialize)]
pub struct Metadata {
#[serde(default)]
pub loudness: Option<f32>,
#[serde(default)]
pub input_level_dbu: Option<f32>,
#[serde(default)]
pub output_level_dbu: Option<f32>,
}
impl NamModel {
pub fn from_file(path: impl AsRef<std::path::Path>) -> Result<Self, Error> {
Self::from_json_str(&std::fs::read_to_string(path)?)
}
pub fn from_json_str(json: &str) -> Result<Self, Error> {
Ok(serde_json::from_str(json)?)
}
#[must_use]
pub fn sample_rate(&self) -> f64 {
self.sample_rate.unwrap_or(DEFAULT_SAMPLE_RATE)
}
fn metadata_typed(&self) -> Metadata {
match &self.metadata {
Some(v) => serde_json::from_value(v.clone()).unwrap_or_default(),
None => Metadata::default(),
}
}
#[must_use]
pub fn loudness(&self) -> Option<f32> {
self.metadata_typed().loudness
}
#[must_use]
pub fn input_level_dbu(&self) -> Option<f32> {
self.metadata_typed().input_level_dbu
}
#[must_use]
pub fn output_level_dbu(&self) -> Option<f32> {
self.metadata_typed().output_level_dbu
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct WaveNetConfig {
pub layers: Vec<LayerArrayConfig>,
#[serde(default)]
pub head: Option<serde_json::Value>,
pub head_scale: f32,
}
#[derive(Debug, Clone, Deserialize)]
pub struct LayerArrayConfig {
pub input_size: usize,
pub condition_size: usize,
pub channels: usize,
pub head_size: usize,
pub kernel_size: usize,
pub dilations: Vec<usize>,
pub activation: String,
pub gated: bool,
pub head_bias: bool,
}