use crate::domain::{Claim, Contract, ContractId, ContractSet, Evidence, Scope};
use crate::object::{Dim, ObjectKind, ObjectMeta, Shape};
use crate::{Error, Result};
use super::{LayerBehavior, OpInput, OpOutput, OpSignature, Operator};
fn shape_arithmetic_contracts(scope: Scope) -> ContractSet {
ContractSet::from_iter([
Contract::new(
ContractId(30),
Claim::Deterministic,
scope.clone(),
Evidence::Axiom,
),
Contract::new(ContractId(31), Claim::Exact, scope, Evidence::Axiom),
])
}
fn tensor_input_check(op: &str, inputs: &[ObjectMeta], expected: usize) -> Result<()> {
if inputs.len() != expected {
return Err(Error::operator(format!(
"{op} expects {expected} input(s), got {}",
inputs.len()
)));
}
for (i, m) in inputs.iter().enumerate() {
if m.object_kind != ObjectKind::Tensor {
return Err(Error::operator(format!(
"{op} only supports tensor inputs, input {i} is {:?}",
m.object_kind
)));
}
}
Ok(())
}
fn parse_target_shape(meta: &ObjectMeta) -> Result<Vec<Dim>> {
if meta.object_kind != ObjectKind::Tensor {
return Err(Error::operator(format!(
"target shape input must be a tensor, got {:?}",
meta.object_kind
)));
}
Ok(meta.shape.dims.clone())
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ReshapeOp;
impl Operator for ReshapeOp {
fn name(&self) -> &'static str {
"reshape"
}
fn signature(&self) -> OpSignature {
OpSignature {
inputs: vec![
OpInput {
name: "input".to_string(),
},
OpInput {
name: "target_shape".to_string(),
},
],
outputs: vec![OpOutput {
name: "out".to_string(),
}],
}
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
tensor_input_check(self.name(), inputs, 2)?;
let target_dims = parse_target_shape(&inputs[1])?;
let mut output = inputs[0].clone();
output.shape = Shape::new(target_dims);
Ok(vec![output])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
shape_arithmetic_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::Global]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct TransposeOp;
impl Operator for TransposeOp {
fn name(&self) -> &'static str {
"transpose"
}
fn signature(&self) -> OpSignature {
OpSignature {
inputs: vec![
OpInput {
name: "input".to_string(),
},
OpInput {
name: "axes".to_string(),
},
],
outputs: vec![OpOutput {
name: "out".to_string(),
}],
}
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
tensor_input_check(self.name(), inputs, 2)?;
let rank = inputs[0].shape.rank();
let mut output = inputs[0].clone();
output.shape = Shape::new(
(0..rank)
.map(|i| inputs[0].shape.dims[i].clone())
.collect::<Vec<Dim>>(),
);
Ok(vec![output])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
shape_arithmetic_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::Global]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct PermuteOp;
impl Operator for PermuteOp {
fn name(&self) -> &'static str {
"permute"
}
fn signature(&self) -> OpSignature {
OpSignature {
inputs: vec![
OpInput {
name: "input".to_string(),
},
OpInput {
name: "permutation".to_string(),
},
],
outputs: vec![OpOutput {
name: "out".to_string(),
}],
}
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
tensor_input_check(self.name(), inputs, 2)?;
let rank = inputs[0].shape.rank();
let mut output = inputs[0].clone();
output.shape = Shape::new(
(0..rank)
.map(|i| inputs[0].shape.dims[i].clone())
.collect::<Vec<Dim>>(),
);
Ok(vec![output])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
shape_arithmetic_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::Global]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct SliceOp;
impl Operator for SliceOp {
fn name(&self) -> &'static str {
"slice"
}
fn signature(&self) -> OpSignature {
OpSignature {
inputs: vec![
OpInput {
name: "input".to_string(),
},
OpInput {
name: "bounds".to_string(),
},
],
outputs: vec![OpOutput {
name: "out".to_string(),
}],
}
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
tensor_input_check(self.name(), inputs, 2)?;
let mut output = inputs[0].clone();
let rank = inputs[0].shape.rank();
output.shape = Shape::new(
(0..rank)
.map(|i| inputs[0].shape.dims[i].clone())
.collect::<Vec<Dim>>(),
);
Ok(vec![output])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
shape_arithmetic_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::Global]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ConcatOp;
impl Operator for ConcatOp {
fn name(&self) -> &'static str {
"concat"
}
fn signature(&self) -> OpSignature {
OpSignature {
inputs: vec![OpInput {
name: "first".to_string(),
}],
outputs: vec![OpOutput {
name: "out".to_string(),
}],
}
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
if inputs.is_empty() {
return Err(Error::operator("concat expects at least 1 input"));
}
for (i, m) in inputs.iter().enumerate() {
if m.object_kind != ObjectKind::Tensor {
return Err(Error::operator(format!(
"concat only supports tensor inputs, input {i} is {:?}",
m.object_kind
)));
}
}
Ok(vec![inputs[0].clone()])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
shape_arithmetic_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::Global]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct BroadcastOp;
impl Operator for BroadcastOp {
fn name(&self) -> &'static str {
"broadcast"
}
fn signature(&self) -> OpSignature {
OpSignature {
inputs: vec![
OpInput {
name: "input".to_string(),
},
OpInput {
name: "target_shape".to_string(),
},
],
outputs: vec![OpOutput {
name: "out".to_string(),
}],
}
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
tensor_input_check(self.name(), inputs, 2)?;
let target_dims = parse_target_shape(&inputs[1])?;
let mut output = inputs[0].clone();
output.shape = Shape::new(target_dims);
Ok(vec![output])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
shape_arithmetic_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::Global]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct FlattenOp;
impl Operator for FlattenOp {
fn name(&self) -> &'static str {
"flatten"
}
fn signature(&self) -> OpSignature {
unary_signature_like()
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
if inputs.len() != 1 {
return Err(Error::operator(format!(
"flatten expects 1 input, got {}",
inputs.len()
)));
}
let m = &inputs[0];
if m.object_kind != ObjectKind::Tensor {
return Err(Error::operator(format!(
"flatten only supports tensor inputs, got {:?}",
m.object_kind
)));
}
let mut output = m.clone();
let product: Option<usize> = if m.shape.dims.is_empty() {
Some(1)
} else {
m.shape
.dims
.iter()
.try_fold(1usize, |acc, d| d.value().map(|v| acc * v))
};
let flat_dim = match product {
Some(n) => Dim::Static(n),
None => Dim::DataDependent("flatten_product".to_string()),
};
output.shape = Shape::new(vec![flat_dim]);
Ok(vec![output])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
shape_arithmetic_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::Global]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct SqueezeOp;
impl Operator for SqueezeOp {
fn name(&self) -> &'static str {
"squeeze"
}
fn signature(&self) -> OpSignature {
OpSignature {
inputs: vec![OpInput {
name: "input".to_string(),
}],
outputs: vec![OpOutput {
name: "out".to_string(),
}],
}
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
if inputs.len() != 1 {
return Err(Error::operator(format!(
"squeeze expects 1 input, got {}",
inputs.len()
)));
}
let m = &inputs[0];
if m.object_kind != ObjectKind::Tensor {
return Err(Error::operator(format!(
"squeeze only supports tensor inputs, got {:?}",
m.object_kind
)));
}
let mut output = m.clone();
output.shape = Shape::new(
m.shape
.dims
.iter()
.filter(|d| !matches!(d, Dim::Static(1)))
.cloned()
.collect::<Vec<Dim>>(),
);
Ok(vec![output])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
shape_arithmetic_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::Global]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct UnsqueezeOp;
impl Operator for UnsqueezeOp {
fn name(&self) -> &'static str {
"unsqueeze"
}
fn signature(&self) -> OpSignature {
OpSignature {
inputs: vec![
OpInput {
name: "input".to_string(),
},
OpInput {
name: "axis".to_string(),
},
],
outputs: vec![OpOutput {
name: "out".to_string(),
}],
}
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
tensor_input_check(self.name(), inputs, 2)?;
let rank = inputs[0].shape.rank();
let mut output = inputs[0].clone();
let mut new_dims = inputs[0].shape.dims.clone();
new_dims.insert(rank, Dim::Static(1));
output.shape = Shape::new(new_dims);
Ok(vec![output])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
shape_arithmetic_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::Global]
}
}
fn unary_signature_like() -> OpSignature {
OpSignature {
inputs: vec![OpInput {
name: "input".to_string(),
}],
outputs: vec![OpOutput {
name: "out".to_string(),
}],
}
}