use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::storage::TensorStorage;
use crate::tensor::Tensor;
pub fn scatter_add_segments<T: Float>(
src: &Tensor<T>,
index: &[i64],
dim_size: usize,
) -> FerrotorchResult<Tensor<T>> {
if src.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "scatter_add_segments",
});
}
let shape = src.shape();
if shape.len() != 2 {
return Err(FerrotorchError::ShapeMismatch {
message: format!("scatter_add_segments: src must be 2-D [E, D], got shape {shape:?}"),
});
}
let e = shape[0];
let d = shape[1];
if index.len() != e {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"scatter_add_segments: index length {} != src.shape()[0] {e}",
index.len()
),
});
}
let zero = <T as num_traits::Zero>::zero();
let mut out = vec![zero; dim_size * d];
let src_data = src.data_vec()?;
for (e_idx, &dst_i64) in index.iter().enumerate() {
if dst_i64 < 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!("scatter_add_segments: index[{e_idx}] = {dst_i64} is negative"),
});
}
let dst = dst_i64 as usize;
if dst >= dim_size {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"scatter_add_segments: index[{e_idx}] = {dst} >= dim_size {dim_size}"
),
});
}
let src_row = &src_data[e_idx * d..(e_idx + 1) * d];
let out_row = &mut out[dst * d..(dst + 1) * d];
for (o, &v) in out_row.iter_mut().zip(src_row.iter()) {
*o += v;
}
}
Tensor::from_storage(TensorStorage::cpu(out), vec![dim_size, d], false)
}
#[cfg(test)]
mod tests {
use super::*;
fn t(data: &[f32], shape: &[usize]) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
}
#[test]
fn segments_basic_aggregation() {
let src = t(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]);
let out = scatter_add_segments(&src, &[0, 1, 0], 2).unwrap();
assert_eq!(out.shape(), &[2, 2]);
let data = out.data().unwrap();
assert!((data[0] - 6.0).abs() < 1e-6);
assert!((data[1] - 8.0).abs() < 1e-6);
assert!((data[2] - 3.0).abs() < 1e-6);
assert!((data[3] - 4.0).abs() < 1e-6);
}
#[test]
fn segments_empty_rows_are_zero() {
let src = t(&[7.0, 0.5, 8.0, 0.25], &[2, 2]);
let out = scatter_add_segments(&src, &[0, 0], 3).unwrap();
assert_eq!(out.shape(), &[3, 2]);
let data = out.data().unwrap();
assert!((data[0] - 15.0).abs() < 1e-6);
assert!((data[1] - 0.75).abs() < 1e-6);
for &v in &data[2..] {
assert!(v.abs() < 1e-12, "expected exact zero, got {v}");
}
}
#[test]
fn segments_single_edge_per_segment() {
let src = t(&[1.0, 1.5, 2.0, 2.5, 3.0, 3.5], &[3, 2]);
let out = scatter_add_segments(&src, &[2, 0, 1], 3).unwrap();
let data = out.data().unwrap();
assert!((data[0] - 2.0).abs() < 1e-6);
assert!((data[1] - 2.5).abs() < 1e-6);
assert!((data[2] - 3.0).abs() < 1e-6);
assert!((data[3] - 3.5).abs() < 1e-6);
assert!((data[4] - 1.0).abs() < 1e-6);
assert!((data[5] - 1.5).abs() < 1e-6);
}
#[test]
fn segments_rejects_non_2d_src() {
let src = t(&[1.0, 2.0, 3.0], &[3]);
assert!(scatter_add_segments(&src, &[0, 1, 0], 2).is_err());
}
#[test]
fn segments_rejects_index_length_mismatch() {
let src = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
assert!(scatter_add_segments(&src, &[0, 1, 0], 2).is_err());
}
#[test]
fn segments_rejects_negative_index() {
let src = t(&[1.0, 2.0], &[1, 2]);
assert!(scatter_add_segments(&src, &[-1], 2).is_err());
}
#[test]
fn segments_rejects_oob_index() {
let src = t(&[1.0, 2.0], &[1, 2]);
assert!(scatter_add_segments(&src, &[2], 2).is_err());
}
}