use svod_dtype::DType;
use svod_tensor::Tensor;
use crate::wespeaker::WeSpeakerResNet34;
#[test]
fn forward_zero_weights_shape() {
let model = WeSpeakerResNet34::with_zero_weights(crate::wespeaker::WeSpeakerConfig::new().with_max_batch_size(1));
let feats = Tensor::zeros(&[1, 1598, 80], DType::Float32).unwrap();
let weights = Tensor::ones(&[1, 799], DType::Float32).unwrap();
let var = svod_tensor::Variable::new("b", 1, 1);
let b = var.bind(1).unwrap();
let out = model.forward(&feats, &weights, &b).unwrap();
let shape: Vec<usize> = out
.shape()
.unwrap()
.iter()
.map(|s| s.as_const().or_else(|| s.vmax()).expect("concrete or symbolic-max shape"))
.collect();
assert_eq!(shape, vec![1, 256]);
}
#[test]
#[ignore = "heavy: full WeSpeaker ResNet34 graph compile through the CPU backend"]
fn forward_zero_weights_realize() {
let model = WeSpeakerResNet34::with_zero_weights(crate::wespeaker::WeSpeakerConfig::new().with_max_batch_size(1));
let feats = Tensor::zeros(&[1, 1598, 80], DType::Float32).unwrap();
let weights = Tensor::ones(&[1, 799], DType::Float32).unwrap();
let var = svod_tensor::Variable::new("b", 1, 1);
let b = var.bind(1).unwrap();
let mut out = model.forward(&feats, &weights, &b).unwrap();
out.realize().unwrap();
let shape: Vec<usize> = out
.shape()
.unwrap()
.iter()
.map(|s| s.as_const().or_else(|| s.vmax()).expect("concrete or symbolic-max shape"))
.collect();
assert_eq!(shape, vec![1, 256]);
}