use crate::domain::{Claim, Contract, ContractId, ContractSet, Evidence, Scope};
use crate::object::{Dim, ObjectKind, ObjectMeta, Shape};
use crate::{Error, Result};
use super::{LayerBehavior, OpInput, OpOutput, OpSignature, Operator};
fn deterministic_exact_contracts(scope: Scope) -> ContractSet {
ContractSet::from_iter([
Contract::new(
ContractId(50),
Claim::Deterministic,
scope.clone(),
Evidence::Axiom,
),
Contract::new(ContractId(51), Claim::Exact, scope, Evidence::Axiom),
])
}
fn tensor_input_check(op: &str, inputs: &[ObjectMeta], expected: usize) -> Result<()> {
if inputs.len() != expected {
return Err(Error::operator(format!(
"{op} expects {expected} input(s), got {}",
inputs.len()
)));
}
for (i, m) in inputs.iter().enumerate() {
if m.object_kind != ObjectKind::Tensor {
return Err(Error::operator(format!(
"{op} only supports tensor inputs, input {i} is {:?}",
m.object_kind
)));
}
}
Ok(())
}
#[derive(Debug, Clone, Copy, Default)]
pub struct GatherOp;
impl Operator for GatherOp {
fn name(&self) -> &'static str {
"gather"
}
fn signature(&self) -> OpSignature {
OpSignature {
inputs: vec![
OpInput {
name: "input".to_string(),
},
OpInput {
name: "indices".to_string(),
},
OpInput {
name: "axis".to_string(),
},
],
outputs: vec![OpOutput {
name: "out".to_string(),
}],
}
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
tensor_input_check(self.name(), inputs, 3)?;
let mut output = inputs[1].clone();
output.domain = inputs[0].domain.clone();
Ok(vec![output])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
deterministic_exact_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::Global]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ScatterOp;
impl Operator for ScatterOp {
fn name(&self) -> &'static str {
"scatter"
}
fn signature(&self) -> OpSignature {
OpSignature {
inputs: vec![
OpInput {
name: "input".to_string(),
},
OpInput {
name: "indices".to_string(),
},
OpInput {
name: "axis".to_string(),
},
],
outputs: vec![OpOutput {
name: "out".to_string(),
}],
}
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
tensor_input_check(self.name(), inputs, 3)?;
Ok(vec![inputs[0].clone()])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
deterministic_exact_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::Global]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct IndexSelectOp;
impl Operator for IndexSelectOp {
fn name(&self) -> &'static str {
"index_select"
}
fn signature(&self) -> OpSignature {
OpSignature {
inputs: vec![
OpInput {
name: "input".to_string(),
},
OpInput {
name: "indices".to_string(),
},
OpInput {
name: "axis".to_string(),
},
],
outputs: vec![OpOutput {
name: "out".to_string(),
}],
}
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
tensor_input_check(self.name(), inputs, 3)?;
let input_rank = inputs[0].shape.rank();
let mut output = inputs[0].clone();
output.shape = Shape::new(
(0..input_rank)
.map(|_| Dim::DataDependent("index_select_axis".to_string()))
.collect::<Vec<Dim>>(),
);
Ok(vec![output])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
deterministic_exact_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::Global]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct IndexAddOp;
impl Operator for IndexAddOp {
fn name(&self) -> &'static str {
"index_add"
}
fn signature(&self) -> OpSignature {
OpSignature {
inputs: vec![
OpInput {
name: "input".to_string(),
},
OpInput {
name: "indices".to_string(),
},
OpInput {
name: "source".to_string(),
},
OpInput {
name: "axis".to_string(),
},
],
outputs: vec![OpOutput {
name: "out".to_string(),
}],
}
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
tensor_input_check(self.name(), inputs, 4)?;
Ok(vec![inputs[0].clone()])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
deterministic_exact_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::Global]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct NonzeroOp;
impl Operator for NonzeroOp {
fn name(&self) -> &'static str {
"nonzero"
}
fn signature(&self) -> OpSignature {
OpSignature {
inputs: vec![OpInput {
name: "input".to_string(),
}],
outputs: vec![OpOutput {
name: "out".to_string(),
}],
}
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
tensor_input_check(self.name(), inputs, 1)?;
let mut output = inputs[0].clone();
let rank = inputs[0].shape.rank();
output.shape = Shape::new(vec![
Dim::DataDependent("nonzero_count".to_string()),
if rank == 0 {
Dim::Static(1)
} else {
Dim::DataDependent("nonzero_rank".to_string())
},
]);
Ok(vec![output])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
deterministic_exact_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::Global]
}
}