use super::canonical::Mlp;
pub struct AudioEncoder {
pub mlp: Mlp,
pub window_size: usize,
pub mel_dim: usize,
pub latent_dim: usize,
}
impl AudioEncoder {
pub fn new(mel_dim: usize, window_size: usize, latent_dim: usize) -> Self {
let input_dim = mel_dim * window_size;
Self {
mlp: Mlp::new(&[input_dim, 256, 128, latent_dim]),
window_size,
mel_dim,
latent_dim,
}
}
pub fn encode(&self, window: &[f32]) -> Vec<f32> {
assert_eq!(window.len(), self.window_size * self.mel_dim);
self.mlp.forward(window)
}
pub fn num_params(&self) -> usize {
self.mlp.num_params()
}
pub fn params_flat(&self) -> Vec<f32> {
self.mlp.params_flat()
}
pub fn set_params_flat(&mut self, params: &[f32]) {
self.mlp.set_params_flat(params);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_audio_encoder() {
let enc = AudioEncoder::new(80, 16, 64);
let window = vec![0.0f32; 80 * 16];
let latent = enc.encode(&window);
assert_eq!(latent.len(), 64);
}
}