use svod_dtype::DType;
use svod_ir::{ConstValue, SInt};
use crate::Tensor;
use crate::test::helpers::*;
crate::codegen_tests! {
fn test_batch_two_independent(config) {
test_setup();
let mut a = Tensor::full(&[4], 2.0f32, DType::Float32).unwrap();
let mut b = Tensor::full(&[3], 3.0f32, DType::Float32).unwrap();
Tensor::realize_batch_with([&mut a, &mut b], &config).unwrap();
assert_eq!(a.as_vec::<f32>().unwrap(), vec![2.0; 4]);
assert_eq!(b.as_vec::<f32>().unwrap(), vec![3.0; 3]);
}
fn test_batch_shared_input(config) {
test_setup();
let x = Tensor::from_slice([1.0f32, 2.0, 3.0]);
let ten = Tensor::full(&[3], 10.0f32, DType::Float32).unwrap();
let two = Tensor::full(&[3], 2.0f32, DType::Float32).unwrap();
let mut a = &x + &ten;
let mut b = &x * &two;
Tensor::realize_batch_with([&mut a, &mut b], &config).unwrap();
assert_close_f32(&a.as_vec::<f32>().unwrap(), &[11.0, 12.0, 13.0], 1e-6);
assert_close_f32(&b.as_vec::<f32>().unwrap(), &[2.0, 4.0, 6.0], 1e-6);
}
fn test_batch_single(config) {
test_setup();
let mut t = Tensor::full(&[5], 7.0f32, DType::Float32).unwrap();
Tensor::realize_batch_with([&mut t], &config).unwrap();
assert_eq!(t.as_vec::<f32>().unwrap(), vec![7.0; 5]);
}
fn test_batch_mixed_realized(config) {
test_setup();
let mut a = Tensor::full(&[3], 1.0f32, DType::Float32).unwrap();
a.realize_with(&config).unwrap();
let mut b = Tensor::full(&[3], 2.0f32, DType::Float32).unwrap();
Tensor::realize_batch_with([&mut a, &mut b], &config).unwrap();
assert_eq!(a.as_vec::<f32>().unwrap(), vec![1.0; 3]);
assert_eq!(b.as_vec::<f32>().unwrap(), vec![2.0; 3]);
}
fn test_batch_mixed_dtypes(config) {
test_setup();
let mut a = Tensor::full(&[3], 5.0f32, DType::Float32).unwrap();
let mut b = Tensor::full_dynamic(&[SInt::from(4)], ConstValue::Int(42), DType::Int32).unwrap();
Tensor::realize_batch_with([&mut a, &mut b], &config).unwrap();
assert_eq!(a.as_vec::<f32>().unwrap(), vec![5.0; 3]);
assert_eq!(b.as_vec::<i32>().unwrap(), vec![42; 4]);
}
fn test_batch_three_outputs(config) {
test_setup();
let mut a = Tensor::full(&[2], 1.0f32, DType::Float32).unwrap();
let mut b = Tensor::full(&[3], 2.0f32, DType::Float32).unwrap();
let mut c = Tensor::full(&[4], 3.0f32, DType::Float32).unwrap();
Tensor::realize_batch_with([&mut a, &mut b, &mut c], &config).unwrap();
assert_eq!(a.as_vec::<f32>().unwrap(), vec![1.0; 2]);
assert_eq!(b.as_vec::<f32>().unwrap(), vec![2.0; 3]);
assert_eq!(c.as_vec::<f32>().unwrap(), vec![3.0; 4]);
}
fn test_batch_shared_with_reduce(config) {
test_setup();
let x = Tensor::from_slice([1.0f32, 2.0, 3.0, 4.0]);
let two = Tensor::full(&[4], 2.0f32, DType::Float32).unwrap();
let mut doubled = &x * &two;
let mut sum = doubled.sum(()).unwrap();
Tensor::realize_batch_with([&mut doubled, &mut sum], &config).unwrap();
assert_close_f32(&doubled.as_vec::<f32>().unwrap(), &[2.0, 4.0, 6.0, 8.0], 1e-6);
assert_close_f32(&sum.as_vec::<f32>().unwrap(), &[20.0], 1e-5);
}
fn test_batch_diamond_output(config) {
test_setup();
let x = Tensor::from_slice([1.0f32, 2.0, 3.0]);
let ten = Tensor::full(&[3], 10.0f32, DType::Float32).unwrap();
let mut a = &x + &ten; let mut b = &a * &a; Tensor::realize_batch_with([&mut a, &mut b], &config).unwrap();
assert_close_f32(&a.as_vec::<f32>().unwrap(), &[11.0, 12.0, 13.0], 1e-6);
assert_close_f32(&b.as_vec::<f32>().unwrap(), &[121.0, 144.0, 169.0], 1e-6);
}
fn test_prepare_batch_execute(config) {
test_setup();
let mut a = &Tensor::from_slice([1.0f32, 2.0]) + &Tensor::from_slice([3.0f32, 4.0]);
let mut b = &Tensor::from_slice([10.0f32, 20.0]) * &Tensor::from_slice([2.0f32, 3.0]);
let plan = Tensor::prepare_batch_with([&mut a, &mut b], &config).unwrap();
plan.execute().unwrap();
assert_eq!(plan.num_outputs(), 2);
assert_close_f32(&a.as_vec::<f32>().unwrap(), &[4.0, 6.0], 1e-6);
assert_close_f32(&b.as_vec::<f32>().unwrap(), &[20.0, 60.0], 1e-6);
}
}
#[test]
fn test_batch_empty() {
Tensor::realize_batch(std::iter::empty::<&mut Tensor>()).unwrap();
}
#[test]
fn test_batch_output_count() {
let mut a = Tensor::full(&[2], 1.0f32, DType::Float32).unwrap();
let mut b = Tensor::full(&[3], 2.0f32, DType::Float32).unwrap();
let mut c = Tensor::full(&[4], 3.0f32, DType::Float32).unwrap();
Tensor::realize_batch([&mut a, &mut b, &mut c]).unwrap();
}