use crate::domain::{Claim, Contract, ContractId, ContractSet, Evidence, Scope};
use crate::object::{Dim, ObjectKind, ObjectMeta, Shape};
use crate::{Error, Result};
use super::{LayerBehavior, OpInput, OpOutput, OpSignature, Operator};
fn unary_signature() -> OpSignature {
OpSignature {
inputs: vec![OpInput {
name: "input".to_string(),
}],
outputs: vec![OpOutput {
name: "out".to_string(),
}],
}
}
fn tensor_input_check(op: &str, inputs: &[ObjectMeta], expected: usize) -> Result<()> {
if inputs.len() != expected {
return Err(Error::operator(format!(
"{op} expects {expected} input(s), got {}",
inputs.len()
)));
}
for (i, m) in inputs.iter().enumerate() {
if m.object_kind != ObjectKind::Tensor {
return Err(Error::operator(format!(
"{op} only supports tensor inputs, input {i} is {:?}",
m.object_kind
)));
}
}
Ok(())
}
fn deterministic_exact_contracts(scope: Scope) -> ContractSet {
ContractSet::from_iter([
Contract::new(
ContractId(40),
Claim::Deterministic,
scope.clone(),
Evidence::Axiom,
),
Contract::new(ContractId(41), Claim::Exact, scope, Evidence::Axiom),
])
}
fn deterministic_approximate_contracts(scope: Scope) -> ContractSet {
ContractSet::from_iter([
Contract::new(
ContractId(42),
Claim::Deterministic,
scope.clone(),
Evidence::Axiom,
),
Contract::new(ContractId(43), Claim::Approximate, scope, Evidence::Axiom),
])
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ReluOp;
impl Operator for ReluOp {
fn name(&self) -> &'static str {
"relu"
}
fn signature(&self) -> OpSignature {
unary_signature()
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
tensor_input_check(self.name(), inputs, 1)?;
Ok(vec![inputs[0].clone()])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
deterministic_exact_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct SigmoidOp;
impl Operator for SigmoidOp {
fn name(&self) -> &'static str {
"sigmoid"
}
fn signature(&self) -> OpSignature {
unary_signature()
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
tensor_input_check(self.name(), inputs, 1)?;
Ok(vec![inputs[0].clone()])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
deterministic_approximate_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct TanhOp;
impl Operator for TanhOp {
fn name(&self) -> &'static str {
"tanh"
}
fn signature(&self) -> OpSignature {
unary_signature()
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
tensor_input_check(self.name(), inputs, 1)?;
Ok(vec![inputs[0].clone()])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
deterministic_approximate_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct GeluOp;
impl Operator for GeluOp {
fn name(&self) -> &'static str {
"gelu"
}
fn signature(&self) -> OpSignature {
unary_signature()
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
tensor_input_check(self.name(), inputs, 1)?;
Ok(vec![inputs[0].clone()])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
deterministic_approximate_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct SoftmaxOp;
impl Operator for SoftmaxOp {
fn name(&self) -> &'static str {
"softmax"
}
fn signature(&self) -> OpSignature {
unary_signature()
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
tensor_input_check(self.name(), inputs, 1)?;
Ok(vec![inputs[0].clone()])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
deterministic_approximate_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::Global]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct LayerNormOp;
impl Operator for LayerNormOp {
fn name(&self) -> &'static str {
"layer_norm"
}
fn signature(&self) -> OpSignature {
OpSignature {
inputs: vec![
OpInput {
name: "input".to_string(),
},
OpInput {
name: "gamma".to_string(),
},
OpInput {
name: "beta".to_string(),
},
],
outputs: vec![OpOutput {
name: "out".to_string(),
}],
}
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
tensor_input_check(self.name(), inputs, 3)?;
let rank = inputs[0].shape.rank();
if rank == 0 {
return Err(Error::operator(
"layer_norm requires at least 1-D input (last axis is the normalization axis)",
));
}
let mut output = inputs[0].clone();
output.shape = Shape::new(
(0..rank)
.map(|i| inputs[0].shape.dims[i].clone())
.collect::<Vec<Dim>>(),
);
Ok(vec![output])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
deterministic_approximate_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::Global]
}
}