use serde::Deserialize;
use crate::error::Error;
pub const DEFAULT_SAMPLE_RATE: f64 = 48_000.0;
#[derive(Debug, Clone, Deserialize)]
pub struct NamModel {
pub version: String,
pub architecture: String,
pub config: WaveNetConfig,
pub weights: Vec<f32>,
#[serde(default)]
pub sample_rate: Option<f64>,
#[serde(default)]
pub metadata: Option<serde_json::Value>,
}
impl NamModel {
pub fn from_file(path: impl AsRef<std::path::Path>) -> Result<Self, Error> {
Self::from_json_str(&std::fs::read_to_string(path)?)
}
pub fn from_json_str(json: &str) -> Result<Self, Error> {
Ok(serde_json::from_str(json)?)
}
#[must_use]
pub fn sample_rate(&self) -> f64 {
self.sample_rate.unwrap_or(DEFAULT_SAMPLE_RATE)
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct WaveNetConfig {
pub layers: Vec<LayerArrayConfig>,
#[serde(default)]
pub head: Option<serde_json::Value>,
pub head_scale: f32,
}
#[derive(Debug, Clone, Deserialize)]
pub struct LayerArrayConfig {
pub input_size: usize,
pub condition_size: usize,
pub channels: usize,
pub head_size: usize,
pub kernel_size: usize,
pub dilations: Vec<usize>,
pub activation: String,
pub gated: bool,
pub head_bias: bool,
}