use crate::op::{ComputeContext, GradientContext, Op, OpError};
use crate::tensor::Tensor;
use crate::Float;
use scirs2_core::ndarray::{Array, Axis, Ix1, IxDyn};
pub struct BooleanMaskOp;
impl<F: Float> Op<F> for BooleanMaskOp {
fn name(&self) -> &'static str {
"BooleanMask"
}
fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
let data = ctx.input(0);
let mask = ctx.input(1);
if data.shape() != mask.shape() {
return Err(OpError::IncompatibleShape(
"Data and mask must have the same shape".into(),
));
}
let data_view = data.view();
let mask_view = mask.view();
let mut selected_elements = Vec::new();
for (data_elem, mask_elem) in data_view.iter().zip(mask_view.iter()) {
if *mask_elem != F::zero() {
selected_elements.push(*data_elem);
}
}
let result = Array::from_vec(selected_elements);
ctx.append_output(result.into_dyn());
Ok(())
}
fn grad(&self, ctx: &mut GradientContext<F>) {
let _gy = ctx.output_grad();
let _mask = ctx.input(1);
let g = ctx.graph();
let inputshape = crate::tensor_ops::shape(ctx.input(0));
let zeros = crate::tensor_ops::zeros(&inputshape, g);
ctx.append_input_grad(0, Some(zeros));
ctx.append_input_grad(1, None); }
}
pub struct TakeOp {
pub axis: isize,
}
impl<F: Float> Op<F> for TakeOp {
fn name(&self) -> &'static str {
"Take"
}
fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
let data = ctx.input(0);
let indices = ctx.input(1);
let datashape = data.shape();
let indicesshape = indices.shape();
let axis = if self.axis < 0 {
(datashape.len() as isize + self.axis) as usize
} else {
self.axis as usize
};
if axis >= datashape.len() {
return Err(OpError::IncompatibleShape("Axis out of bounds".into()));
}
if indicesshape.len() != 1 {
return Err(OpError::IncompatibleShape(
"Only 1D indices supported for now".into(),
));
}
let indices_view = indices.view();
let data_view = data.view();
let axis_size = datashape[axis];
let index_values: Result<Vec<usize>, OpError> = indices_view
.iter()
.map(|&idx| {
let idx_int = idx
.to_usize()
.ok_or_else(|| OpError::Other("Index must be non-negative integer".into()))?;
if idx_int >= axis_size {
Err(OpError::Other("Index out of bounds".into()))
} else {
Ok(idx_int)
}
})
.collect();
let index_values = index_values?;
let mut outputshape = datashape.to_vec();
outputshape[axis] = index_values.len();
let mut output = Array::<F, IxDyn>::zeros(IxDyn(&outputshape));
for (out_idx, &src_idx) in index_values.iter().enumerate() {
let src_slice = data_view.index_axis(Axis(axis), src_idx);
let mut out_slice = output.index_axis_mut(Axis(axis), out_idx);
out_slice.assign(&src_slice);
}
ctx.append_output(output);
Ok(())
}
fn grad(&self, ctx: &mut GradientContext<F>) {
let _gy = ctx.output_grad();
let _indices = ctx.input(1);
let g = ctx.graph();
let inputshape = crate::tensor_ops::shape(ctx.input(0));
let zeros = crate::tensor_ops::zeros(&inputshape, g);
ctx.append_input_grad(0, Some(zeros));
ctx.append_input_grad(1, None); }
}
pub struct ScatterOp {
#[allow(dead_code)]
pub axis: isize,
pub output_size: usize,
}
impl<F: Float> Op<F> for ScatterOp {
fn name(&self) -> &'static str {
"Scatter"
}
fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
let indices = ctx.input(0);
let updates = ctx.input(1);
let indicesshape = indices.shape();
let updatesshape = updates.shape();
if indicesshape.len() != 1 || updatesshape.len() != 1 {
return Err(OpError::IncompatibleShape(
"Only 1D scatter supported for now".into(),
));
}
if indicesshape[0] != updatesshape[0] {
return Err(OpError::IncompatibleShape(
"Indices and updates must have same length".into(),
));
}
let indices_view = indices.view();
let updates_view = updates.view();
let mut output = Array::<F, Ix1>::zeros(self.output_size);
for (idx_val, update_val) in indices_view.iter().zip(updates_view.iter()) {
let idx = idx_val
.to_usize()
.ok_or_else(|| OpError::Other("Index must be non-negative integer".into()))?;
if idx >= self.output_size {
return Err(OpError::Other("Index out of bounds".into()));
}
output[idx] = *update_val;
}
ctx.append_output(output.into_dyn());
Ok(())
}
fn grad(&self, ctx: &mut GradientContext<F>) {
let _gy = ctx.output_grad();
let _indices = ctx.input(0);
let g = ctx.graph();
let updatesshape = crate::tensor_ops::shape(ctx.input(1));
let zeros_updates = crate::tensor_ops::zeros(&updatesshape, g);
ctx.append_input_grad(0, None); ctx.append_input_grad(1, Some(zeros_updates)); }
}
pub struct WhereOp;
impl<F: Float> Op<F> for WhereOp {
fn name(&self) -> &'static str {
"Where"
}
fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
let condition = ctx.input(0);
let x = ctx.input(1);
let y = ctx.input(2);
if condition.shape() != x.shape() || x.shape() != y.shape() {
return Err(OpError::IncompatibleShape(
"All inputs must have the same shape".into(),
));
}
let condition_view = condition.view();
let x_view = x.view();
let y_view = y.view();
let mut output = Array::<F, IxDyn>::zeros(x.shape());
for ((out_elem, &cond), (&x_elem, &y_elem)) in output
.iter_mut()
.zip(condition_view.iter())
.zip(x_view.iter().zip(y_view.iter()))
{
*out_elem = if cond != F::zero() { x_elem } else { y_elem };
}
ctx.append_output(output);
Ok(())
}
fn grad(&self, ctx: &mut GradientContext<F>) {
let _gy = ctx.output_grad();
let _condition = ctx.input(0);
let g = ctx.graph();
let xshape = crate::tensor_ops::shape(ctx.input(1));
let yshape = crate::tensor_ops::shape(ctx.input(2));
let zeros_x = crate::tensor_ops::zeros(&xshape, g);
let zeros_y = crate::tensor_ops::zeros(&yshape, g);
ctx.append_input_grad(0, None); ctx.append_input_grad(1, Some(zeros_x)); ctx.append_input_grad(2, Some(zeros_y)); }
}
pub struct AdvancedGatherOp {
pub axes: Vec<isize>,
}
impl<F: Float> Op<F> for AdvancedGatherOp {
fn name(&self) -> &'static str {
"AdvancedGather"
}
fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
let data = ctx.input(0);
let datashape = data.shape();
if self.axes.len() != 2 || ctx.inputs().len() != 3 {
return Err(OpError::IncompatibleShape(
"Advanced gather currently supports 2D indexing only".into(),
));
}
let indices0 = ctx.input(1);
let indices1 = ctx.input(2);
if indices0.shape() != indices1.shape() {
return Err(OpError::IncompatibleShape(
"Index arrays must have the same shape".into(),
));
}
let indices0_view = indices0.view();
let indices1_view = indices1.view();
let data_view = data.view();
let outputshape = indices0.shape();
let mut output = Array::<F, IxDyn>::zeros(outputshape);
for ((out_elem, &idx0), &idx1) in output
.iter_mut()
.zip(indices0_view.iter())
.zip(indices1_view.iter())
{
let i0 = idx0
.to_usize()
.ok_or_else(|| OpError::Other("Index must be non-negative integer".into()))?;
let i1 = idx1
.to_usize()
.ok_or_else(|| OpError::Other("Index must be non-negative integer".into()))?;
if i0 >= datashape[0] || i1 >= datashape[1] {
return Err(OpError::Other("Index out of bounds".into()));
}
*out_elem = data_view[[i0, i1]];
}
ctx.append_output(output);
Ok(())
}
fn grad(&self, ctx: &mut GradientContext<F>) {
let _gy = ctx.output_grad();
let g = ctx.graph();
let inputshape = crate::tensor_ops::shape(ctx.input(0));
let zeros = crate::tensor_ops::zeros(&inputshape, g);
ctx.append_input_grad(0, Some(zeros));
ctx.append_input_grad(1, None); ctx.append_input_grad(2, None); }
}
#[allow(dead_code)]
pub fn boolean_mask<'g, F: Float>(data: &Tensor<'g, F>, mask: &Tensor<'g, F>) -> Tensor<'g, F> {
let g = data.graph();
Tensor::builder(g)
.append_input(data, false)
.append_input(mask, false)
.build(BooleanMaskOp)
}
#[allow(dead_code)]
pub fn take<'g, F: Float>(
data: &Tensor<'g, F>,
indices: &Tensor<'g, F>,
axis: isize,
) -> Tensor<'g, F> {
let g = data.graph();
Tensor::builder(g)
.append_input(data, false)
.append_input(indices, false)
.build(TakeOp { axis })
}
#[allow(dead_code)]
pub fn scatter<'g, F: Float>(
indices: &Tensor<'g, F>,
updates: &Tensor<'g, F>,
output_size: usize,
axis: isize,
) -> Tensor<'g, F> {
let g = indices.graph();
Tensor::builder(g)
.append_input(indices, false)
.append_input(updates, false)
.build(ScatterOp { axis, output_size })
}
#[allow(dead_code)]
pub fn where_op<'g, F: Float>(
condition: &Tensor<'g, F>,
x: &Tensor<'g, F>,
y: &Tensor<'g, F>,
) -> Tensor<'g, F> {
let g = condition.graph();
Tensor::builder(g)
.append_input(condition, false)
.append_input(x, false)
.append_input(y, false)
.build(WhereOp)
}
#[allow(dead_code)]
pub fn advanced_gather<'g, F: Float>(
data: &Tensor<'g, F>,
indices: &[&Tensor<'g, F>],
axes: &[isize],
) -> Tensor<'g, F> {
let g = data.graph();
let mut builder = Tensor::builder(g);
builder = builder.append_input(data, false);
for idx_tensor in indices {
builder = builder.append_input(*idx_tensor, false);
}
builder.build(AdvancedGatherOp {
axes: axes.to_vec(),
})
}
#[allow(dead_code)]
pub fn get_at_coords<'g, F: Float>(
data: &Tensor<'g, F>,
row_indices: &Tensor<'g, F>,
col_indices: &Tensor<'g, F>,
) -> Tensor<'g, F> {
advanced_gather(data, &[row_indices, col_indices], &[0, 1])
}
#[allow(dead_code)]
pub fn select_rows<'g, F: Float>(
data: &Tensor<'g, F>,
row_indices: &Tensor<'g, F>,
) -> Tensor<'g, F> {
take(data, row_indices, 0)
}
#[allow(dead_code)]
pub fn select_columns<'g, F: Float>(
data: &Tensor<'g, F>,
col_indices: &Tensor<'g, F>,
) -> Tensor<'g, F> {
take(data, col_indices, 1)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_boolean_mask_op_creation() {
let op = BooleanMaskOp;
assert_eq!(<BooleanMaskOp as Op<f32>>::name(&op), "BooleanMask");
}
#[test]
fn test_take_op_creation() {
let op = TakeOp { axis: 0 };
assert_eq!(<TakeOp as Op<f32>>::name(&op), "Take");
assert_eq!(op.axis, 0);
}
#[test]
fn test_scatter_op_creation() {
let op = ScatterOp {
axis: 0,
output_size: 10,
};
assert_eq!(<ScatterOp as Op<f32>>::name(&op), "Scatter");
assert_eq!(op.axis, 0);
assert_eq!(op.output_size, 10);
}
#[test]
fn test_where_op_creation() {
let op = WhereOp;
assert_eq!(<WhereOp as Op<f32>>::name(&op), "Where");
}
#[test]
fn test_advanced_gather_op_creation() {
let op = AdvancedGatherOp { axes: vec![0, 1] };
assert_eq!(<AdvancedGatherOp as Op<f32>>::name(&op), "AdvancedGather");
assert_eq!(op.axes, vec![0, 1]);
}
}