use crate::dtype::{DType, Float};
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::storage::TensorStorage;
use crate::tensor::Tensor;
pub fn scatter_add_segments<T: Float>(
src: &Tensor<T>,
index: &[i64],
dim_size: usize,
) -> FerrotorchResult<Tensor<T>> {
let shape = src.shape();
if shape.len() != 2 {
return Err(FerrotorchError::ShapeMismatch {
message: format!("scatter_add_segments: src must be 2-D [E, D], got shape {shape:?}"),
});
}
let e = shape[0];
let d = shape[1];
if index.len() != e {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"scatter_add_segments: index length {} != src.shape()[0] {e}",
index.len()
),
});
}
for (e_idx, &dst_i64) in index.iter().enumerate() {
if dst_i64 < 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!("scatter_add_segments: index[{e_idx}] = {dst_i64} is negative"),
});
}
if dst_i64 as usize >= dim_size {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"scatter_add_segments: index[{e_idx}] = {dst_i64} >= dim_size {dim_size}"
),
});
}
}
if src.is_cuda() {
return scatter_add_segments_cuda(src, index, e, d, dim_size);
}
let zero = <T as num_traits::Zero>::zero();
let mut out = vec![zero; dim_size * d];
let src_data = src.data_vec()?;
for (e_idx, &dst_i64) in index.iter().enumerate() {
let dst = dst_i64 as usize;
let src_row = &src_data[e_idx * d..(e_idx + 1) * d];
let out_row = &mut out[dst * d..(dst + 1) * d];
for (o, &v) in out_row.iter_mut().zip(src_row.iter()) {
*o += v;
}
}
Tensor::from_storage(TensorStorage::cpu(out), vec![dim_size, d], false)
}
fn scatter_add_segments_cuda<T: Float>(
src: &Tensor<T>,
index: &[i64],
e: usize,
d: usize,
dim_size: usize,
) -> FerrotorchResult<Tensor<T>> {
if !matches!(T::dtype(), DType::F32 | DType::F64) {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "scatter_add_segments",
});
}
let src = src.contiguous()?;
let src_handle = src.gpu_handle()?;
let ordinal = src_handle.device_ordinal();
let backend = crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let idx_bytes: &[u8] =
unsafe { std::slice::from_raw_parts(index.as_ptr().cast::<u8>(), index.len() * 8) };
let idx_handle = backend.cpu_to_gpu(idx_bytes, DType::I64, ordinal)?;
let h = if T::dtype() == DType::F32 {
backend.scatter_add_segments_f32(src_handle, &idx_handle, e, d, dim_size)?
} else {
backend.scatter_add_segments_f64(src_handle, &idx_handle, e, d, dim_size)?
};
Tensor::from_storage(TensorStorage::gpu(h), vec![dim_size, d], false)
}
#[cfg(test)]
mod tests {
use super::*;
fn t(data: &[f32], shape: &[usize]) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
}
#[test]
fn segments_basic_aggregation() {
let src = t(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]);
let out = scatter_add_segments(&src, &[0, 1, 0], 2).unwrap();
assert_eq!(out.shape(), &[2, 2]);
let data = out.data().unwrap();
assert!((data[0] - 6.0).abs() < 1e-6);
assert!((data[1] - 8.0).abs() < 1e-6);
assert!((data[2] - 3.0).abs() < 1e-6);
assert!((data[3] - 4.0).abs() < 1e-6);
}
#[test]
fn segments_empty_rows_are_zero() {
let src = t(&[7.0, 0.5, 8.0, 0.25], &[2, 2]);
let out = scatter_add_segments(&src, &[0, 0], 3).unwrap();
assert_eq!(out.shape(), &[3, 2]);
let data = out.data().unwrap();
assert!((data[0] - 15.0).abs() < 1e-6);
assert!((data[1] - 0.75).abs() < 1e-6);
for &v in &data[2..] {
assert!(v.abs() < 1e-12, "expected exact zero, got {v}");
}
}
#[test]
fn segments_single_edge_per_segment() {
let src = t(&[1.0, 1.5, 2.0, 2.5, 3.0, 3.5], &[3, 2]);
let out = scatter_add_segments(&src, &[2, 0, 1], 3).unwrap();
let data = out.data().unwrap();
assert!((data[0] - 2.0).abs() < 1e-6);
assert!((data[1] - 2.5).abs() < 1e-6);
assert!((data[2] - 3.0).abs() < 1e-6);
assert!((data[3] - 3.5).abs() < 1e-6);
assert!((data[4] - 1.0).abs() < 1e-6);
assert!((data[5] - 1.5).abs() < 1e-6);
}
#[test]
fn segments_rejects_non_2d_src() {
let src = t(&[1.0, 2.0, 3.0], &[3]);
assert!(scatter_add_segments(&src, &[0, 1, 0], 2).is_err());
}
#[test]
fn segments_rejects_index_length_mismatch() {
let src = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
assert!(scatter_add_segments(&src, &[0, 1, 0], 2).is_err());
}
#[test]
fn segments_rejects_negative_index() {
let src = t(&[1.0, 2.0], &[1, 2]);
assert!(scatter_add_segments(&src, &[-1], 2).is_err());
}
#[test]
fn segments_rejects_oob_index() {
let src = t(&[1.0, 2.0], &[1, 2]);
assert!(scatter_add_segments(&src, &[2], 2).is_err());
}
}