1use serde::de::{self, Deserializer};
10use serde::Deserialize;
11
12use crate::error::Error;
13
14#[derive(Debug, Clone, PartialEq)]
22pub enum ActivationSpec {
23 Named {
25 name: String,
27 negative_slope: Option<f32>,
30 },
31 Unsupported(serde_json::Value),
33}
34
35impl<'de> Deserialize<'de> for ActivationSpec {
36 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
37 where
38 D: Deserializer<'de>,
39 {
40 let v = serde_json::Value::deserialize(deserializer)?;
41 Ok(match &v {
42 serde_json::Value::String(s) => ActivationSpec::Named {
43 name: s.clone(),
44 negative_slope: None,
45 },
46 serde_json::Value::Object(map) => match map.get("type") {
47 Some(serde_json::Value::String(t)) => match map.get("negative_slope") {
48 None | Some(serde_json::Value::Null) => ActivationSpec::Named {
50 name: t.clone(),
51 negative_slope: None,
52 },
53 Some(slope) if slope.as_f64().is_some() => ActivationSpec::Named {
55 name: t.clone(),
56 negative_slope: slope.as_f64().map(|x| x as f32),
57 },
58 Some(_) => ActivationSpec::Unsupported(v.clone()),
61 },
62 _ => ActivationSpec::Unsupported(v),
63 },
64 _ => ActivationSpec::Unsupported(v),
65 })
66 }
67}
68
69pub const DEFAULT_SAMPLE_RATE: f64 = 48_000.0;
73
74#[derive(Debug, Clone)]
79pub struct NamModel {
80 pub version: String,
82 pub architecture: String,
84 pub config: ModelConfig,
86 pub weights: Vec<f32>,
89 pub sample_rate: Option<f64>,
91 pub metadata: Option<serde_json::Value>,
93}
94
95#[derive(Debug, Clone, Deserialize)]
97pub struct LstmConfig {
98 pub input_size: usize,
100 pub hidden_size: usize,
102 pub num_layers: usize,
104}
105
106#[derive(Debug, Clone, Deserialize)]
109pub struct SlimmableSubmodel {
110 pub max_value: f32,
112 pub model: NamModel,
114}
115
116#[derive(Debug, Clone, Deserialize)]
119pub struct SlimmableConfig {
120 pub submodels: Vec<SlimmableSubmodel>,
122}
123
124#[derive(Debug, Clone)]
126pub enum ModelConfig {
127 WaveNet(WaveNetConfig),
130 Lstm(LstmConfig),
132 Slimmable(SlimmableConfig),
134}
135
136impl<'de> Deserialize<'de> for NamModel {
137 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
138 where
139 D: Deserializer<'de>,
140 {
141 #[derive(Deserialize)]
145 struct Raw {
146 version: String,
147 architecture: String,
148 config: serde_json::Value,
149 weights: Vec<f32>,
150 #[serde(default)]
151 sample_rate: Option<f64>,
152 #[serde(default)]
153 metadata: Option<serde_json::Value>,
154 }
155
156 let raw = Raw::deserialize(deserializer)?;
157 let config = match raw.architecture.as_str() {
158 "WaveNet" => {
159 let raw_wn: RawWaveNetConfig =
160 serde_json::from_value(raw.config).map_err(de::Error::custom)?;
161 ModelConfig::WaveNet(raw_wn.normalize().map_err(de::Error::custom)?)
162 }
163 "LSTM" => {
164 ModelConfig::Lstm(serde_json::from_value(raw.config).map_err(de::Error::custom)?)
165 }
166 "SlimmableContainer" => ModelConfig::Slimmable(
167 serde_json::from_value(raw.config).map_err(de::Error::custom)?,
168 ),
169 other => {
170 return Err(de::Error::custom(format!(
171 "unsupported model architecture: {other:?}"
172 )))
173 }
174 };
175
176 Ok(NamModel {
177 version: raw.version,
178 architecture: raw.architecture,
179 config,
180 weights: raw.weights,
181 sample_rate: raw.sample_rate,
182 metadata: raw.metadata,
183 })
184 }
185}
186
187#[derive(Debug, Clone, Default, Deserialize)]
190pub struct Metadata {
191 #[serde(default)]
193 pub loudness: Option<f32>,
194 #[serde(default)]
196 pub input_level_dbu: Option<f32>,
197 #[serde(default)]
199 pub output_level_dbu: Option<f32>,
200}
201
202impl NamModel {
203 pub fn from_file(path: impl AsRef<std::path::Path>) -> Result<Self, Error> {
209 Self::from_json_str(&std::fs::read_to_string(path)?)
210 }
211
212 pub fn from_json_str(json: &str) -> Result<Self, Error> {
214 Ok(serde_json::from_str(json)?)
215 }
216
217 #[must_use]
228 pub fn expected_sample_rate(&self) -> f64 {
229 self.sample_rate.unwrap_or(DEFAULT_SAMPLE_RATE)
230 }
231
232 #[must_use]
241 pub fn metadata_typed(&self) -> Metadata {
242 match &self.metadata {
243 Some(v) => serde_json::from_value(v.clone()).unwrap_or_default(),
244 None => Metadata::default(),
245 }
246 }
247
248 #[must_use]
250 pub fn loudness(&self) -> Option<f32> {
251 self.metadata_typed().loudness
252 }
253
254 #[must_use]
256 pub fn input_level_dbu(&self) -> Option<f32> {
257 self.metadata_typed().input_level_dbu
258 }
259
260 #[must_use]
262 pub fn output_level_dbu(&self) -> Option<f32> {
263 self.metadata_typed().output_level_dbu
264 }
265}
266
267#[derive(Debug, Clone, Copy, PartialEq, Eq)]
269pub enum GatingMode {
270 None,
272 Gated,
274 Blended,
276}
277
278impl GatingMode {
279 pub(crate) fn from_name(s: &str) -> Result<Self, String> {
281 match s {
282 "none" => Ok(Self::None),
283 "gated" => Ok(Self::Gated),
284 "blended" => Ok(Self::Blended),
285 other => Err(format!("unknown gating_mode: {other:?}")),
286 }
287 }
288}
289
290#[derive(Debug, Clone, Copy, PartialEq, Eq)]
293pub struct Layer1x1Config {
294 pub active: bool,
296 pub groups: usize,
298}
299
300fn opt_usize(o: &serde_json::Value, key: &str) -> Option<usize> {
304 o.get(key).and_then(|x| x.as_u64()).map(|x| x as usize)
305}
306
307impl Layer1x1Config {
308 pub(crate) fn from_json(v: Option<&serde_json::Value>) -> Self {
309 match v {
310 None => Self {
311 active: true,
312 groups: 1,
313 },
314 Some(o) => Self {
315 active: o.get("active").and_then(|x| x.as_bool()).unwrap_or(true),
316 groups: opt_usize(o, "groups").unwrap_or(1),
317 },
318 }
319 }
320}
321
322#[derive(Debug, Clone, Copy, PartialEq, Eq)]
326pub struct Head1x1Config {
327 pub active: bool,
329 pub out_channels: Option<usize>,
331 pub groups: usize,
333}
334
335impl Head1x1Config {
336 pub(crate) fn from_json(v: Option<&serde_json::Value>) -> Self {
337 match v {
338 None => Self {
339 active: false,
340 out_channels: None,
341 groups: 1,
342 },
343 Some(o) => Self {
344 active: o.get("active").and_then(|x| x.as_bool()).unwrap_or(false),
345 out_channels: opt_usize(o, "out_channels"),
346 groups: opt_usize(o, "groups").unwrap_or(1),
347 },
348 }
349 }
350}
351
352#[derive(Debug, Clone, Copy, PartialEq, Eq)]
355pub struct FilmConfig {
356 pub active: bool,
358 pub shift: bool,
360 pub groups: usize,
362}
363
364impl FilmConfig {
365 pub const INACTIVE: Self = Self {
367 active: false,
368 shift: false,
369 groups: 1,
370 };
371
372 pub(crate) fn from_json(v: Option<&serde_json::Value>) -> Self {
373 match v {
374 None => Self::INACTIVE,
375 Some(serde_json::Value::Bool(false)) => Self::INACTIVE,
376 Some(o) => Self {
377 active: o.get("active").and_then(|x| x.as_bool()).unwrap_or(true),
378 shift: o.get("shift").and_then(|x| x.as_bool()).unwrap_or(true),
379 groups: opt_usize(o, "groups").unwrap_or(1),
380 },
381 }
382 }
383}
384
385#[derive(Debug, Clone)]
388pub struct PostStackHeadConfig {
389 pub channels: usize,
391 pub out_channels: usize,
393 pub kernel_sizes: Vec<usize>,
395 pub activation: ActivationSpec,
397}
398
399#[derive(Debug, Clone)]
402pub struct WaveNetConfig {
403 pub layers: Vec<LayerArrayConfig>,
405 pub post_stack_head: Option<PostStackHeadConfig>,
407 pub head_scale: f32,
409 pub in_channels: usize,
411 pub condition_dsp: Option<Box<NamModel>>,
413}
414
415#[derive(serde::Deserialize)]
416struct RawWaveNetConfig {
417 layers: Vec<RawLayerArrayConfig>,
418 #[serde(default)]
419 head: Option<serde_json::Value>,
420 head_scale: f32,
421 #[serde(default)]
422 in_channels: Option<usize>,
423 #[serde(default)]
424 condition_dsp: Option<serde_json::Value>,
425}
426
427impl RawWaveNetConfig {
428 fn normalize(self) -> Result<WaveNetConfig, String> {
429 let layers = self
430 .layers
431 .into_iter()
432 .map(RawLayerArrayConfig::normalize)
433 .collect::<Result<Vec<_>, _>>()?;
434
435 let post_stack_head = match self.head {
436 Some(h) if !h.is_null() => {
437 let channels =
438 h.get("channels")
439 .and_then(|x| x.as_u64())
440 .ok_or("post-stack head missing channels")? as usize;
441 let out_channels = h
442 .get("out_channels")
443 .and_then(|x| x.as_u64())
444 .ok_or("post-stack head missing out_channels")?
445 as usize;
446 let kernel_sizes: Vec<usize> = h
447 .get("kernel_sizes")
448 .and_then(|x| x.as_array())
449 .ok_or("post-stack head missing kernel_sizes")?
450 .iter()
451 .map(|k| {
452 k.as_u64()
453 .map(|v| v as usize)
454 .ok_or("kernel_sizes entry not an int".to_string())
455 })
456 .collect::<Result<_, _>>()?;
457 let activation = serde_json::from_value::<ActivationSpec>(
458 h.get("activation")
459 .cloned()
460 .unwrap_or(serde_json::Value::Null),
461 )
462 .map_err(|e| e.to_string())?;
463 Some(PostStackHeadConfig {
464 channels,
465 out_channels,
466 kernel_sizes,
467 activation,
468 })
469 }
470 _ => None,
471 };
472
473 let condition_dsp = match self.condition_dsp {
474 Some(v) if !v.is_null() => {
475 let m = serde_json::from_value::<NamModel>(v).map_err(|e| e.to_string())?;
476 Some(Box::new(m))
477 }
478 _ => None,
479 };
480
481 Ok(WaveNetConfig {
482 layers,
483 post_stack_head,
484 head_scale: self.head_scale,
485 in_channels: self.in_channels.unwrap_or(1),
486 condition_dsp,
487 })
488 }
489}
490
491#[derive(Debug, Clone)]
496pub struct LayerArrayConfig {
497 pub input_size: usize,
499 pub condition_size: usize,
501 pub channels: usize,
503 pub bottleneck: usize,
505 pub dilations: Vec<usize>,
507 pub kernel_sizes: Vec<usize>,
509 pub activations: Vec<ActivationSpec>,
511 pub gating_modes: Vec<GatingMode>,
513 pub secondary_activations: Vec<ActivationSpec>,
516 pub groups_input: usize,
518 pub groups_input_mixin: usize,
520 pub head_size: usize,
522 pub head_kernel_size: usize,
524 pub head_bias: bool,
526 pub layer1x1: Layer1x1Config,
528 pub head1x1: Head1x1Config,
530 pub conv_pre_film: FilmConfig,
532 pub conv_post_film: FilmConfig,
534 pub input_mixin_pre_film: FilmConfig,
536 pub input_mixin_post_film: FilmConfig,
538 pub activation_pre_film: FilmConfig,
540 pub activation_post_film: FilmConfig,
542 pub layer1x1_post_film: FilmConfig,
544 pub head1x1_post_film: FilmConfig,
546}
547
548impl LayerArrayConfig {
549 pub fn gating_mode(&self) -> GatingMode {
558 self.gating_modes
559 .first()
560 .copied()
561 .unwrap_or(GatingMode::None)
562 }
563}
564
565#[derive(Debug, Clone, serde::Deserialize)]
568pub(crate) struct RawLayerArrayConfig {
569 input_size: usize,
570 condition_size: usize,
571 channels: usize,
572 #[serde(default)]
573 bottleneck: Option<usize>,
574 dilations: Vec<usize>,
575 #[serde(default)]
576 kernel_size: Option<usize>,
577 #[serde(default)]
578 kernel_sizes: Option<Vec<usize>>,
579 activation: serde_json::Value,
580 #[serde(default)]
581 gating_mode: Option<serde_json::Value>,
582 #[serde(default)]
583 gated: Option<bool>,
584 #[serde(default)]
585 secondary_activation: Option<serde_json::Value>,
586 #[serde(default)]
587 groups_input: Option<usize>,
588 #[serde(default)]
589 groups_input_mixin: Option<usize>,
590 #[serde(default)]
591 head: Option<serde_json::Value>,
592 #[serde(default)]
593 head_size: Option<usize>,
594 #[serde(default)]
595 head_bias: Option<bool>,
596 #[serde(default)]
597 layer1x1: Option<serde_json::Value>,
598 #[serde(default)]
599 head1x1: Option<serde_json::Value>,
600 #[serde(default)]
601 conv_pre_film: Option<serde_json::Value>,
602 #[serde(default)]
603 conv_post_film: Option<serde_json::Value>,
604 #[serde(default)]
605 input_mixin_pre_film: Option<serde_json::Value>,
606 #[serde(default)]
607 input_mixin_post_film: Option<serde_json::Value>,
608 #[serde(default)]
609 activation_pre_film: Option<serde_json::Value>,
610 #[serde(default)]
611 activation_post_film: Option<serde_json::Value>,
612 #[serde(default)]
613 layer1x1_post_film: Option<serde_json::Value>,
614 #[serde(default)]
615 head1x1_post_film: Option<serde_json::Value>,
616}
617
618impl RawLayerArrayConfig {
619 pub(crate) fn normalize(self) -> Result<LayerArrayConfig, String> {
620 let n = self.dilations.len();
621 if n == 0 {
622 return Err("layer-array has no dilations".into());
623 }
624
625 let kernel_sizes = match (self.kernel_size, self.kernel_sizes) {
626 (Some(_), Some(_)) => {
627 return Err("layer-array specifies both kernel_size and kernel_sizes".into())
628 }
629 (Some(k), None) => vec![k; n],
630 (None, Some(ks)) => {
631 if ks.len() != n {
632 return Err(format!(
633 "kernel_sizes length {} != number of layers {n}",
634 ks.len()
635 ));
636 }
637 ks
638 }
639 (None, None) => {
640 return Err("layer-array specifies neither kernel_size nor kernel_sizes".into())
641 }
642 };
643
644 let activations = broadcast_activations(&self.activation, n)?;
645
646 let gating_modes = match (&self.gating_mode, self.gated) {
650 (Some(v), _) => broadcast_gating(v, n)?,
651 (None, Some(true)) => vec![GatingMode::Gated; n],
652 (None, _) => vec![GatingMode::None; n],
653 };
654
655 let secondary_activations = match &self.secondary_activation {
656 Some(v) => broadcast_secondary(v, n)?,
657 None => vec![default_sigmoid(); n],
658 };
659
660 let (head_size, head_kernel_size, head_bias) = match &self.head {
661 Some(h) if !h.is_null() => {
662 let out = h
663 .get("out_channels")
664 .and_then(|x| x.as_u64())
665 .ok_or("layer head missing out_channels")? as usize;
666 let k = h
667 .get("kernel_size")
668 .and_then(|x| x.as_u64())
669 .ok_or("layer head missing kernel_size")? as usize;
670 let bias = h.get("bias").and_then(|x| x.as_bool()).unwrap_or(true);
675 (out, k, bias)
676 }
677 _ => {
678 let hs = self
679 .head_size
680 .ok_or("layer-array missing head_size (and no head object)")?;
681 (hs, 1, self.head_bias.unwrap_or(false))
682 }
683 };
684
685 if head_kernel_size == 0 {
691 return Err("layer-array head_kernel_size must be >= 1".into());
692 }
693 if self.channels == 0 {
694 return Err("layer-array channels must be >= 1".into());
695 }
696 if head_size == 0 {
697 return Err("layer-array head_size must be >= 1".into());
698 }
699 if kernel_sizes.contains(&0) {
700 return Err("layer-array kernel_sizes entries must be >= 1".into());
701 }
702 if self.dilations.contains(&0) {
703 return Err("layer-array dilations entries must be >= 1".into());
704 }
705 let bottleneck = self.bottleneck.unwrap_or(self.channels);
706 if bottleneck == 0 {
707 return Err("layer-array bottleneck must be >= 1".into());
708 }
709
710 let groups_input = self.groups_input.unwrap_or(1);
711 let groups_input_mixin = self.groups_input_mixin.unwrap_or(1);
712 let layer1x1 = Layer1x1Config::from_json(self.layer1x1.as_ref());
713 let head1x1 = Head1x1Config::from_json(self.head1x1.as_ref());
714 let films = [
715 FilmConfig::from_json(self.conv_pre_film.as_ref()),
716 FilmConfig::from_json(self.conv_post_film.as_ref()),
717 FilmConfig::from_json(self.input_mixin_pre_film.as_ref()),
718 FilmConfig::from_json(self.input_mixin_post_film.as_ref()),
719 FilmConfig::from_json(self.activation_pre_film.as_ref()),
720 FilmConfig::from_json(self.activation_post_film.as_ref()),
721 FilmConfig::from_json(self.layer1x1_post_film.as_ref()),
722 FilmConfig::from_json(self.head1x1_post_film.as_ref()),
723 ];
724 let group_counts = [
730 ("groups_input", groups_input),
731 ("groups_input_mixin", groups_input_mixin),
732 ("layer1x1.groups", layer1x1.groups),
733 ("head1x1.groups", head1x1.groups),
734 (
735 "film.groups",
736 films.iter().map(|f| f.groups).min().unwrap_or(1),
737 ),
738 ];
739 for (name, g) in group_counts {
740 if g == 0 {
741 return Err(format!("layer-array {name} must be >= 1"));
742 }
743 }
744 let [conv_pre_film, conv_post_film, input_mixin_pre_film, input_mixin_post_film, activation_pre_film, activation_post_film, layer1x1_post_film, head1x1_post_film] =
745 films;
746
747 Ok(LayerArrayConfig {
748 input_size: self.input_size,
749 condition_size: self.condition_size,
750 channels: self.channels,
751 bottleneck,
752 dilations: self.dilations,
753 kernel_sizes,
754 activations,
755 gating_modes,
756 secondary_activations,
757 groups_input,
758 groups_input_mixin,
759 head_size,
760 head_kernel_size,
761 head_bias,
762 layer1x1,
763 head1x1,
764 conv_pre_film,
765 conv_post_film,
766 input_mixin_pre_film,
767 input_mixin_post_film,
768 activation_pre_film,
769 activation_post_film,
770 layer1x1_post_film,
771 head1x1_post_film,
772 })
773 }
774}
775
776fn default_sigmoid() -> ActivationSpec {
778 ActivationSpec::Named {
779 name: "Sigmoid".into(),
780 negative_slope: None,
781 }
782}
783
784fn broadcast<T: Clone>(
789 v: &serde_json::Value,
790 n: usize,
791 kind: &str,
792 parse: impl Fn(&serde_json::Value) -> Result<T, String>,
793) -> Result<Vec<T>, String> {
794 match v {
795 serde_json::Value::Array(items) => {
796 if items.len() != n {
797 return Err(format!(
798 "{kind} list length {} != number of layers {n}",
799 items.len()
800 ));
801 }
802 items.iter().map(&parse).collect()
803 }
804 other => Ok(vec![parse(other)?; n]),
805 }
806}
807
808fn parse_activation(e: &serde_json::Value) -> Result<ActivationSpec, String> {
809 serde_json::from_value::<ActivationSpec>(e.clone()).map_err(|e| e.to_string())
810}
811
812fn broadcast_activations(v: &serde_json::Value, n: usize) -> Result<Vec<ActivationSpec>, String> {
813 broadcast(v, n, "activation", parse_activation)
814}
815
816fn broadcast_secondary(v: &serde_json::Value, n: usize) -> Result<Vec<ActivationSpec>, String> {
819 broadcast(v, n, "secondary_activation", |e| {
820 if e.is_null() {
821 Ok(default_sigmoid())
822 } else {
823 parse_activation(e)
824 }
825 })
826}
827
828fn broadcast_gating(v: &serde_json::Value, n: usize) -> Result<Vec<GatingMode>, String> {
830 broadcast(v, n, "gating_mode", |e| {
831 e.as_str()
832 .ok_or_else(|| "gating_mode entry is not a string".to_string())
833 .and_then(GatingMode::from_name)
834 })
835}
836
837#[cfg(test)]
838mod layer_array_normalize_tests {
839 use super::*;
840
841 fn norm(v: serde_json::Value) -> LayerArrayConfig {
842 let raw: RawLayerArrayConfig = serde_json::from_value(v).unwrap();
843 raw.normalize().unwrap()
844 }
845
846 #[test]
847 fn a1_layer_broadcasts_scalar_kernel_and_string_activation() {
848 let la = norm(serde_json::json!({
849 "input_size": 1, "condition_size": 1, "channels": 2, "head_size": 1,
850 "kernel_size": 3, "dilations": [1, 2, 4], "activation": "Tanh",
851 "gated": false, "head_bias": false
852 }));
853 assert_eq!(la.channels, 2);
854 assert_eq!(la.bottleneck, 2);
855 assert_eq!(la.kernel_sizes, vec![3, 3, 3]);
856 assert_eq!(la.gating_modes, vec![GatingMode::None; 3]);
857 assert_eq!(la.head_size, 1);
858 assert_eq!(la.head_kernel_size, 1);
859 assert!(!la.head_bias);
860 assert!(la.layer1x1.active);
861 assert!(!la.head1x1.active);
862 assert_eq!(la.groups_input, 1);
863 assert_eq!(la.activations.len(), 3);
864 assert!(matches!(&la.activations[0], ActivationSpec::Named { name, .. } if name == "Tanh"));
865 let g = norm(serde_json::json!({
866 "input_size": 1, "condition_size": 1, "channels": 2, "head_size": 1,
867 "kernel_size": 3, "dilations": [1], "activation": "Tanh",
868 "gated": true, "head_bias": true
869 }));
870 assert_eq!(g.gating_modes, vec![GatingMode::Gated]);
871 }
872
873 #[test]
874 fn a2_flexible_layer_parses_per_layer_vectors_and_nested_head() {
875 let la = norm(serde_json::json!({
876 "input_size": 1, "condition_size": 1, "channels": 3, "bottleneck": 3,
877 "dilations": [1, 3, 7],
878 "kernel_sizes": [6, 6, 15],
879 "activation": [
880 {"type": "LeakyReLU", "negative_slope": 0.01},
881 {"type": "LeakyReLU", "negative_slope": 0.01},
882 {"type": "LeakyReLU", "negative_slope": 0.01}
883 ],
884 "head": {"out_channels": 1, "kernel_size": 16, "bias": true},
885 "head1x1": {"active": false, "out_channels": 1, "groups": 1},
886 "layer1x1": {"active": true, "groups": 1},
887 "groups_input": 1, "groups_input_mixin": 1,
888 "gating_mode": ["none", "none", "none"],
889 "secondary_activation": [null, null, null],
890 "conv_pre_film": {"active": false, "shift": true, "groups": 1}
891 }));
892 assert_eq!(la.kernel_sizes, vec![6, 6, 15]);
893 assert_eq!(la.gating_modes, vec![GatingMode::None; 3]);
894 assert_eq!(la.head_size, 1);
895 assert_eq!(la.head_kernel_size, 16);
896 assert!(la.head_bias);
897 assert_eq!(la.bottleneck, 3);
898 assert_eq!(la.activations.len(), 3);
899 assert!(!la.conv_pre_film.active);
900 }
901
902 #[test]
903 fn both_kernel_forms_is_an_error() {
904 let raw: RawLayerArrayConfig = serde_json::from_value(serde_json::json!({
905 "input_size": 1, "condition_size": 1, "channels": 1, "head_size": 1,
906 "kernel_size": 3, "kernel_sizes": [3], "dilations": [1],
907 "activation": "Tanh", "gated": false, "head_bias": false
908 }))
909 .unwrap();
910 assert!(raw.normalize().is_err());
911 }
912
913 #[test]
914 fn kernel_sizes_length_mismatch_is_an_error() {
915 let raw: RawLayerArrayConfig = serde_json::from_value(serde_json::json!({
916 "input_size": 1, "condition_size": 1, "channels": 1, "head_size": 1,
917 "kernel_sizes": [3, 3], "dilations": [1],
918 "activation": "Tanh", "gated": false, "head_bias": false
919 }))
920 .unwrap();
921 assert!(raw.normalize().is_err());
922 }
923
924 #[test]
925 fn activation_list_length_mismatch_is_an_error() {
926 let raw: RawLayerArrayConfig = serde_json::from_value(serde_json::json!({
927 "input_size": 1, "condition_size": 1, "channels": 1, "head_size": 1,
928 "kernel_size": 3, "dilations": [1, 2],
929 "activation": ["Tanh"], "gated": false, "head_bias": false
930 }))
931 .unwrap();
932 assert!(raw.normalize().is_err());
933 }
934
935 fn raw_layer_array(mutate: impl FnOnce(&mut serde_json::Value)) -> RawLayerArrayConfig {
938 let mut v = serde_json::json!({
939 "input_size": 1, "condition_size": 1, "channels": 1, "head_size": 1,
940 "kernel_size": 3, "dilations": [1],
941 "activation": "Tanh", "gated": false, "head_bias": false
942 });
943 mutate(&mut v);
944 serde_json::from_value(v).unwrap()
945 }
946
947 #[test]
948 fn baseline_raw_layer_array_normalizes() {
949 assert!(raw_layer_array(|_| {}).normalize().is_ok());
952 }
953
954 #[test]
955 fn zero_channels_is_an_error() {
956 let raw = raw_layer_array(|v| v["channels"] = serde_json::json!(0));
957 assert!(raw.normalize().is_err());
958 }
959
960 #[test]
961 fn zero_head_size_is_an_error() {
962 let raw = raw_layer_array(|v| v["head_size"] = serde_json::json!(0));
963 assert!(raw.normalize().is_err());
964 }
965
966 #[test]
967 fn zero_kernel_size_is_an_error() {
968 let raw = raw_layer_array(|v| v["kernel_size"] = serde_json::json!(0));
969 assert!(raw.normalize().is_err());
970 }
971
972 #[test]
973 fn zero_dilation_is_an_error() {
974 let raw = raw_layer_array(|v| v["dilations"] = serde_json::json!([0]));
975 assert!(raw.normalize().is_err());
976 }
977
978 #[test]
979 fn zero_bottleneck_is_an_error() {
980 let raw = raw_layer_array(|v| v["bottleneck"] = serde_json::json!(0));
984 assert!(raw.normalize().is_err());
985 }
986
987 #[test]
988 fn zero_groups_is_an_error() {
989 for field in ["groups_input", "groups_input_mixin"] {
991 let raw = raw_layer_array(|v| v[field] = serde_json::json!(0));
992 assert!(raw.normalize().is_err(), "{field} == 0 must error");
993 }
994 let raw = raw_layer_array(|v| {
995 v["layer1x1"] = serde_json::json!({ "active": true, "groups": 0 });
996 });
997 assert!(raw.normalize().is_err(), "layer1x1.groups == 0 must error");
998 }
999
1000 #[test]
1001 fn zero_head_kernel_size_is_an_error() {
1002 let raw = raw_layer_array(|v| {
1003 v.as_object_mut().unwrap().remove("head_size");
1004 v["head"] = serde_json::json!({
1005 "out_channels": 1, "kernel_size": 0, "activation": "ReLU"
1006 });
1007 });
1008 assert!(raw.normalize().is_err());
1009 }
1010}
1011
1012#[cfg(test)]
1013mod a2_subconfig_tests {
1014 use super::*;
1015
1016 #[test]
1017 fn gating_mode_from_str() {
1018 assert_eq!(GatingMode::from_name("none").unwrap(), GatingMode::None);
1019 assert_eq!(GatingMode::from_name("gated").unwrap(), GatingMode::Gated);
1020 assert_eq!(
1021 GatingMode::from_name("blended").unwrap(),
1022 GatingMode::Blended
1023 );
1024 assert!(GatingMode::from_name("wat").is_err());
1025 }
1026
1027 #[test]
1028 fn film_absent_or_false_is_inactive() {
1029 assert_eq!(FilmConfig::from_json(None), FilmConfig::INACTIVE);
1030 assert_eq!(
1031 FilmConfig::from_json(Some(&serde_json::json!(false))),
1032 FilmConfig::INACTIVE
1033 );
1034 }
1035
1036 #[test]
1037 fn film_object_defaults_active_shift_groups() {
1038 let v = serde_json::json!({});
1039 let f = FilmConfig::from_json(Some(&v));
1040 assert_eq!(
1041 f,
1042 FilmConfig {
1043 active: true,
1044 shift: true,
1045 groups: 1
1046 }
1047 );
1048 let v = serde_json::json!({"active": false, "shift": false, "groups": 2});
1049 assert_eq!(
1050 FilmConfig::from_json(Some(&v)),
1051 FilmConfig {
1052 active: false,
1053 shift: false,
1054 groups: 2
1055 }
1056 );
1057 }
1058
1059 #[test]
1060 fn layer1x1_defaults_active_true_groups_1() {
1061 assert_eq!(
1062 Layer1x1Config::from_json(None),
1063 Layer1x1Config {
1064 active: true,
1065 groups: 1
1066 }
1067 );
1068 let v = serde_json::json!({"active": true, "groups": 1});
1069 assert_eq!(
1070 Layer1x1Config::from_json(Some(&v)),
1071 Layer1x1Config {
1072 active: true,
1073 groups: 1
1074 }
1075 );
1076 }
1077
1078 #[test]
1079 fn head1x1_defaults_inactive() {
1080 let h = Head1x1Config::from_json(None);
1081 assert_eq!(
1082 h,
1083 Head1x1Config {
1084 active: false,
1085 out_channels: None,
1086 groups: 1
1087 }
1088 );
1089 let v = serde_json::json!({"active": false, "out_channels": 1, "groups": 1});
1090 assert_eq!(
1091 Head1x1Config::from_json(Some(&v)),
1092 Head1x1Config {
1093 active: false,
1094 out_channels: Some(1),
1095 groups: 1
1096 }
1097 );
1098 }
1099}
1100
1101#[cfg(test)]
1102mod wavenet_config_tests {
1103 use super::*;
1104
1105 fn parse(json: &str) -> WaveNetConfig {
1106 match NamModel::from_json_str(json).unwrap().config {
1107 ModelConfig::WaveNet(c) => c,
1108 other => panic!("expected WaveNet, got {other:?}"),
1109 }
1110 }
1111
1112 #[test]
1113 fn a1_config_parses_unchanged() {
1114 let c = parse(
1115 r#"{
1116 "version":"0.5.4","architecture":"WaveNet","config":{
1117 "layers":[{"input_size":1,"condition_size":1,"channels":2,"head_size":1,
1118 "kernel_size":3,"dilations":[1,2],"activation":"Tanh",
1119 "gated":false,"head_bias":false}],
1120 "head":null,"head_scale":2.0},
1121 "weights":[]}"#,
1122 );
1123 assert_eq!(c.layers.len(), 1);
1124 assert_eq!(c.head_scale, 2.0);
1125 assert!(c.post_stack_head.is_none());
1126 assert!(c.condition_dsp.is_none());
1127 assert_eq!(c.layers[0].kernel_sizes, vec![3, 3]);
1128 }
1129
1130 #[test]
1131 fn a2_flexible_container_submodel_config_parses() {
1132 let c = parse(
1133 r#"{
1134 "version":"0.7.0","architecture":"WaveNet","config":{
1135 "layers":[{"input_size":1,"condition_size":1,"channels":3,"bottleneck":3,
1136 "dilations":[1,3,7],"kernel_sizes":[6,6,15],
1137 "activation":[{"type":"LeakyReLU"},{"type":"LeakyReLU"},{"type":"LeakyReLU"}],
1138 "head":{"out_channels":1,"kernel_size":16,"bias":true},
1139 "head1x1":{"active":false},"layer1x1":{"active":true,"groups":1},
1140 "gating_mode":["none","none","none"]}],
1141 "head":null,"head_scale":0.5},
1142 "weights":[]}"#,
1143 );
1144 assert_eq!(c.layers[0].head_kernel_size, 16);
1145 assert_eq!(c.layers[0].kernel_sizes, vec![6, 6, 15]);
1146 assert!(c.post_stack_head.is_none());
1147 }
1148
1149 #[test]
1150 fn post_stack_head_parses() {
1151 let c = parse(
1152 r#"{
1153 "version":"0.6.0","architecture":"WaveNet","config":{
1154 "layers":[{"input_size":1,"condition_size":1,"channels":2,"head_size":2,
1155 "kernel_size":3,"dilations":[1],"activation":"Tanh",
1156 "gated":false,"head_bias":false}],
1157 "head":{"channels":4,"out_channels":1,"kernel_sizes":[1,1],"activation":"ReLU"},
1158 "head_scale":1.0},
1159 "weights":[]}"#,
1160 );
1161 let h = c.post_stack_head.expect("post-stack head present");
1162 assert_eq!(h.channels, 4);
1163 assert_eq!(h.out_channels, 1);
1164 assert_eq!(h.kernel_sizes, vec![1, 1]);
1165 }
1166
1167 #[test]
1168 fn condition_dsp_parses_as_nested_model() {
1169 let c = parse(
1170 r#"{
1171 "version":"0.6.0","architecture":"WaveNet","config":{
1172 "layers":[{"input_size":1,"condition_size":1,"channels":2,"head_size":1,
1173 "kernel_size":3,"dilations":[1],"activation":"Tanh",
1174 "gated":false,"head_bias":false}],
1175 "head":null,"head_scale":1.0,
1176 "condition_dsp":{"version":"0.5.4","architecture":"WaveNet","config":{
1177 "layers":[{"input_size":1,"condition_size":1,"channels":1,"head_size":1,
1178 "kernel_size":1,"dilations":[1],"activation":"Tanh",
1179 "gated":false,"head_bias":false}],
1180 "head":null,"head_scale":1.0},"weights":[]}},
1181 "weights":[]}"#,
1182 );
1183 let dsp = c.condition_dsp.expect("condition_dsp present");
1184 assert_eq!(dsp.architecture, "WaveNet");
1185 }
1186}