1use crate::error::Error;
5use crate::lstm::Lstm;
6use crate::model::{ModelConfig, NamModel};
7use crate::wavenet::WaveNet;
8
9#[non_exhaustive]
15#[derive(Debug)]
16pub enum Model {
17 WaveNet(WaveNet),
19 Lstm(Lstm),
21}
22
23impl Model {
24 pub fn from_nam(model: &NamModel) -> Result<Self, Error> {
26 match &model.config {
27 ModelConfig::WaveNet(_) => Ok(Model::WaveNet(WaveNet::new(model)?)),
28 ModelConfig::Lstm(_) => Ok(Model::Lstm(Lstm::new(model)?)),
29 }
30 }
31
32 pub fn process_buffer(&mut self, io: &mut [f32]) {
34 match self {
35 Model::WaveNet(w) => w.process_buffer(io),
36 Model::Lstm(l) => l.process_buffer(io),
37 }
38 }
39
40 pub fn process_sample(&mut self, x: f32) -> f32 {
42 match self {
43 Model::WaveNet(w) => w.process_sample(x),
44 Model::Lstm(l) => l.process_sample(x),
45 }
46 }
47
48 pub fn sample_rate(&self) -> f64 {
50 match self {
51 Model::WaveNet(w) => w.sample_rate(),
52 Model::Lstm(l) => l.sample_rate(),
53 }
54 }
55
56 pub fn reset(&mut self) {
58 match self {
59 Model::WaveNet(w) => w.reset(),
60 Model::Lstm(l) => l.reset(),
61 }
62 }
63}
64
65#[cfg(test)]
66mod tests {
67 use super::*;
68
69 const TINY_WAVENET: &str = r#"{
70 "version": "0.5.4", "architecture": "WaveNet",
71 "config": { "layers": [{
72 "input_size": 1, "condition_size": 1, "channels": 1, "head_size": 1,
73 "kernel_size": 1, "dilations": [1], "activation": "ReLU",
74 "gated": false, "head_bias": false
75 }], "head": null, "head_scale": 10.0 },
76 "weights": [1.0, 2.0, 0.5, 1.0, 3.0, 0.1, 0.5, 10.0]
77 }"#;
78
79 const TINY_LSTM: &str = r#"{
80 "version": "0.5.4", "architecture": "LSTM",
81 "config": { "input_size": 1, "hidden_size": 1, "num_layers": 1 },
82 "weights": [1.0,0.0, 0.0,0.0, 2.0,0.0, 0.0,0.0, 0.0,0.0,0.0,0.0, 0.0, 0.0, 3.0, 0.5]
83 }"#;
84
85 #[test]
86 fn from_nam_builds_wavenet() {
87 let m = NamModel::from_json_str(TINY_WAVENET).unwrap();
88 let mut model = Model::from_nam(&m).unwrap();
89 assert!(matches!(model, Model::WaveNet(_)));
90 let mut buf = [0.5_f32];
91 model.process_buffer(&mut buf);
92 assert!((buf[0] - 10.0).abs() < 1e-5, "got {}", buf[0]);
93 }
94
95 #[test]
96 fn from_nam_builds_lstm() {
97 let m = NamModel::from_json_str(TINY_LSTM).unwrap();
98 let mut model = Model::from_nam(&m).unwrap();
99 assert!(matches!(model, Model::Lstm(_)));
100 let mut buf = [0.5_f32];
101 model.process_buffer(&mut buf);
102 assert!((buf[0] - 1.1623).abs() < 1e-3, "got {}", buf[0]);
103 }
104}