use crate::domain::{Claim, Contract, ContractId, ContractSet, Evidence, Scope};
use crate::object::{Dim, ObjectKind, ObjectMeta, Shape};
use crate::{Error, Result};
use super::{LayerBehavior, OpInput, OpOutput, OpSignature, Operator};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReductionKind {
Sum,
Mean,
Max,
Min,
ArgMax,
ArgMin,
Prod,
Any,
All,
}
impl ReductionKind {
pub fn op_name(&self) -> &'static str {
match self {
ReductionKind::Sum => "sum",
ReductionKind::Mean => "mean",
ReductionKind::Max => "max",
ReductionKind::Min => "min",
ReductionKind::ArgMax => "argmax",
ReductionKind::ArgMin => "argmin",
ReductionKind::Prod => "prod",
ReductionKind::Any => "any",
ReductionKind::All => "all",
}
}
}
fn deterministic_exact_contracts(scope: Scope) -> ContractSet {
ContractSet::from_iter([
Contract::new(
ContractId(60),
Claim::Deterministic,
scope.clone(),
Evidence::Axiom,
),
Contract::new(ContractId(61), Claim::Exact, scope, Evidence::Axiom),
])
}
fn binary_signature_with_axis() -> OpSignature {
OpSignature {
inputs: vec![
OpInput {
name: "input".to_string(),
},
OpInput {
name: "axis".to_string(),
},
],
outputs: vec![OpOutput {
name: "out".to_string(),
}],
}
}
fn check_two_tensor_inputs(op: &str, inputs: &[ObjectMeta]) -> Result<()> {
if inputs.len() != 2 {
return Err(Error::operator(format!(
"{op} expects 2 inputs (input, axis), got {}",
inputs.len()
)));
}
for (i, m) in inputs.iter().enumerate() {
if m.object_kind != ObjectKind::Tensor {
return Err(Error::operator(format!(
"{op} only supports tensor inputs, input {i} is {:?}",
m.object_kind
)));
}
}
Ok(())
}
macro_rules! reduction_op {
($name:ident, $kind:ident) => {
#[doc = concat!("Named reduction: `", stringify!($kind), "` over a single axis.")]
#[derive(Debug, Clone, Copy, Default)]
pub struct $name;
impl Operator for $name {
fn name(&self) -> &'static str {
ReductionKind::$kind.op_name()
}
fn signature(&self) -> OpSignature {
binary_signature_with_axis()
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
check_two_tensor_inputs(self.name(), inputs)?;
let rank = inputs[0].shape.rank();
if rank == 0 {
return Err(Error::operator(format!(
"{}: rank-0 input cannot be reduced over an axis",
self.name()
)));
}
let mut output = inputs[0].clone();
let mut dims = Vec::with_capacity(rank.saturating_sub(1));
for d in &inputs[0].shape.dims {
dims.push(Dim::DataDependent(format!("{}_reduced", self.name())));
let _ = d; }
dims.clear();
output.shape = Shape::new(dims);
Ok(vec![output])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
deterministic_exact_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::Global]
}
}
};
}
reduction_op!(SumOp, Sum);
reduction_op!(MeanOp, Mean);
reduction_op!(MaxOp, Max);
reduction_op!(MinOp, Min);
reduction_op!(ArgMaxOp, ArgMax);
reduction_op!(ArgMinOp, ArgMin);
reduction_op!(ProdOp, Prod);
reduction_op!(AnyOp, Any);
reduction_op!(AllOp, All);