use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::storage::TensorStorage;
use crate::tensor::Tensor;
pub fn select<T: Float>(
input: &Tensor<T>,
dim: usize,
index: usize,
) -> FerrotorchResult<Tensor<T>> {
let shape = input.shape();
let ndim = shape.len();
if dim >= ndim {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"select: dim {} is out of bounds for tensor with {} dimensions",
dim, ndim
),
});
}
if index >= shape[dim] {
return Err(FerrotorchError::IndexOutOfBounds {
index,
axis: dim,
size: shape[dim],
});
}
let mut out_shape: Vec<usize> = shape.to_vec();
out_shape.remove(dim);
let data = input.data()?;
let outer: usize = shape[..dim].iter().product();
let inner: usize = if dim + 1 < ndim {
shape[dim + 1..].iter().product()
} else {
1
};
let dim_size = shape[dim];
let out_numel: usize = outer * inner;
let mut out_data = Vec::with_capacity(out_numel);
for o in 0..outer {
let src_base = o * dim_size * inner + index * inner;
for j in 0..inner {
out_data.push(data[src_base + j]);
}
}
Tensor::from_storage(TensorStorage::cpu(out_data), out_shape, false)
}
pub fn stack<T: Float>(tensors: &[Tensor<T>], dim: usize) -> FerrotorchResult<Tensor<T>> {
if tensors.is_empty() {
return Err(FerrotorchError::InvalidArgument {
message: "stack: empty tensor list".into(),
});
}
let base_shape = tensors[0].shape();
let base_ndim = base_shape.len();
if dim > base_ndim {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"stack: dim {} is out of bounds for tensors with {} dimensions (max = {})",
dim, base_ndim, base_ndim
),
});
}
for (i, t) in tensors.iter().enumerate().skip(1) {
if t.shape() != base_shape {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"stack: tensor {} has shape {:?}, expected {:?}",
i,
t.shape(),
base_shape
),
});
}
}
let n = tensors.len();
let mut out_shape = base_shape.to_vec();
out_shape.insert(dim, n);
let outer: usize = base_shape[..dim].iter().product();
let inner: usize = if dim < base_ndim {
base_shape[dim..].iter().product()
} else {
1
};
let out_numel: usize = out_shape.iter().product();
let mut out_data = vec![<T as num_traits::Zero>::zero(); out_numel];
for (t_idx, t) in tensors.iter().enumerate() {
let t_data = t.data()?;
for o in 0..outer {
let dst_base = o * n * inner + t_idx * inner;
let src_base = o * inner;
out_data[dst_base..dst_base + inner]
.copy_from_slice(&t_data[src_base..src_base + inner]);
}
}
Tensor::from_storage(TensorStorage::cpu(out_data), out_shape, false)
}
pub fn vmap<T, F>(
f: F,
in_dim: usize,
out_dim: usize,
) -> impl Fn(&Tensor<T>) -> FerrotorchResult<Tensor<T>>
where
T: Float,
F: Fn(&Tensor<T>) -> FerrotorchResult<Tensor<T>>,
{
move |input: &Tensor<T>| {
let shape = input.shape();
if in_dim >= shape.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"vmap: in_dim {} is out of bounds for tensor with {} dimensions",
in_dim,
shape.len()
),
});
}
let batch_size = shape[in_dim];
let mut results = Vec::with_capacity(batch_size);
for i in 0..batch_size {
let slice = select(input, in_dim, i)?;
let output = f(&slice)?;
results.push(output);
}
stack(&results, out_dim)
}
}
pub fn vmap2<T, F>(
f: F,
in_dim_a: usize,
in_dim_b: usize,
out_dim: usize,
) -> impl Fn(&Tensor<T>, &Tensor<T>) -> FerrotorchResult<Tensor<T>>
where
T: Float,
F: Fn(&Tensor<T>, &Tensor<T>) -> FerrotorchResult<Tensor<T>>,
{
move |a: &Tensor<T>, b: &Tensor<T>| {
let a_shape = a.shape();
let b_shape = b.shape();
if in_dim_a >= a_shape.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"vmap2: in_dim_a {} is out of bounds for tensor a with {} dimensions",
in_dim_a,
a_shape.len()
),
});
}
if in_dim_b >= b_shape.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"vmap2: in_dim_b {} is out of bounds for tensor b with {} dimensions",
in_dim_b,
b_shape.len()
),
});
}
let batch_a = a_shape[in_dim_a];
let batch_b = b_shape[in_dim_b];
if batch_a != batch_b {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"vmap2: batch size mismatch: a has {} along dim {}, b has {} along dim {}",
batch_a, in_dim_a, batch_b, in_dim_b
),
});
}
let batch_size = batch_a;
let mut results = Vec::with_capacity(batch_size);
for i in 0..batch_size {
let slice_a = select(a, in_dim_a, i)?;
let slice_b = select(b, in_dim_b, i)?;
let output = f(&slice_a, &slice_b)?;
results.push(output);
}
stack(&results, out_dim)
}
}
pub fn vmap3<T, F>(
f: F,
in_dim_a: usize,
in_dim_b: usize,
in_dim_c: usize,
out_dim: usize,
) -> impl Fn(&Tensor<T>, &Tensor<T>, &Tensor<T>) -> FerrotorchResult<Tensor<T>>
where
T: Float,
F: Fn(&Tensor<T>, &Tensor<T>, &Tensor<T>) -> FerrotorchResult<Tensor<T>>,
{
move |a: &Tensor<T>, b: &Tensor<T>, c: &Tensor<T>| {
let a_shape = a.shape();
let b_shape = b.shape();
let c_shape = c.shape();
if in_dim_a >= a_shape.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"vmap3: in_dim_a {} is out of bounds for tensor a with {} dimensions",
in_dim_a,
a_shape.len()
),
});
}
if in_dim_b >= b_shape.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"vmap3: in_dim_b {} is out of bounds for tensor b with {} dimensions",
in_dim_b,
b_shape.len()
),
});
}
if in_dim_c >= c_shape.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"vmap3: in_dim_c {} is out of bounds for tensor c with {} dimensions",
in_dim_c,
c_shape.len()
),
});
}
let batch_a = a_shape[in_dim_a];
let batch_b = b_shape[in_dim_b];
let batch_c = c_shape[in_dim_c];
if batch_a != batch_b || batch_a != batch_c {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"vmap3: batch size mismatch: a={} dim {}, b={} dim {}, c={} dim {}",
batch_a, in_dim_a, batch_b, in_dim_b, batch_c, in_dim_c
),
});
}
let mut results = Vec::with_capacity(batch_a);
for i in 0..batch_a {
let sa = select(a, in_dim_a, i)?;
let sb = select(b, in_dim_b, i)?;
let sc = select(c, in_dim_c, i)?;
results.push(f(&sa, &sb, &sc)?);
}
stack(&results, out_dim)
}
}
pub fn vmap_many<T, F>(
f: F,
in_dims: Vec<usize>,
out_dim: usize,
) -> impl Fn(&[&Tensor<T>]) -> FerrotorchResult<Tensor<T>>
where
T: Float,
F: Fn(&[Tensor<T>]) -> FerrotorchResult<Tensor<T>>,
{
move |inputs: &[&Tensor<T>]| {
if inputs.len() != in_dims.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"vmap_many: got {} inputs but {} in_dims",
inputs.len(),
in_dims.len()
),
});
}
if inputs.is_empty() {
return Err(FerrotorchError::InvalidArgument {
message: "vmap_many: at least one input required".into(),
});
}
let mut batch_size: Option<usize> = None;
for (i, (input, &dim)) in inputs.iter().zip(in_dims.iter()).enumerate() {
if dim >= input.ndim() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"vmap_many: in_dims[{}] = {} is out of bounds for input with {} dims",
i,
dim,
input.ndim()
),
});
}
let bs = input.shape()[dim];
match batch_size {
None => batch_size = Some(bs),
Some(b) if b != bs => {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"vmap_many: batch size mismatch: input[{}] has {} along dim {}, others have {}",
i, bs, dim, b
),
});
}
Some(_) => {}
}
}
let batch_size = batch_size.unwrap();
let mut results = Vec::with_capacity(batch_size);
for i in 0..batch_size {
let mut slices: Vec<Tensor<T>> = Vec::with_capacity(inputs.len());
for (input, &dim) in inputs.iter().zip(in_dims.iter()) {
slices.push(select(input, dim, i)?);
}
results.push(f(&slices)?);
}
stack(&results, out_dim)
}
}
pub fn vmap_multi_output<T, F>(
f: F,
in_dim: usize,
out_dim: usize,
) -> impl Fn(&Tensor<T>) -> FerrotorchResult<Vec<Tensor<T>>>
where
T: Float,
F: Fn(&Tensor<T>) -> FerrotorchResult<Vec<Tensor<T>>>,
{
move |input: &Tensor<T>| {
if in_dim >= input.ndim() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"vmap_multi_output: in_dim {} is out of bounds for tensor with {} dims",
in_dim,
input.ndim()
),
});
}
let batch_size = input.shape()[in_dim];
let mut per_call: Vec<Vec<Tensor<T>>> = Vec::with_capacity(batch_size);
let mut num_outputs: Option<usize> = None;
for i in 0..batch_size {
let slice = select(input, in_dim, i)?;
let outs = f(&slice)?;
match num_outputs {
None => num_outputs = Some(outs.len()),
Some(n) if n != outs.len() => {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"vmap_multi_output: closure returned {} outputs at batch index {} \
but {} at the first call -- output count must be fixed",
outs.len(),
i,
n
),
});
}
Some(_) => {}
}
per_call.push(outs);
}
let num_outputs = num_outputs.unwrap_or(0);
let mut stacked: Vec<Tensor<T>> = Vec::with_capacity(num_outputs);
for out_idx in 0..num_outputs {
let mut col: Vec<Tensor<T>> = Vec::with_capacity(batch_size);
for batch_outputs in &per_call {
col.push(batch_outputs[out_idx].clone());
}
stacked.push(stack(&col, out_dim)?);
}
Ok(stacked)
}
}
pub fn per_sample_grad<T, F>(
loss_fn: F,
inputs: &Tensor<T>,
param: &Tensor<T>,
in_dim: usize,
) -> FerrotorchResult<Tensor<T>>
where
T: Float,
F: Fn(&Tensor<T>, &Tensor<T>) -> FerrotorchResult<Tensor<T>>,
{
if in_dim >= inputs.ndim() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"per_sample_grad: in_dim {} out of bounds for input with {} dims",
in_dim,
inputs.ndim()
),
});
}
let batch_size = inputs.shape()[in_dim];
let mut grads: Vec<Tensor<T>> = Vec::with_capacity(batch_size);
for i in 0..batch_size {
let slice = select(inputs, in_dim, i)?;
let p_data = param.data_vec()?;
let p_leaf = crate::tensor::Tensor::from_storage(
crate::storage::TensorStorage::cpu(p_data),
param.shape().to_vec(),
true,
)?;
let loss = loss_fn(&slice, &p_leaf)?;
if !loss.is_scalar() && loss.numel() != 1 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"per_sample_grad: loss_fn must return a scalar tensor, got shape {:?}",
loss.shape()
),
});
}
loss.backward()?;
let g = p_leaf.grad()?.ok_or_else(|| FerrotorchError::InvalidArgument {
message: "per_sample_grad: parameter received no gradient -- check that loss_fn \
uses the parameter argument and produces a differentiable result"
.into(),
})?;
grads.push(g);
}
stack(&grads, 0)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::creation::from_slice;
fn t(data: &[f32], shape: &[usize]) -> Tensor<f32> {
from_slice(data, shape).unwrap()
}
#[test]
fn test_select_axis0() {
let data: Vec<f32> = (1..=12).map(|x| x as f32).collect();
let input = t(&data, &[3, 4]);
let s0 = select(&input, 0, 0).unwrap();
assert_eq!(s0.shape(), &[4]);
assert_eq!(s0.data().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
let s1 = select(&input, 0, 1).unwrap();
assert_eq!(s1.shape(), &[4]);
assert_eq!(s1.data().unwrap(), &[5.0, 6.0, 7.0, 8.0]);
let s2 = select(&input, 0, 2).unwrap();
assert_eq!(s2.shape(), &[4]);
assert_eq!(s2.data().unwrap(), &[9.0, 10.0, 11.0, 12.0]);
}
#[test]
fn test_select_axis1() {
let input = t(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let s0 = select(&input, 1, 0).unwrap();
assert_eq!(s0.shape(), &[2]);
assert_eq!(s0.data().unwrap(), &[1.0, 4.0]);
let s1 = select(&input, 1, 1).unwrap();
assert_eq!(s1.data().unwrap(), &[2.0, 5.0]);
let s2 = select(&input, 1, 2).unwrap();
assert_eq!(s2.data().unwrap(), &[3.0, 6.0]);
}
#[test]
fn test_select_3d() {
let data: Vec<f32> = (0..24).map(|x| x as f32).collect();
let input = t(&data, &[2, 3, 4]);
let s0 = select(&input, 0, 0).unwrap();
assert_eq!(s0.shape(), &[3, 4]);
let expected: Vec<f32> = (0..12).map(|x| x as f32).collect();
assert_eq!(s0.data().unwrap(), &expected);
let s1 = select(&input, 0, 1).unwrap();
assert_eq!(s1.shape(), &[3, 4]);
let expected: Vec<f32> = (12..24).map(|x| x as f32).collect();
assert_eq!(s1.data().unwrap(), &expected);
}
#[test]
fn test_select_from_1d() {
let input = t(&[10.0, 20.0, 30.0], &[3]);
let s = select(&input, 0, 1).unwrap();
assert!(s.is_scalar());
assert_eq!(s.item().unwrap(), 20.0);
}
#[test]
fn test_select_invalid_dim() {
let input = t(&[1.0, 2.0], &[2]);
assert!(select(&input, 1, 0).is_err());
}
#[test]
fn test_select_invalid_index() {
let input = t(&[1.0, 2.0, 3.0], &[3]);
assert!(select(&input, 0, 3).is_err());
}
#[test]
fn test_stack_axis0() {
let a = t(&[1.0, 2.0], &[2]);
let b = t(&[3.0, 4.0], &[2]);
let c = t(&[5.0, 6.0], &[2]);
let s = stack(&[a, b, c], 0).unwrap();
assert_eq!(s.shape(), &[3, 2]);
assert_eq!(s.data().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn test_stack_axis1() {
let a = t(&[1.0, 2.0], &[2]);
let b = t(&[3.0, 4.0], &[2]);
let s = stack(&[a, b], 1).unwrap();
assert_eq!(s.shape(), &[2, 2]);
assert_eq!(s.data().unwrap(), &[1.0, 3.0, 2.0, 4.0]);
}
#[test]
fn test_stack_2d_axis0() {
let a = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = t(&[5.0, 6.0, 7.0, 8.0], &[2, 2]);
let s = stack(&[a, b], 0).unwrap();
assert_eq!(s.shape(), &[2, 2, 2]);
assert_eq!(s.data().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
}
#[test]
fn test_stack_empty_error() {
let result: FerrotorchResult<Tensor<f32>> = stack(&[], 0);
assert!(result.is_err());
}
#[test]
fn test_stack_shape_mismatch() {
let a = t(&[1.0, 2.0], &[2]);
let b = t(&[3.0, 4.0, 5.0], &[3]);
assert!(stack(&[a, b], 0).is_err());
}
#[test]
fn test_stack_invalid_dim() {
let a = t(&[1.0, 2.0], &[2]);
let b = t(&[3.0, 4.0], &[2]);
assert!(stack(&[a, b], 2).is_err());
}
#[test]
fn test_select_stack_roundtrip() {
let data: Vec<f32> = (1..=12).map(|x| x as f32).collect();
let input = t(&data, &[3, 4]);
let rows: Vec<Tensor<f32>> = (0..3).map(|i| select(&input, 0, i).unwrap()).collect();
let reconstructed = stack(&rows, 0).unwrap();
assert_eq!(reconstructed.shape(), input.shape());
assert_eq!(reconstructed.data().unwrap(), input.data().unwrap());
}
#[test]
fn test_vmap_double() {
let data: Vec<f32> = (1..=12).map(|x| x as f32).collect();
let input = t(&data, &[3, 4]);
let doubled = vmap(
|x| {
let two = from_slice(&vec![2.0f32; x.numel()], x.shape())?;
x * &two
},
0,
0,
)(&input)
.unwrap();
assert_eq!(doubled.shape(), &[3, 4]);
let expected: Vec<f32> = data.iter().map(|x| x * 2.0).collect();
assert_eq!(doubled.data().unwrap(), &expected);
}
#[test]
fn test_vmap_sum_per_row() {
let data: Vec<f32> = (1..=12).map(|x| x as f32).collect();
let input = t(&data, &[3, 4]);
let sums = vmap(|x| x.sum_all(), 0, 0)(&input).unwrap();
assert_eq!(sums.shape(), &[3]);
let sums_data = sums.data().unwrap();
assert!((sums_data[0] - 10.0).abs() < 1e-6);
assert!((sums_data[1] - 26.0).abs() < 1e-6);
assert!((sums_data[2] - 42.0).abs() < 1e-6);
}
#[test]
fn test_vmap_invalid_in_dim() {
let input = t(&[1.0, 2.0], &[2]);
let result = vmap(|x: &Tensor<f32>| Ok(x.clone()), 1, 0)(&input);
assert!(result.is_err());
}
#[test]
fn test_vmap_matmul_matches_bmm() {
let a_data: Vec<f32> = (0..24).map(|x| x as f32).collect();
let b_data: Vec<f32> = (0..16).map(|x| (x as f32) * 0.1).collect();
let a = t(&a_data, &[2, 3, 4]);
let b = t(&b_data, &[2, 4, 2]);
let bmm_result = a.bmm(&b).unwrap();
let vmap_result = vmap2(|x, y| x.matmul(y), 0, 0, 0)(&a, &b).unwrap();
assert_eq!(bmm_result.shape(), vmap_result.shape());
let bmm_data = bmm_result.data().unwrap();
let vmap_data = vmap_result.data().unwrap();
for (bv, vv) in bmm_data.iter().zip(vmap_data.iter()) {
assert!((bv - vv).abs() < 1e-4, "bmm={bv}, vmap={vv}");
}
}
#[test]
fn test_vmap2_elementwise_add() {
let a = t(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]);
let b = t(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], &[3, 2]);
let result = vmap2(|x, y| x + y, 0, 0, 0)(&a, &b).unwrap();
assert_eq!(result.shape(), &[3, 2]);
assert_eq!(
result.data().unwrap(),
&[11.0, 22.0, 33.0, 44.0, 55.0, 66.0]
);
}
#[test]
fn test_vmap2_batch_mismatch() {
let a = t(&[1.0, 2.0, 3.0], &[3]);
let b = t(&[1.0, 2.0], &[2]);
let result = vmap2(|x, y| x + y, 0, 0, 0)(&a, &b);
assert!(result.is_err());
}
#[test]
fn test_vmap2_invalid_dim_a() {
let a = t(&[1.0, 2.0], &[2]);
let b = t(&[3.0, 4.0], &[2]);
let result = vmap2(|x: &Tensor<f32>, y: &Tensor<f32>| x + y, 2, 0, 0)(&a, &b);
assert!(result.is_err());
}
#[test]
fn test_vmap2_invalid_dim_b() {
let a = t(&[1.0, 2.0], &[2]);
let b = t(&[3.0, 4.0], &[2]);
let result = vmap2(|x: &Tensor<f32>, y: &Tensor<f32>| x + y, 0, 2, 0)(&a, &b);
assert!(result.is_err());
}
#[test]
fn test_stack_scalars() {
use crate::creation::scalar;
let a = scalar(1.0f32).unwrap();
let b = scalar(2.0f32).unwrap();
let c = scalar(3.0f32).unwrap();
let s = stack(&[a, b, c], 0).unwrap();
assert_eq!(s.shape(), &[3]);
assert_eq!(s.data().unwrap(), &[1.0, 2.0, 3.0]);
}
#[test]
fn test_vmap3_three_way_add() {
let a = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = t(&[10.0, 20.0, 30.0, 40.0], &[2, 2]);
let c = t(&[100.0, 200.0, 300.0, 400.0], &[2, 2]);
let result = vmap3(
|x, y, z| {
use crate::grad_fns::arithmetic::add;
let xy = add(x, y)?;
add(&xy, z)
},
0,
0,
0,
0,
)(&a, &b, &c)
.unwrap();
assert_eq!(result.shape(), &[2, 2]);
assert_eq!(result.data().unwrap(), &[111.0, 222.0, 333.0, 444.0]);
}
#[test]
fn test_vmap3_batch_size_mismatch() {
let a = t(&[1.0, 2.0], &[2]);
let b = t(&[1.0, 2.0, 3.0], &[3]);
let c = t(&[1.0, 2.0], &[2]);
let result = vmap3(
|x, _y, _z| Ok(x.clone()),
0,
0,
0,
0,
)(&a, &b, &c);
assert!(result.is_err());
}
#[test]
fn test_vmap3_invalid_dim() {
let a = t(&[1.0, 2.0], &[2]);
let b = t(&[1.0, 2.0], &[2]);
let c = t(&[1.0, 2.0], &[2]);
let result =
vmap3(|x, _y, _z| Ok(x.clone()), 5, 0, 0, 0)(&a, &b, &c);
assert!(result.is_err());
}
#[test]
fn test_vmap_many_four_inputs() {
let a = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = t(&[10.0, 20.0, 30.0, 40.0], &[2, 2]);
let c = t(&[100.0, 200.0, 300.0, 400.0], &[2, 2]);
let d = t(&[1000.0, 2000.0, 3000.0, 4000.0], &[2, 2]);
let result = vmap_many(
|slices: &[Tensor<f32>]| {
use crate::grad_fns::arithmetic::add;
let mut acc = slices[0].clone();
for s in &slices[1..] {
acc = add(&acc, s)?;
}
Ok(acc)
},
vec![0, 0, 0, 0],
0,
)(&[&a, &b, &c, &d])
.unwrap();
assert_eq!(result.shape(), &[2, 2]);
assert_eq!(result.data().unwrap(), &[1111.0, 2222.0, 3333.0, 4444.0]);
}
#[test]
fn test_vmap_many_input_count_mismatch() {
let a = t(&[1.0, 2.0], &[2]);
let b = t(&[1.0, 2.0], &[2]);
let result = vmap_many(
|slices: &[Tensor<f32>]| Ok(slices[0].clone()),
vec![0, 0, 0],
0,
)(&[&a, &b]);
assert!(result.is_err());
}
#[test]
fn test_vmap_many_empty_inputs_errors() {
let result = vmap_many(
|slices: &[Tensor<f32>]| Ok(slices[0].clone()),
vec![],
0,
)(&[]);
assert!(result.is_err());
}
#[test]
fn test_vmap_multi_output_two_outputs() {
let x = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let outs = vmap_multi_output(
|slice| {
use crate::grad_fns::arithmetic::mul;
let sq = mul(slice, slice)?;
Ok(vec![slice.clone(), sq])
},
0,
0,
)(&x)
.unwrap();
assert_eq!(outs.len(), 2);
assert_eq!(outs[0].shape(), &[2, 2]);
assert_eq!(outs[0].data().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
assert_eq!(outs[1].shape(), &[2, 2]);
assert_eq!(outs[1].data().unwrap(), &[1.0, 4.0, 9.0, 16.0]);
}
#[test]
fn test_vmap_multi_output_single_output_acts_like_vmap() {
let x = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let outs =
vmap_multi_output(|s| Ok(vec![s.clone()]), 0, 0)(&x).unwrap();
assert_eq!(outs.len(), 1);
assert_eq!(outs[0].shape(), &[2, 2]);
assert_eq!(outs[0].data().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_vmap_nested_composability() {
let x_data: Vec<f32> = (0..24).map(|v| v as f32).collect();
let x = t(&x_data, &[2, 3, 4]);
let result = vmap(
|outer_slice| {
vmap(|inner_slice| Ok(inner_slice.clone()), 0, 0)(outer_slice)
},
0,
0,
)(&x)
.unwrap();
assert_eq!(result.shape(), &[2, 3, 4]);
let r = result.data().unwrap();
for (i, &v) in r.iter().enumerate() {
assert!((v - x_data[i]).abs() < 1e-6, "mismatch at {i}: {v} vs {}", x_data[i]);
}
}
#[test]
fn test_vmap_nested_double_negation() {
let x = t(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let result = vmap(
|outer_slice| {
vmap(
|inner_slice| crate::grad_fns::arithmetic::neg(inner_slice),
0,
0,
)(outer_slice)
},
0,
0,
)(&x)
.unwrap();
assert_eq!(result.shape(), &[2, 3]);
assert_eq!(result.data().unwrap(), &[-1.0, -2.0, -3.0, -4.0, -5.0, -6.0]);
}
#[test]
fn test_per_sample_grad_simple_quadratic() {
let inputs = t(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let param = t(&[0.5, 0.5, 0.5], &[3]);
let grads = per_sample_grad(
|x: &Tensor<f32>, p: &Tensor<f32>| {
use crate::grad_fns::arithmetic::mul;
use crate::grad_fns::reduction::sum;
let xp = mul(x, p)?;
let sq = mul(&xp, &xp)?;
sum(&sq)
},
&inputs,
¶m,
0,
)
.unwrap();
assert_eq!(grads.shape(), &[2, 3]);
let g = grads.data().unwrap();
assert!((g[0] - 1.0).abs() < 1e-4);
assert!((g[1] - 4.0).abs() < 1e-4);
assert!((g[2] - 9.0).abs() < 1e-4);
assert!((g[3] - 16.0).abs() < 1e-4);
assert!((g[4] - 25.0).abs() < 1e-4);
assert!((g[5] - 36.0).abs() < 1e-4);
}
#[test]
fn test_per_sample_grad_invalid_dim() {
let x = t(&[1.0, 2.0], &[2]);
let p = t(&[0.5], &[1]);
let result = per_sample_grad(
|_x, _p| Ok(_p.clone()),
&x,
&p,
5,
);
assert!(result.is_err());
}
#[test]
fn test_per_sample_grad_non_scalar_loss_errors() {
let x = t(&[1.0, 2.0], &[2]);
let p = t(&[0.5], &[1]);
let result = per_sample_grad(
|x: &Tensor<f32>, _p: &Tensor<f32>| Ok(x.clone()),
&x,
&p,
0,
);
assert!(result.is_err());
}
}