use serde::de::{self, Deserializer};
use serde::Deserialize;
use crate::error::Error;
#[derive(Debug, Clone, PartialEq)]
pub enum ActivationSpec {
Named {
name: String,
negative_slope: Option<f32>,
},
Unsupported(serde_json::Value),
}
impl<'de> Deserialize<'de> for ActivationSpec {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let v = serde_json::Value::deserialize(deserializer)?;
Ok(match &v {
serde_json::Value::String(s) => ActivationSpec::Named {
name: s.clone(),
negative_slope: None,
},
serde_json::Value::Object(map) => match map.get("type") {
Some(serde_json::Value::String(t)) => match map.get("negative_slope") {
None | Some(serde_json::Value::Null) => ActivationSpec::Named {
name: t.clone(),
negative_slope: None,
},
Some(slope) if slope.as_f64().is_some() => ActivationSpec::Named {
name: t.clone(),
negative_slope: slope.as_f64().map(|x| x as f32),
},
Some(_) => ActivationSpec::Unsupported(v.clone()),
},
_ => ActivationSpec::Unsupported(v),
},
_ => ActivationSpec::Unsupported(v),
})
}
}
pub const DEFAULT_SAMPLE_RATE: f64 = 48_000.0;
#[derive(Debug, Clone)]
pub struct NamModel {
pub version: String,
pub architecture: String,
pub config: ModelConfig,
pub weights: Vec<f32>,
pub sample_rate: Option<f64>,
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct LstmConfig {
pub input_size: usize,
pub hidden_size: usize,
pub num_layers: usize,
}
#[derive(Debug, Clone, Deserialize)]
pub struct SlimmableSubmodel {
pub max_value: f32,
pub model: NamModel,
}
#[derive(Debug, Clone, Deserialize)]
pub struct SlimmableConfig {
pub submodels: Vec<SlimmableSubmodel>,
}
#[derive(Debug, Clone)]
pub enum ModelConfig {
WaveNet(WaveNetConfig),
Lstm(LstmConfig),
Slimmable(SlimmableConfig),
}
impl<'de> Deserialize<'de> for NamModel {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
struct Raw {
version: String,
architecture: String,
config: serde_json::Value,
weights: Vec<f32>,
#[serde(default)]
sample_rate: Option<f64>,
#[serde(default)]
metadata: Option<serde_json::Value>,
}
let raw = Raw::deserialize(deserializer)?;
let config = match raw.architecture.as_str() {
"WaveNet" => {
let raw_wn: RawWaveNetConfig =
serde_json::from_value(raw.config).map_err(de::Error::custom)?;
ModelConfig::WaveNet(raw_wn.normalize().map_err(de::Error::custom)?)
}
"LSTM" => {
ModelConfig::Lstm(serde_json::from_value(raw.config).map_err(de::Error::custom)?)
}
"SlimmableContainer" => ModelConfig::Slimmable(
serde_json::from_value(raw.config).map_err(de::Error::custom)?,
),
other => {
return Err(de::Error::custom(format!(
"unsupported model architecture: {other:?}"
)))
}
};
Ok(NamModel {
version: raw.version,
architecture: raw.architecture,
config,
weights: raw.weights,
sample_rate: raw.sample_rate,
metadata: raw.metadata,
})
}
}
#[derive(Debug, Clone, Default, Deserialize)]
pub struct Metadata {
#[serde(default)]
pub loudness: Option<f32>,
#[serde(default)]
pub input_level_dbu: Option<f32>,
#[serde(default)]
pub output_level_dbu: Option<f32>,
}
impl NamModel {
pub fn from_file(path: impl AsRef<std::path::Path>) -> Result<Self, Error> {
Self::from_json_str(&std::fs::read_to_string(path)?)
}
pub fn from_json_str(json: &str) -> Result<Self, Error> {
Ok(serde_json::from_str(json)?)
}
#[must_use]
pub fn expected_sample_rate(&self) -> f64 {
self.sample_rate.unwrap_or(DEFAULT_SAMPLE_RATE)
}
#[must_use]
pub fn metadata_typed(&self) -> Metadata {
match &self.metadata {
Some(v) => serde_json::from_value(v.clone()).unwrap_or_default(),
None => Metadata::default(),
}
}
#[must_use]
pub fn loudness(&self) -> Option<f32> {
self.metadata_typed().loudness
}
#[must_use]
pub fn input_level_dbu(&self) -> Option<f32> {
self.metadata_typed().input_level_dbu
}
#[must_use]
pub fn output_level_dbu(&self) -> Option<f32> {
self.metadata_typed().output_level_dbu
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GatingMode {
None,
Gated,
Blended,
}
impl GatingMode {
pub(crate) fn from_name(s: &str) -> Result<Self, String> {
match s {
"none" => Ok(Self::None),
"gated" => Ok(Self::Gated),
"blended" => Ok(Self::Blended),
other => Err(format!("unknown gating_mode: {other:?}")),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Layer1x1Config {
pub active: bool,
pub groups: usize,
}
fn opt_usize(o: &serde_json::Value, key: &str) -> Option<usize> {
o.get(key).and_then(|x| x.as_u64()).map(|x| x as usize)
}
impl Layer1x1Config {
pub(crate) fn from_json(v: Option<&serde_json::Value>) -> Self {
match v {
None => Self {
active: true,
groups: 1,
},
Some(o) => Self {
active: o.get("active").and_then(|x| x.as_bool()).unwrap_or(true),
groups: opt_usize(o, "groups").unwrap_or(1),
},
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Head1x1Config {
pub active: bool,
pub out_channels: Option<usize>,
pub groups: usize,
}
impl Head1x1Config {
pub(crate) fn from_json(v: Option<&serde_json::Value>) -> Self {
match v {
None => Self {
active: false,
out_channels: None,
groups: 1,
},
Some(o) => Self {
active: o.get("active").and_then(|x| x.as_bool()).unwrap_or(false),
out_channels: opt_usize(o, "out_channels"),
groups: opt_usize(o, "groups").unwrap_or(1),
},
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct FilmConfig {
pub active: bool,
pub shift: bool,
pub groups: usize,
}
impl FilmConfig {
pub const INACTIVE: Self = Self {
active: false,
shift: false,
groups: 1,
};
pub(crate) fn from_json(v: Option<&serde_json::Value>) -> Self {
match v {
None => Self::INACTIVE,
Some(serde_json::Value::Bool(false)) => Self::INACTIVE,
Some(o) => Self {
active: o.get("active").and_then(|x| x.as_bool()).unwrap_or(true),
shift: o.get("shift").and_then(|x| x.as_bool()).unwrap_or(true),
groups: opt_usize(o, "groups").unwrap_or(1),
},
}
}
}
#[derive(Debug, Clone)]
pub struct PostStackHeadConfig {
pub channels: usize,
pub out_channels: usize,
pub kernel_sizes: Vec<usize>,
pub activation: ActivationSpec,
}
#[derive(Debug, Clone)]
pub struct WaveNetConfig {
pub layers: Vec<LayerArrayConfig>,
pub post_stack_head: Option<PostStackHeadConfig>,
pub head_scale: f32,
pub in_channels: usize,
pub condition_dsp: Option<Box<NamModel>>,
}
#[derive(serde::Deserialize)]
struct RawWaveNetConfig {
layers: Vec<RawLayerArrayConfig>,
#[serde(default)]
head: Option<serde_json::Value>,
head_scale: f32,
#[serde(default)]
in_channels: Option<usize>,
#[serde(default)]
condition_dsp: Option<serde_json::Value>,
}
impl RawWaveNetConfig {
fn normalize(self) -> Result<WaveNetConfig, String> {
let layers = self
.layers
.into_iter()
.map(RawLayerArrayConfig::normalize)
.collect::<Result<Vec<_>, _>>()?;
let post_stack_head = match self.head {
Some(h) if !h.is_null() => {
let channels =
h.get("channels")
.and_then(|x| x.as_u64())
.ok_or("post-stack head missing channels")? as usize;
let out_channels = h
.get("out_channels")
.and_then(|x| x.as_u64())
.ok_or("post-stack head missing out_channels")?
as usize;
let kernel_sizes: Vec<usize> = h
.get("kernel_sizes")
.and_then(|x| x.as_array())
.ok_or("post-stack head missing kernel_sizes")?
.iter()
.map(|k| {
k.as_u64()
.map(|v| v as usize)
.ok_or("kernel_sizes entry not an int".to_string())
})
.collect::<Result<_, _>>()?;
let activation = serde_json::from_value::<ActivationSpec>(
h.get("activation")
.cloned()
.unwrap_or(serde_json::Value::Null),
)
.map_err(|e| e.to_string())?;
Some(PostStackHeadConfig {
channels,
out_channels,
kernel_sizes,
activation,
})
}
_ => None,
};
let condition_dsp = match self.condition_dsp {
Some(v) if !v.is_null() => {
let m = serde_json::from_value::<NamModel>(v).map_err(|e| e.to_string())?;
Some(Box::new(m))
}
_ => None,
};
Ok(WaveNetConfig {
layers,
post_stack_head,
head_scale: self.head_scale,
in_channels: self.in_channels.unwrap_or(1),
condition_dsp,
})
}
}
#[derive(Debug, Clone)]
pub struct LayerArrayConfig {
pub input_size: usize,
pub condition_size: usize,
pub channels: usize,
pub bottleneck: usize,
pub dilations: Vec<usize>,
pub kernel_sizes: Vec<usize>,
pub activations: Vec<ActivationSpec>,
pub gating_modes: Vec<GatingMode>,
pub secondary_activations: Vec<ActivationSpec>,
pub groups_input: usize,
pub groups_input_mixin: usize,
pub head_size: usize,
pub head_kernel_size: usize,
pub head_bias: bool,
pub layer1x1: Layer1x1Config,
pub head1x1: Head1x1Config,
pub conv_pre_film: FilmConfig,
pub conv_post_film: FilmConfig,
pub input_mixin_pre_film: FilmConfig,
pub input_mixin_post_film: FilmConfig,
pub activation_pre_film: FilmConfig,
pub activation_post_film: FilmConfig,
pub layer1x1_post_film: FilmConfig,
pub head1x1_post_film: FilmConfig,
}
impl LayerArrayConfig {
pub fn gating_mode(&self) -> GatingMode {
self.gating_modes
.first()
.copied()
.unwrap_or(GatingMode::None)
}
}
#[derive(Debug, Clone, serde::Deserialize)]
pub(crate) struct RawLayerArrayConfig {
input_size: usize,
condition_size: usize,
channels: usize,
#[serde(default)]
bottleneck: Option<usize>,
dilations: Vec<usize>,
#[serde(default)]
kernel_size: Option<usize>,
#[serde(default)]
kernel_sizes: Option<Vec<usize>>,
activation: serde_json::Value,
#[serde(default)]
gating_mode: Option<serde_json::Value>,
#[serde(default)]
gated: Option<bool>,
#[serde(default)]
secondary_activation: Option<serde_json::Value>,
#[serde(default)]
groups_input: Option<usize>,
#[serde(default)]
groups_input_mixin: Option<usize>,
#[serde(default)]
head: Option<serde_json::Value>,
#[serde(default)]
head_size: Option<usize>,
#[serde(default)]
head_bias: Option<bool>,
#[serde(default)]
layer1x1: Option<serde_json::Value>,
#[serde(default)]
head1x1: Option<serde_json::Value>,
#[serde(default)]
conv_pre_film: Option<serde_json::Value>,
#[serde(default)]
conv_post_film: Option<serde_json::Value>,
#[serde(default)]
input_mixin_pre_film: Option<serde_json::Value>,
#[serde(default)]
input_mixin_post_film: Option<serde_json::Value>,
#[serde(default)]
activation_pre_film: Option<serde_json::Value>,
#[serde(default)]
activation_post_film: Option<serde_json::Value>,
#[serde(default)]
layer1x1_post_film: Option<serde_json::Value>,
#[serde(default)]
head1x1_post_film: Option<serde_json::Value>,
}
impl RawLayerArrayConfig {
pub(crate) fn normalize(self) -> Result<LayerArrayConfig, String> {
let n = self.dilations.len();
if n == 0 {
return Err("layer-array has no dilations".into());
}
let kernel_sizes = match (self.kernel_size, self.kernel_sizes) {
(Some(_), Some(_)) => {
return Err("layer-array specifies both kernel_size and kernel_sizes".into())
}
(Some(k), None) => vec![k; n],
(None, Some(ks)) => {
if ks.len() != n {
return Err(format!(
"kernel_sizes length {} != number of layers {n}",
ks.len()
));
}
ks
}
(None, None) => {
return Err("layer-array specifies neither kernel_size nor kernel_sizes".into())
}
};
let activations = broadcast_activations(&self.activation, n)?;
let gating_modes = match (&self.gating_mode, self.gated) {
(Some(v), _) => broadcast_gating(v, n)?,
(None, Some(true)) => vec![GatingMode::Gated; n],
(None, _) => vec![GatingMode::None; n],
};
let secondary_activations = match &self.secondary_activation {
Some(v) => broadcast_secondary(v, n)?,
None => vec![default_sigmoid(); n],
};
let (head_size, head_kernel_size, head_bias) = match &self.head {
Some(h) if !h.is_null() => {
let out = h
.get("out_channels")
.and_then(|x| x.as_u64())
.ok_or("layer head missing out_channels")? as usize;
let k = h
.get("kernel_size")
.and_then(|x| x.as_u64())
.ok_or("layer head missing kernel_size")? as usize;
let bias = h.get("bias").and_then(|x| x.as_bool()).unwrap_or(true);
(out, k, bias)
}
_ => {
let hs = self
.head_size
.ok_or("layer-array missing head_size (and no head object)")?;
(hs, 1, self.head_bias.unwrap_or(false))
}
};
if head_kernel_size == 0 {
return Err("layer-array head_kernel_size must be >= 1".into());
}
if self.channels == 0 {
return Err("layer-array channels must be >= 1".into());
}
if head_size == 0 {
return Err("layer-array head_size must be >= 1".into());
}
if kernel_sizes.contains(&0) {
return Err("layer-array kernel_sizes entries must be >= 1".into());
}
if self.dilations.contains(&0) {
return Err("layer-array dilations entries must be >= 1".into());
}
let bottleneck = self.bottleneck.unwrap_or(self.channels);
if bottleneck == 0 {
return Err("layer-array bottleneck must be >= 1".into());
}
let groups_input = self.groups_input.unwrap_or(1);
let groups_input_mixin = self.groups_input_mixin.unwrap_or(1);
let layer1x1 = Layer1x1Config::from_json(self.layer1x1.as_ref());
let head1x1 = Head1x1Config::from_json(self.head1x1.as_ref());
let films = [
FilmConfig::from_json(self.conv_pre_film.as_ref()),
FilmConfig::from_json(self.conv_post_film.as_ref()),
FilmConfig::from_json(self.input_mixin_pre_film.as_ref()),
FilmConfig::from_json(self.input_mixin_post_film.as_ref()),
FilmConfig::from_json(self.activation_pre_film.as_ref()),
FilmConfig::from_json(self.activation_post_film.as_ref()),
FilmConfig::from_json(self.layer1x1_post_film.as_ref()),
FilmConfig::from_json(self.head1x1_post_film.as_ref()),
];
let group_counts = [
("groups_input", groups_input),
("groups_input_mixin", groups_input_mixin),
("layer1x1.groups", layer1x1.groups),
("head1x1.groups", head1x1.groups),
(
"film.groups",
films.iter().map(|f| f.groups).min().unwrap_or(1),
),
];
for (name, g) in group_counts {
if g == 0 {
return Err(format!("layer-array {name} must be >= 1"));
}
}
let [conv_pre_film, conv_post_film, input_mixin_pre_film, input_mixin_post_film, activation_pre_film, activation_post_film, layer1x1_post_film, head1x1_post_film] =
films;
Ok(LayerArrayConfig {
input_size: self.input_size,
condition_size: self.condition_size,
channels: self.channels,
bottleneck,
dilations: self.dilations,
kernel_sizes,
activations,
gating_modes,
secondary_activations,
groups_input,
groups_input_mixin,
head_size,
head_kernel_size,
head_bias,
layer1x1,
head1x1,
conv_pre_film,
conv_post_film,
input_mixin_pre_film,
input_mixin_post_film,
activation_pre_film,
activation_post_film,
layer1x1_post_film,
head1x1_post_film,
})
}
}
fn default_sigmoid() -> ActivationSpec {
ActivationSpec::Named {
name: "Sigmoid".into(),
negative_slope: None,
}
}
fn broadcast<T: Clone>(
v: &serde_json::Value,
n: usize,
kind: &str,
parse: impl Fn(&serde_json::Value) -> Result<T, String>,
) -> Result<Vec<T>, String> {
match v {
serde_json::Value::Array(items) => {
if items.len() != n {
return Err(format!(
"{kind} list length {} != number of layers {n}",
items.len()
));
}
items.iter().map(&parse).collect()
}
other => Ok(vec![parse(other)?; n]),
}
}
fn parse_activation(e: &serde_json::Value) -> Result<ActivationSpec, String> {
serde_json::from_value::<ActivationSpec>(e.clone()).map_err(|e| e.to_string())
}
fn broadcast_activations(v: &serde_json::Value, n: usize) -> Result<Vec<ActivationSpec>, String> {
broadcast(v, n, "activation", parse_activation)
}
fn broadcast_secondary(v: &serde_json::Value, n: usize) -> Result<Vec<ActivationSpec>, String> {
broadcast(v, n, "secondary_activation", |e| {
if e.is_null() {
Ok(default_sigmoid())
} else {
parse_activation(e)
}
})
}
fn broadcast_gating(v: &serde_json::Value, n: usize) -> Result<Vec<GatingMode>, String> {
broadcast(v, n, "gating_mode", |e| {
e.as_str()
.ok_or_else(|| "gating_mode entry is not a string".to_string())
.and_then(GatingMode::from_name)
})
}
#[cfg(test)]
mod layer_array_normalize_tests {
use super::*;
fn norm(v: serde_json::Value) -> LayerArrayConfig {
let raw: RawLayerArrayConfig = serde_json::from_value(v).unwrap();
raw.normalize().unwrap()
}
#[test]
fn a1_layer_broadcasts_scalar_kernel_and_string_activation() {
let la = norm(serde_json::json!({
"input_size": 1, "condition_size": 1, "channels": 2, "head_size": 1,
"kernel_size": 3, "dilations": [1, 2, 4], "activation": "Tanh",
"gated": false, "head_bias": false
}));
assert_eq!(la.channels, 2);
assert_eq!(la.bottleneck, 2);
assert_eq!(la.kernel_sizes, vec![3, 3, 3]);
assert_eq!(la.gating_modes, vec![GatingMode::None; 3]);
assert_eq!(la.head_size, 1);
assert_eq!(la.head_kernel_size, 1);
assert!(!la.head_bias);
assert!(la.layer1x1.active);
assert!(!la.head1x1.active);
assert_eq!(la.groups_input, 1);
assert_eq!(la.activations.len(), 3);
assert!(matches!(&la.activations[0], ActivationSpec::Named { name, .. } if name == "Tanh"));
let g = norm(serde_json::json!({
"input_size": 1, "condition_size": 1, "channels": 2, "head_size": 1,
"kernel_size": 3, "dilations": [1], "activation": "Tanh",
"gated": true, "head_bias": true
}));
assert_eq!(g.gating_modes, vec![GatingMode::Gated]);
}
#[test]
fn a2_flexible_layer_parses_per_layer_vectors_and_nested_head() {
let la = norm(serde_json::json!({
"input_size": 1, "condition_size": 1, "channels": 3, "bottleneck": 3,
"dilations": [1, 3, 7],
"kernel_sizes": [6, 6, 15],
"activation": [
{"type": "LeakyReLU", "negative_slope": 0.01},
{"type": "LeakyReLU", "negative_slope": 0.01},
{"type": "LeakyReLU", "negative_slope": 0.01}
],
"head": {"out_channels": 1, "kernel_size": 16, "bias": true},
"head1x1": {"active": false, "out_channels": 1, "groups": 1},
"layer1x1": {"active": true, "groups": 1},
"groups_input": 1, "groups_input_mixin": 1,
"gating_mode": ["none", "none", "none"],
"secondary_activation": [null, null, null],
"conv_pre_film": {"active": false, "shift": true, "groups": 1}
}));
assert_eq!(la.kernel_sizes, vec![6, 6, 15]);
assert_eq!(la.gating_modes, vec![GatingMode::None; 3]);
assert_eq!(la.head_size, 1);
assert_eq!(la.head_kernel_size, 16);
assert!(la.head_bias);
assert_eq!(la.bottleneck, 3);
assert_eq!(la.activations.len(), 3);
assert!(!la.conv_pre_film.active);
}
#[test]
fn both_kernel_forms_is_an_error() {
let raw: RawLayerArrayConfig = serde_json::from_value(serde_json::json!({
"input_size": 1, "condition_size": 1, "channels": 1, "head_size": 1,
"kernel_size": 3, "kernel_sizes": [3], "dilations": [1],
"activation": "Tanh", "gated": false, "head_bias": false
}))
.unwrap();
assert!(raw.normalize().is_err());
}
#[test]
fn kernel_sizes_length_mismatch_is_an_error() {
let raw: RawLayerArrayConfig = serde_json::from_value(serde_json::json!({
"input_size": 1, "condition_size": 1, "channels": 1, "head_size": 1,
"kernel_sizes": [3, 3], "dilations": [1],
"activation": "Tanh", "gated": false, "head_bias": false
}))
.unwrap();
assert!(raw.normalize().is_err());
}
#[test]
fn activation_list_length_mismatch_is_an_error() {
let raw: RawLayerArrayConfig = serde_json::from_value(serde_json::json!({
"input_size": 1, "condition_size": 1, "channels": 1, "head_size": 1,
"kernel_size": 3, "dilations": [1, 2],
"activation": ["Tanh"], "gated": false, "head_bias": false
}))
.unwrap();
assert!(raw.normalize().is_err());
}
fn raw_layer_array(mutate: impl FnOnce(&mut serde_json::Value)) -> RawLayerArrayConfig {
let mut v = serde_json::json!({
"input_size": 1, "condition_size": 1, "channels": 1, "head_size": 1,
"kernel_size": 3, "dilations": [1],
"activation": "Tanh", "gated": false, "head_bias": false
});
mutate(&mut v);
serde_json::from_value(v).unwrap()
}
#[test]
fn baseline_raw_layer_array_normalizes() {
assert!(raw_layer_array(|_| {}).normalize().is_ok());
}
#[test]
fn zero_channels_is_an_error() {
let raw = raw_layer_array(|v| v["channels"] = serde_json::json!(0));
assert!(raw.normalize().is_err());
}
#[test]
fn zero_head_size_is_an_error() {
let raw = raw_layer_array(|v| v["head_size"] = serde_json::json!(0));
assert!(raw.normalize().is_err());
}
#[test]
fn zero_kernel_size_is_an_error() {
let raw = raw_layer_array(|v| v["kernel_size"] = serde_json::json!(0));
assert!(raw.normalize().is_err());
}
#[test]
fn zero_dilation_is_an_error() {
let raw = raw_layer_array(|v| v["dilations"] = serde_json::json!([0]));
assert!(raw.normalize().is_err());
}
#[test]
fn zero_bottleneck_is_an_error() {
let raw = raw_layer_array(|v| v["bottleneck"] = serde_json::json!(0));
assert!(raw.normalize().is_err());
}
#[test]
fn zero_groups_is_an_error() {
for field in ["groups_input", "groups_input_mixin"] {
let raw = raw_layer_array(|v| v[field] = serde_json::json!(0));
assert!(raw.normalize().is_err(), "{field} == 0 must error");
}
let raw = raw_layer_array(|v| {
v["layer1x1"] = serde_json::json!({ "active": true, "groups": 0 });
});
assert!(raw.normalize().is_err(), "layer1x1.groups == 0 must error");
}
#[test]
fn zero_head_kernel_size_is_an_error() {
let raw = raw_layer_array(|v| {
v.as_object_mut().unwrap().remove("head_size");
v["head"] = serde_json::json!({
"out_channels": 1, "kernel_size": 0, "activation": "ReLU"
});
});
assert!(raw.normalize().is_err());
}
}
#[cfg(test)]
mod a2_subconfig_tests {
use super::*;
#[test]
fn gating_mode_from_str() {
assert_eq!(GatingMode::from_name("none").unwrap(), GatingMode::None);
assert_eq!(GatingMode::from_name("gated").unwrap(), GatingMode::Gated);
assert_eq!(
GatingMode::from_name("blended").unwrap(),
GatingMode::Blended
);
assert!(GatingMode::from_name("wat").is_err());
}
#[test]
fn film_absent_or_false_is_inactive() {
assert_eq!(FilmConfig::from_json(None), FilmConfig::INACTIVE);
assert_eq!(
FilmConfig::from_json(Some(&serde_json::json!(false))),
FilmConfig::INACTIVE
);
}
#[test]
fn film_object_defaults_active_shift_groups() {
let v = serde_json::json!({});
let f = FilmConfig::from_json(Some(&v));
assert_eq!(
f,
FilmConfig {
active: true,
shift: true,
groups: 1
}
);
let v = serde_json::json!({"active": false, "shift": false, "groups": 2});
assert_eq!(
FilmConfig::from_json(Some(&v)),
FilmConfig {
active: false,
shift: false,
groups: 2
}
);
}
#[test]
fn layer1x1_defaults_active_true_groups_1() {
assert_eq!(
Layer1x1Config::from_json(None),
Layer1x1Config {
active: true,
groups: 1
}
);
let v = serde_json::json!({"active": true, "groups": 1});
assert_eq!(
Layer1x1Config::from_json(Some(&v)),
Layer1x1Config {
active: true,
groups: 1
}
);
}
#[test]
fn head1x1_defaults_inactive() {
let h = Head1x1Config::from_json(None);
assert_eq!(
h,
Head1x1Config {
active: false,
out_channels: None,
groups: 1
}
);
let v = serde_json::json!({"active": false, "out_channels": 1, "groups": 1});
assert_eq!(
Head1x1Config::from_json(Some(&v)),
Head1x1Config {
active: false,
out_channels: Some(1),
groups: 1
}
);
}
}
#[cfg(test)]
mod wavenet_config_tests {
use super::*;
fn parse(json: &str) -> WaveNetConfig {
match NamModel::from_json_str(json).unwrap().config {
ModelConfig::WaveNet(c) => c,
other => panic!("expected WaveNet, got {other:?}"),
}
}
#[test]
fn a1_config_parses_unchanged() {
let c = parse(
r#"{
"version":"0.5.4","architecture":"WaveNet","config":{
"layers":[{"input_size":1,"condition_size":1,"channels":2,"head_size":1,
"kernel_size":3,"dilations":[1,2],"activation":"Tanh",
"gated":false,"head_bias":false}],
"head":null,"head_scale":2.0},
"weights":[]}"#,
);
assert_eq!(c.layers.len(), 1);
assert_eq!(c.head_scale, 2.0);
assert!(c.post_stack_head.is_none());
assert!(c.condition_dsp.is_none());
assert_eq!(c.layers[0].kernel_sizes, vec![3, 3]);
}
#[test]
fn a2_flexible_container_submodel_config_parses() {
let c = parse(
r#"{
"version":"0.7.0","architecture":"WaveNet","config":{
"layers":[{"input_size":1,"condition_size":1,"channels":3,"bottleneck":3,
"dilations":[1,3,7],"kernel_sizes":[6,6,15],
"activation":[{"type":"LeakyReLU"},{"type":"LeakyReLU"},{"type":"LeakyReLU"}],
"head":{"out_channels":1,"kernel_size":16,"bias":true},
"head1x1":{"active":false},"layer1x1":{"active":true,"groups":1},
"gating_mode":["none","none","none"]}],
"head":null,"head_scale":0.5},
"weights":[]}"#,
);
assert_eq!(c.layers[0].head_kernel_size, 16);
assert_eq!(c.layers[0].kernel_sizes, vec![6, 6, 15]);
assert!(c.post_stack_head.is_none());
}
#[test]
fn post_stack_head_parses() {
let c = parse(
r#"{
"version":"0.6.0","architecture":"WaveNet","config":{
"layers":[{"input_size":1,"condition_size":1,"channels":2,"head_size":2,
"kernel_size":3,"dilations":[1],"activation":"Tanh",
"gated":false,"head_bias":false}],
"head":{"channels":4,"out_channels":1,"kernel_sizes":[1,1],"activation":"ReLU"},
"head_scale":1.0},
"weights":[]}"#,
);
let h = c.post_stack_head.expect("post-stack head present");
assert_eq!(h.channels, 4);
assert_eq!(h.out_channels, 1);
assert_eq!(h.kernel_sizes, vec![1, 1]);
}
#[test]
fn condition_dsp_parses_as_nested_model() {
let c = parse(
r#"{
"version":"0.6.0","architecture":"WaveNet","config":{
"layers":[{"input_size":1,"condition_size":1,"channels":2,"head_size":1,
"kernel_size":3,"dilations":[1],"activation":"Tanh",
"gated":false,"head_bias":false}],
"head":null,"head_scale":1.0,
"condition_dsp":{"version":"0.5.4","architecture":"WaveNet","config":{
"layers":[{"input_size":1,"condition_size":1,"channels":1,"head_size":1,
"kernel_size":1,"dilations":[1],"activation":"Tanh",
"gated":false,"head_bias":false}],
"head":null,"head_scale":1.0},"weights":[]}},
"weights":[]}"#,
);
let dsp = c.condition_dsp.expect("condition_dsp present");
assert_eq!(dsp.architecture, "WaveNet");
}
}