use snafu::ResultExt;
use svod_ir::SInt;
use svod_tensor::Tensor;
use super::error::{Result, TensorSnafu};
const EPS: f64 = 1e-8;
pub fn tstp_forward(features: &Tensor, weights: &Tensor) -> Result<Tensor> {
let shape = features.shape().context(TensorSnafu)?;
if shape.len() != 4 {
return Err(super::error::Error::Tensor {
source: Box::new(svod_tensor::error::Error::IrConstruction {
details: format!("TSTP expects 4D features, got {}D", shape.len()),
}),
});
}
let t_back = shape[3].as_const().ok_or_else(|| super::error::Error::Tensor {
source: Box::new(svod_tensor::error::Error::IrConstruction {
details: "TSTP requires concrete T (backbone time dim) — symbolic not supported".into(),
}),
})?;
let t_w = weights.shape().context(TensorSnafu)?[1].as_const().ok_or_else(|| super::error::Error::Tensor {
source: Box::new(svod_tensor::error::Error::IrConstruction {
details: "TSTP requires concrete T_w (weight time dim)".into(),
}),
})?;
let mat = nearest_interp_matrix(t_w, t_back);
let w = weights.linear().weight(&mat).call().context(TensorSnafu)?;
let w = w.try_unsqueeze(1).context(TensorSnafu)?;
let w = w.try_unsqueeze(2).context(TensorSnafu)?;
let dtype = features.uop().dtype();
let eps = Tensor::const_(EPS, dtype.clone());
let v1_raw = w.sum_with().axes(3isize).keepdim(true).call().context(TensorSnafu)?;
let v1 = v1_raw.try_add(&eps).context(TensorSnafu)?;
let xw = features.try_mul(&w).context(TensorSnafu)?;
let xw_sum = xw.sum_with().axes(3isize).keepdim(true).call().context(TensorSnafu)?;
let mean = xw_sum.try_div(&v1).context(TensorSnafu)?;
let centered = features.try_sub(&mean).context(TensorSnafu)?;
let dx2 = centered.square().context(TensorSnafu)?;
let w_sq = w.square().context(TensorSnafu)?;
let v2 = w_sq.sum_with().axes(3isize).keepdim(true).call().context(TensorSnafu)?;
let denom = v1.try_sub(&v2.try_div(&v1).context(TensorSnafu)?).context(TensorSnafu)?;
let denom = denom.try_add(&eps).context(TensorSnafu)?;
let var_num = dx2.try_mul(&w).context(TensorSnafu)?;
let var_num = var_num.sum_with().axes(3isize).keepdim(true).call().context(TensorSnafu)?;
let var = var_num.try_div(&denom).context(TensorSnafu)?;
let std = var.try_sqrt().context(TensorSnafu)?;
let b = shape[0].clone();
let c = shape[1].as_const().ok_or_else(|| super::error::Error::Tensor {
source: Box::new(svod_tensor::error::Error::IrConstruction {
details: "TSTP requires concrete C (channel dim)".into(),
}),
})?;
let h = shape[2].as_const().ok_or_else(|| super::error::Error::Tensor {
source: Box::new(svod_tensor::error::Error::IrConstruction {
details: "TSTP requires concrete H (freq dim)".into(),
}),
})?;
let stats_dim = SInt::Const(c * h);
let mean_flat = mean.try_reshape([b.clone(), stats_dim.clone()]).context(TensorSnafu)?;
let std_flat = std.try_reshape([b, stats_dim]).context(TensorSnafu)?;
Tensor::cat(&[&mean_flat, &std_flat], 1).context(TensorSnafu)
}
fn nearest_interp_matrix(t_in: usize, t_out: usize) -> Tensor {
let mut m = vec![0.0f32; t_out * t_in];
for o in 0..t_out {
let src = (o * t_in) / t_out;
m[o * t_in + src] = 1.0;
}
Tensor::from_slice(&m).try_reshape([t_out as isize, t_in as isize]).expect("nearest interp matrix reshape")
}