use crate::error::Error;
use crate::lstm::Lstm;
use crate::model::{ModelConfig, NamModel};
use crate::wavenet::WaveNet;
#[non_exhaustive]
#[derive(Debug)]
pub enum Model {
WaveNet(Box<WaveNet>),
Lstm(Lstm),
Slimmable(Slimmable),
}
#[derive(Debug)]
pub struct Slimmable {
submodels: Vec<Model>,
max_values: Vec<f32>,
active: usize,
}
impl Slimmable {
pub fn len(&self) -> usize {
self.submodels.len()
}
pub fn is_empty(&self) -> bool {
self.submodels.is_empty()
}
pub fn active_index(&self) -> usize {
self.active
}
pub fn select(&mut self, index: usize) {
self.active = index.min(self.submodels.len() - 1);
}
pub fn set_slim_size(&mut self, value: f32) {
self.active = self
.max_values
.iter()
.position(|&m| m > value)
.unwrap_or(self.submodels.len() - 1);
}
}
impl Model {
pub fn from_nam(model: &NamModel) -> Result<Self, Error> {
match &model.config {
ModelConfig::WaveNet(_) => Ok(Model::WaveNet(Box::new(WaveNet::new(model)?))),
ModelConfig::Lstm(_) => Ok(Model::Lstm(Lstm::new(model)?)),
ModelConfig::Slimmable(cfg) => {
if cfg.submodels.is_empty() {
return Err(Error::UnsupportedFeature("empty SlimmableContainer".into()));
}
let mut submodels = Vec::with_capacity(cfg.submodels.len());
let mut max_values = Vec::with_capacity(cfg.submodels.len());
for sm in &cfg.submodels {
submodels.push(Model::from_nam(&sm.model)?);
max_values.push(sm.max_value);
}
let active = submodels.len() - 1; Ok(Model::Slimmable(Slimmable {
submodels,
max_values,
active,
}))
}
}
}
pub(crate) fn from_nam_conditioning(model: &NamModel) -> Result<Self, Error> {
match &model.config {
ModelConfig::WaveNet(_) => {
Ok(Model::WaveNet(Box::new(WaveNet::new_conditioning(model)?)))
}
_ => Model::from_nam(model),
}
}
pub fn process_buffer(&mut self, io: &mut [f32]) {
match self {
Model::WaveNet(w) => w.process_buffer(io),
Model::Lstm(l) => l.process_buffer(io),
Model::Slimmable(s) => s.submodels[s.active].process_buffer(io),
}
}
pub fn process_sample(&mut self, x: f32) -> f32 {
match self {
Model::WaveNet(w) => w.process_sample(x),
Model::Lstm(l) => l.process_sample(x),
Model::Slimmable(s) => s.submodels[s.active].process_sample(x),
}
}
pub fn reset(&mut self) {
match self {
Model::WaveNet(w) => w.reset(),
Model::Lstm(l) => l.reset(),
Model::Slimmable(s) => s.submodels.iter_mut().for_each(Model::reset),
}
}
pub fn receptive_field(&self) -> usize {
match self {
Model::WaveNet(w) => w.receptive_field(),
Model::Lstm(_) => 0,
Model::Slimmable(s) => s.submodels[s.active].receptive_field(),
}
}
pub(crate) fn num_output_channels(&self) -> usize {
match self {
Model::WaveNet(w) => w.num_output_channels(),
Model::Lstm(_) => 1,
Model::Slimmable(s) => s.submodels[s.active].num_output_channels(),
}
}
pub(crate) fn process_block_multi(&mut self, input: &[f32], out: &mut [f32], n: usize) {
match self {
Model::WaveNet(w) => w.process_block_multi(input, out, n),
Model::Lstm(l) => {
out[..n].copy_from_slice(&input[..n]);
l.process_buffer(&mut out[..n]);
}
Model::Slimmable(s) => s.submodels[s.active].process_block_multi(input, out, n),
}
}
pub fn as_slimmable(&self) -> Option<&Slimmable> {
match self {
Model::Slimmable(s) => Some(s),
_ => None,
}
}
pub fn as_slimmable_mut(&mut self) -> Option<&mut Slimmable> {
match self {
Model::Slimmable(s) => Some(s),
_ => None,
}
}
}
const _: () = {
fn assert_send_sync<T: Send + Sync>() {}
let _ = assert_send_sync::<Model>;
let _ = assert_send_sync::<WaveNet>;
let _ = assert_send_sync::<Lstm>;
let _ = assert_send_sync::<Slimmable>;
};
#[cfg(test)]
mod tests {
use super::*;
const TINY_WAVENET: &str = r#"{
"version": "0.5.4", "architecture": "WaveNet",
"config": { "layers": [{
"input_size": 1, "condition_size": 1, "channels": 1, "head_size": 1,
"kernel_size": 1, "dilations": [1], "activation": "ReLU",
"gated": false, "head_bias": false
}], "head": null, "head_scale": 10.0 },
"weights": [1.0, 2.0, 0.5, 1.0, 3.0, 0.1, 0.5, 10.0]
}"#;
const TINY_LSTM: &str = r#"{
"version": "0.5.4", "architecture": "LSTM",
"config": { "input_size": 1, "hidden_size": 1, "num_layers": 1 },
"weights": [1.0,0.0, 0.0,0.0, 2.0,0.0, 0.0,0.0, 0.0,0.0,0.0,0.0, 0.0, 0.0, 3.0, 0.5]
}"#;
#[test]
fn from_nam_builds_wavenet() {
let m = NamModel::from_json_str(TINY_WAVENET).unwrap();
let mut model = Model::from_nam(&m).unwrap();
assert!(matches!(model, Model::WaveNet(_)));
let mut buf = [0.5_f32];
model.process_buffer(&mut buf);
assert!((buf[0] - 10.0).abs() < 1e-5, "got {}", buf[0]);
}
#[test]
fn receptive_field_zero_for_lstm_warmup_for_wavenet() {
let wn = Model::from_nam(&NamModel::from_json_str(TINY_WAVENET).unwrap()).unwrap();
assert_eq!(wn.receptive_field(), 1);
let lstm = Model::from_nam(&NamModel::from_json_str(TINY_LSTM).unwrap()).unwrap();
assert_eq!(lstm.receptive_field(), 0);
}
#[test]
fn from_nam_builds_lstm() {
let m = NamModel::from_json_str(TINY_LSTM).unwrap();
let mut model = Model::from_nam(&m).unwrap();
assert!(matches!(model, Model::Lstm(_)));
let mut buf = [0.5_f32];
model.process_buffer(&mut buf);
assert!((buf[0] - 1.1623).abs() < 1e-3, "got {}", buf[0]);
}
fn container() -> Model {
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/fixtures/slimmable_container.nam");
let json = std::fs::read_to_string(path).expect("read container");
let m = NamModel::from_json_str(&json).expect("parse container");
Model::from_nam(&m).expect("build container")
}
#[test]
fn from_nam_builds_slimmable_default_full() {
let mut model = container();
let s = model.as_slimmable_mut().expect("is slimmable");
assert_eq!(s.len(), 3);
assert_eq!(s.active_index(), 2, "default = last/full submodel");
}
#[test]
fn select_clamps_out_of_range() {
let mut model = container();
let s = model.as_slimmable_mut().unwrap();
s.select(0);
assert_eq!(s.active_index(), 0);
s.select(99);
assert_eq!(s.active_index(), 2, "clamped to last");
}
#[test]
fn set_slim_size_picks_first_threshold_above_value() {
let mut model = container();
let s = model.as_slimmable_mut().unwrap();
s.set_slim_size(0.0);
assert_eq!(s.active_index(), 0); s.set_slim_size(0.5);
assert_eq!(s.active_index(), 1); s.set_slim_size(0.99);
assert_eq!(s.active_index(), 2); s.set_slim_size(5.0);
assert_eq!(s.active_index(), 2); }
#[test]
fn reset_clears_all_submodels_not_just_active() {
let mut model = container();
let mut fresh = container();
fresh.as_slimmable_mut().unwrap().select(0);
let mut probe_fresh = vec![0.3_f32; 8];
fresh.process_buffer(&mut probe_fresh);
model.as_slimmable_mut().unwrap().select(0);
let mut warm = vec![0.5_f32; 16];
model.process_buffer(&mut warm);
model.as_slimmable_mut().unwrap().select(2);
model.reset();
model.as_slimmable_mut().unwrap().select(0);
let mut probe = vec![0.3_f32; 8];
model.process_buffer(&mut probe);
for (i, (got, want)) in probe.iter().zip(&probe_fresh).enumerate() {
assert!(
(got - want).abs() < 1e-6,
"reset left submodel 0 dirty at sample {i}: {got} vs fresh {want}"
);
}
}
#[test]
fn slimmable_processes_through_active_submodel() {
let mut model = container();
model.as_slimmable_mut().unwrap().select(0); let mut a = vec![0.1_f32; 32];
model.process_buffer(&mut a);
model.as_slimmable_mut().unwrap().select(2); let mut b = vec![0.1_f32; 32];
model.process_buffer(&mut b);
}
#[test]
fn as_slimmable_none_for_plain_models() {
let mut wn = Model::from_nam(&NamModel::from_json_str(TINY_WAVENET).unwrap()).unwrap();
assert!(wn.as_slimmable_mut().is_none());
}
}