use crate::error::{Error, Result};
use crate::nn::BiLstm;
use numr::dtype::DType;
use numr::ops::{
ActivationOps, BinaryOps, MatmulOps, ReduceOps, ScalarOps, TensorOps, UnaryOps, UtilityOps,
};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub struct FramePredictor<R: Runtime> {
lstm: BiLstm<R>,
proj_weight: Tensor<R>,
proj_bias: Tensor<R>,
d_model: usize,
style_dim: usize,
}
impl<R: Runtime> FramePredictor<R> {
pub fn new(
lstm: BiLstm<R>,
proj_weight: Tensor<R>,
proj_bias: Tensor<R>,
d_model: usize,
style_dim: usize,
) -> Result<Self> {
if 2 * lstm.hidden_size() != d_model {
return Err(Error::InvalidArgument {
arg: "lstm",
reason: format!(
"BiLSTM output width must equal d_model ({d_model}), got 2 * {}",
lstm.hidden_size()
),
});
}
if proj_weight.shape() != [1, d_model] {
return Err(Error::InvalidArgument {
arg: "proj_weight",
reason: format!("expected [1, {d_model}], got {:?}", proj_weight.shape()),
});
}
if proj_bias.shape() != [1] {
return Err(Error::InvalidArgument {
arg: "proj_bias",
reason: format!("expected [1], got {:?}", proj_bias.shape()),
});
}
Ok(Self {
lstm,
proj_weight,
proj_bias,
d_model,
style_dim,
})
}
pub fn d_model(&self) -> usize {
self.d_model
}
pub fn style_dim(&self) -> usize {
self.style_dim
}
pub fn forward<C>(&self, client: &C, frames: &Tensor<R>, style: &Tensor<R>) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>
+ MatmulOps<R>
+ TensorOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ ActivationOps<R>
+ ReduceOps<R>
+ ScalarOps<R>
+ UtilityOps<R>,
{
let h_shape = frames.shape();
if h_shape.len() != 3 || h_shape[2] != self.d_model {
return Err(Error::InvalidArgument {
arg: "frames",
reason: format!("expected [B, T, {}], got {h_shape:?}", self.d_model),
});
}
let s_shape = style.shape();
if s_shape != [h_shape[0], self.style_dim] {
return Err(Error::InvalidArgument {
arg: "style",
reason: format!(
"expected [{}, {}], got {s_shape:?}",
h_shape[0], self.style_dim
),
});
}
let (b, t, _) = (h_shape[0], h_shape[1], h_shape[2]);
let style_bc = style
.reshape(&[b, 1, self.style_dim])
.map_err(Error::Numr)?
.broadcast_to(&[b, t, self.style_dim])
.map_err(Error::Numr)?
.contiguous()?;
let cat = client.cat(&[frames, &style_bc], 2).map_err(Error::Numr)?;
let lstm_out = self.lstm.forward(client, &cat)?;
let flat = lstm_out
.reshape(&[b * t, self.d_model])
.map_err(Error::Numr)?;
let w_t = self.proj_weight.transpose(0, 1).map_err(Error::Numr)?;
let proj = client
.matmul_bias(&flat, &w_t, &self.proj_bias)
.map_err(Error::Numr)?;
proj.reshape(&[b, t]).map_err(Error::Numr)
}
}
pub type PitchPredictor<R> = FramePredictor<R>;
pub type EnergyPredictor<R> = FramePredictor<R>;
#[cfg(test)]
mod tests {
use super::*;
use crate::nn::{BiLstm, Lstm};
use crate::test_utils::cpu_setup;
use numr::runtime::cpu::CpuRuntime;
fn zeros(shape: &[usize], device: &<CpuRuntime as Runtime>::Device) -> Tensor<CpuRuntime> {
let n: usize = shape.iter().product();
Tensor::<CpuRuntime>::from_slice(&vec![0.0f32; n], shape, device)
}
fn build(device: &<CpuRuntime as Runtime>::Device) -> FramePredictor<CpuRuntime> {
let d_model = 4;
let style_dim = 3;
let input = d_model + style_dim;
let hidden = d_model / 2;
let lstm_f = Lstm::new(
zeros(&[4 * hidden, input], device),
zeros(&[4 * hidden, hidden], device),
zeros(&[4 * hidden], device),
zeros(&[4 * hidden], device),
)
.unwrap();
let lstm_b = Lstm::new(
zeros(&[4 * hidden, input], device),
zeros(&[4 * hidden, hidden], device),
zeros(&[4 * hidden], device),
zeros(&[4 * hidden], device),
)
.unwrap();
let bi = BiLstm::new(lstm_f, lstm_b).unwrap();
FramePredictor::new(
bi,
zeros(&[1, d_model], device),
zeros(&[1], device),
d_model,
style_dim,
)
.unwrap()
}
#[test]
fn forward_shape_is_b_t() {
let (client, device) = cpu_setup();
let pred = build(&device);
let frames = zeros(&[2, 7, 4], &device);
let style = zeros(&[2, 3], &device);
let out = pred.forward(&client, &frames, &style).unwrap();
assert_eq!(out.shape(), &[2, 7]);
}
#[test]
fn pitch_and_energy_aliases_compile() {
let (_client, device) = cpu_setup();
let _pitch: PitchPredictor<CpuRuntime> = build(&device);
let _energy: EnergyPredictor<CpuRuntime> = build(&device);
}
#[test]
fn rejects_wrong_frames_rank() {
let (client, device) = cpu_setup();
let pred = build(&device);
let frames = zeros(&[2, 4], &device);
let style = zeros(&[2, 3], &device);
assert!(pred.forward(&client, &frames, &style).is_err());
}
#[test]
fn rejects_wrong_style_shape() {
let (client, device) = cpu_setup();
let pred = build(&device);
let frames = zeros(&[1, 4, 4], &device);
let style = zeros(&[1, 5], &device);
assert!(pred.forward(&client, &frames, &style).is_err());
}
#[test]
fn new_rejects_lstm_width_mismatch() {
let (_client, device) = cpu_setup();
let hidden = 3;
let input = 7;
let lstm_f = Lstm::new(
zeros(&[4 * hidden, input], &device),
zeros(&[4 * hidden, hidden], &device),
zeros(&[4 * hidden], &device),
zeros(&[4 * hidden], &device),
)
.unwrap();
let lstm_b = Lstm::new(
zeros(&[4 * hidden, input], &device),
zeros(&[4 * hidden, hidden], &device),
zeros(&[4 * hidden], &device),
zeros(&[4 * hidden], &device),
)
.unwrap();
let bi = BiLstm::new(lstm_f, lstm_b).unwrap();
assert!(
FramePredictor::new(bi, zeros(&[1, 4], &device), zeros(&[1], &device), 4, 3,).is_err()
);
}
}