use crate::domain::{Claim, Contract, ContractId, ContractSet, Evidence, Scope};
use crate::object::{ObjectKind, ObjectMeta};
use crate::{Error, Result};
use super::{LayerBehavior, OpInput, OpOutput, OpSignature, Operator};
fn binary_signature() -> OpSignature {
OpSignature {
inputs: vec![
OpInput {
name: "lhs".to_string(),
},
OpInput {
name: "rhs".to_string(),
},
],
outputs: vec![OpOutput {
name: "out".to_string(),
}],
}
}
fn unary_signature() -> OpSignature {
OpSignature {
inputs: vec![OpInput {
name: "input".to_string(),
}],
outputs: vec![OpOutput {
name: "out".to_string(),
}],
}
}
fn infer_binary_same_shape(op: &str, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
if inputs.len() != 2 {
return Err(Error::operator(format!(
"{op} expects 2 inputs, got {}",
inputs.len()
)));
}
if inputs[0].object_kind != ObjectKind::Tensor || inputs[1].object_kind != ObjectKind::Tensor {
return Err(Error::operator(format!("{op} only supports tensor inputs")));
}
if inputs[0].domain != inputs[1].domain {
return Err(Error::domain(format!(
"{op} domain mismatch: left={:?}, right={:?}",
inputs[0].domain, inputs[1].domain
)));
}
inputs[0].shape.ensure_same(&inputs[1].shape)?;
Ok(vec![inputs[0].clone()])
}
fn infer_matmul(inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
if inputs.len() != 2 {
return Err(Error::operator(format!(
"matmul expects 2 inputs, got {}",
inputs.len()
)));
}
if inputs[0].object_kind != ObjectKind::Tensor || inputs[1].object_kind != ObjectKind::Tensor {
return Err(Error::operator("matmul only supports tensor inputs"));
}
if inputs[0].domain != inputs[1].domain {
return Err(Error::domain(format!(
"matmul domain mismatch: left={:?}, right={:?}",
inputs[0].domain, inputs[1].domain
)));
}
if inputs[0].shape.rank() != 2 || inputs[1].shape.rank() != 2 {
return Err(Error::shape(format!(
"matmul expects rank-2 tensors, got left rank {} and right rank {}",
inputs[0].shape.rank(),
inputs[1].shape.rank()
)));
}
let lhs_dims = &inputs[0].shape.dims;
let rhs_dims = &inputs[1].shape.dims;
inputs[0]
.shape
.ensure_dim_proves_equal(1, &inputs[1].shape, 0)?;
let mut output = inputs[0].clone();
output.shape = crate::object::Shape::new(vec![lhs_dims[0].clone(), rhs_dims[1].clone()]);
Ok(vec![output])
}
fn exact_arithmetic_contracts(scope: Scope) -> ContractSet {
ContractSet::from_iter([
Contract::new(
ContractId(1),
Claim::Deterministic,
scope.clone(),
Evidence::Axiom,
),
Contract::new(ContractId(2), Claim::Exact, scope, Evidence::Axiom),
])
}
#[derive(Debug, Clone, Copy, Default)]
pub struct AddOp;
impl Operator for AddOp {
fn name(&self) -> &'static str {
"add"
}
fn signature(&self) -> OpSignature {
binary_signature()
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
infer_binary_same_shape(self.name(), inputs)
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
exact_arithmetic_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::CoverLocal]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct MulOp;
impl Operator for MulOp {
fn name(&self) -> &'static str {
"mul"
}
fn signature(&self) -> OpSignature {
binary_signature()
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
infer_binary_same_shape(self.name(), inputs)
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
exact_arithmetic_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![
LayerBehavior::Pointwise,
LayerBehavior::CoverLocal,
LayerBehavior::ValuationFiltered,
]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct MapOp;
impl Operator for MapOp {
fn name(&self) -> &'static str {
"map"
}
fn signature(&self) -> OpSignature {
unary_signature()
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
if inputs.len() != 1 {
return Err(Error::operator(format!(
"map expects 1 input, got {}",
inputs.len()
)));
}
Ok(vec![inputs[0].clone()])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
ContractSet::from_iter([Contract::new(
ContractId(3),
Claim::Local,
Scope::Operator(self.name().to_string()),
Evidence::Axiom,
)])
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::CoverLocal]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ReduceOp;
impl Operator for ReduceOp {
fn name(&self) -> &'static str {
"reduce"
}
fn signature(&self) -> OpSignature {
unary_signature()
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
if inputs.len() != 1 {
return Err(Error::operator(format!(
"reduce expects 1 input, got {}",
inputs.len()
)));
}
let mut output = inputs[0].clone();
output.shape = crate::object::Shape::scalar();
Ok(vec![output])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::from_iter([Contract::new(
ContractId(4),
Claim::Associative,
Scope::Operator(self.name().to_string()),
Evidence::Axiom,
)])
}
fn provided_contracts(&self) -> ContractSet {
ContractSet::from_iter([Contract::new(
ContractId(5),
Claim::Deterministic,
Scope::Operator(self.name().to_string()),
Evidence::Axiom,
)])
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Global]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct MatmulOp;
impl Operator for MatmulOp {
fn name(&self) -> &'static str {
"matmul"
}
fn signature(&self) -> OpSignature {
binary_signature()
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
infer_matmul(inputs)
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
exact_arithmetic_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Global, LayerBehavior::PrecisionLayered]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct FmaOp;
impl Operator for FmaOp {
fn name(&self) -> &'static str {
"fma"
}
fn signature(&self) -> OpSignature {
OpSignature {
inputs: vec![
OpInput {
name: "a".to_string(),
},
OpInput {
name: "b".to_string(),
},
OpInput {
name: "c".to_string(),
},
],
outputs: vec![OpOutput {
name: "out".to_string(),
}],
}
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
if inputs.len() != 3 {
return Err(Error::operator(format!(
"fma expects 3 inputs, got {}",
inputs.len()
)));
}
for (i, m) in inputs.iter().enumerate() {
if m.object_kind != ObjectKind::Tensor {
return Err(Error::operator(format!(
"fma only supports tensor inputs, input {i} is {:?}",
m.object_kind
)));
}
}
if inputs[0].domain != inputs[1].domain || inputs[1].domain != inputs[2].domain {
return Err(Error::domain(format!(
"fma domain mismatch: a={:?}, b={:?}, c={:?}",
inputs[0].domain, inputs[1].domain, inputs[2].domain
)));
}
inputs[0].shape.ensure_same(&inputs[1].shape)?;
inputs[1].shape.ensure_same(&inputs[2].shape)?;
Ok(vec![inputs[0].clone()])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
exact_arithmetic_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::CoverLocal]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct PAdicMatmulFmaOp;
impl Operator for PAdicMatmulFmaOp {
fn name(&self) -> &'static str {
"p_pad_fma"
}
fn signature(&self) -> OpSignature {
OpSignature {
inputs: vec![
OpInput {
name: "a".to_string(),
},
OpInput {
name: "b".to_string(),
},
OpInput {
name: "c".to_string(),
},
],
outputs: vec![OpOutput {
name: "out".to_string(),
}],
}
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
if inputs.len() != 3 {
return Err(Error::operator(format!(
"p_pad_fma expects 3 inputs, got {}",
inputs.len()
)));
}
for (i, m) in inputs.iter().enumerate() {
if m.object_kind != ObjectKind::Tensor {
return Err(Error::operator(format!(
"p_pad_fma only supports tensor inputs, input {i} is {:?}",
m.object_kind
)));
}
}
if inputs[0].domain != inputs[1].domain || inputs[1].domain != inputs[2].domain {
return Err(Error::domain(format!(
"p_pad_fma domain mismatch: a={:?}, b={:?}, c={:?}",
inputs[0].domain, inputs[1].domain, inputs[2].domain
)));
}
if inputs[0].shape.rank() != 2 || inputs[1].shape.rank() != 2 || inputs[2].shape.rank() != 2
{
return Err(Error::shape(format!(
"p_pad_fma expects rank-2 tensors, got left rank {} right rank {} bias rank {}",
inputs[0].shape.rank(),
inputs[1].shape.rank(),
inputs[2].shape.rank()
)));
}
let lhs_dims = &inputs[0].shape.dims;
let rhs_dims = &inputs[1].shape.dims;
let c_dims = &inputs[2].shape.dims;
inputs[0]
.shape
.ensure_dim_proves_equal(1, &inputs[1].shape, 0)?;
if c_dims[0] != lhs_dims[0] || c_dims[1] != rhs_dims[1] {
return Err(Error::shape(format!(
"p_pad_fma bias shape {:?} does not match output shape [{}, {}]",
inputs[2].shape, lhs_dims[0], rhs_dims[1]
)));
}
let mut output = inputs[0].clone();
output.shape = crate::object::Shape::new(vec![lhs_dims[0].clone(), rhs_dims[1].clone()]);
Ok(vec![output])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
exact_arithmetic_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Global, LayerBehavior::PrecisionLayered]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct PAdicDotOp;
impl Operator for PAdicDotOp {
fn name(&self) -> &'static str {
"p_dot"
}
fn signature(&self) -> OpSignature {
OpSignature {
inputs: vec![
OpInput {
name: "a".to_string(),
},
OpInput {
name: "b".to_string(),
},
],
outputs: vec![OpOutput {
name: "out".to_string(),
}],
}
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
if inputs.len() != 2 {
return Err(Error::operator(format!(
"p_dot expects 2 inputs, got {}",
inputs.len()
)));
}
for (i, m) in inputs.iter().enumerate() {
if m.object_kind != ObjectKind::Tensor {
return Err(Error::operator(format!(
"p_dot only supports tensor inputs, input {i} is {:?}",
m.object_kind
)));
}
}
if inputs[0].domain != inputs[1].domain {
return Err(Error::domain(format!(
"p_dot domain mismatch: a={:?}, b={:?}",
inputs[0].domain, inputs[1].domain
)));
}
if inputs[0].shape.rank() != 1 || inputs[1].shape.rank() != 1 {
return Err(Error::shape(format!(
"p_dot expects rank-1 vector inputs, got rank {} and rank {}",
inputs[0].shape.rank(),
inputs[1].shape.rank()
)));
}
inputs[0].shape.ensure_same(&inputs[1].shape)?;
let mut output = inputs[0].clone();
output.shape = crate::object::Shape::scalar();
Ok(vec![output])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
exact_arithmetic_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Global, LayerBehavior::PrecisionLayered]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ClampOp;
impl Operator for ClampOp {
fn name(&self) -> &'static str {
"clamp"
}
fn signature(&self) -> OpSignature {
OpSignature {
inputs: vec![
OpInput {
name: "input".to_string(),
},
OpInput {
name: "lo".to_string(),
},
OpInput {
name: "hi".to_string(),
},
],
outputs: vec![OpOutput {
name: "out".to_string(),
}],
}
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
if inputs.len() != 3 {
return Err(Error::operator(format!(
"clamp expects 3 inputs (data, lo, hi), got {}",
inputs.len()
)));
}
for (i, m) in inputs.iter().enumerate() {
if m.object_kind != ObjectKind::Tensor {
return Err(Error::operator(format!(
"clamp only supports tensor inputs, input {i} is {:?}",
m.object_kind
)));
}
}
Ok(vec![inputs[0].clone()])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
exact_arithmetic_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct NegOp;
impl Operator for NegOp {
fn name(&self) -> &'static str {
"neg"
}
fn signature(&self) -> OpSignature {
unary_signature()
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
if inputs.len() != 1 {
return Err(Error::operator(format!(
"neg expects 1 input, got {}",
inputs.len()
)));
}
let m = &inputs[0];
if m.object_kind != ObjectKind::Tensor {
return Err(Error::operator(format!(
"neg only supports tensor inputs, got {:?}",
m.object_kind
)));
}
Ok(vec![m.clone()])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
exact_arithmetic_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::CoverLocal]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct AbsOp;
impl Operator for AbsOp {
fn name(&self) -> &'static str {
"abs"
}
fn signature(&self) -> OpSignature {
unary_signature()
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
if inputs.len() != 1 {
return Err(Error::operator(format!(
"abs expects 1 input, got {}",
inputs.len()
)));
}
let m = &inputs[0];
if m.object_kind != ObjectKind::Tensor {
return Err(Error::operator(format!(
"abs only supports tensor inputs, got {:?}",
m.object_kind
)));
}
Ok(vec![m.clone()])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
exact_arithmetic_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::CoverLocal]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct SquareOp;
impl Operator for SquareOp {
fn name(&self) -> &'static str {
"square"
}
fn signature(&self) -> OpSignature {
unary_signature()
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
if inputs.len() != 1 {
return Err(Error::operator(format!(
"square expects 1 input, got {}",
inputs.len()
)));
}
let m = &inputs[0];
if m.object_kind != ObjectKind::Tensor {
return Err(Error::operator(format!(
"square only supports tensor inputs, got {:?}",
m.object_kind
)));
}
Ok(vec![m.clone()])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
exact_arithmetic_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::CoverLocal]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct MulByTwoOp;
impl Operator for MulByTwoOp {
fn name(&self) -> &'static str {
"mul_by_two"
}
fn signature(&self) -> OpSignature {
unary_signature()
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
if inputs.len() != 1 {
return Err(Error::operator(format!(
"mul_by_two expects 1 input, got {}",
inputs.len()
)));
}
let m = &inputs[0];
if m.object_kind != ObjectKind::Tensor {
return Err(Error::operator(format!(
"mul_by_two only supports tensor inputs, got {:?}",
m.object_kind
)));
}
Ok(vec![m.clone()])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
exact_arithmetic_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::CoverLocal]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct SubOp;
impl Operator for SubOp {
fn name(&self) -> &'static str {
"sub"
}
fn signature(&self) -> OpSignature {
binary_signature()
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
infer_binary_same_shape(self.name(), inputs)
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
exact_arithmetic_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::CoverLocal]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct DivOp;
impl Operator for DivOp {
fn name(&self) -> &'static str {
"div"
}
fn signature(&self) -> OpSignature {
binary_signature()
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
infer_binary_same_shape(self.name(), inputs)
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
exact_arithmetic_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::CoverLocal]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ScalarAddOp;
impl Operator for ScalarAddOp {
fn name(&self) -> &'static str {
"scalar_add"
}
fn signature(&self) -> OpSignature {
OpSignature {
inputs: vec![
OpInput {
name: "tensor".to_string(),
},
OpInput {
name: "scalar".to_string(),
},
],
outputs: vec![OpOutput {
name: "out".to_string(),
}],
}
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
if inputs.len() != 2 {
return Err(Error::operator(format!(
"scalar_add expects 2 inputs (tensor, scalar), got {}",
inputs.len()
)));
}
for (i, m) in inputs.iter().enumerate() {
if m.object_kind != ObjectKind::Tensor {
return Err(Error::operator(format!(
"scalar_add only supports tensor inputs, input {i} is {:?}",
m.object_kind
)));
}
}
if inputs[0].domain != inputs[1].domain {
return Err(Error::domain(format!(
"scalar_add domain mismatch: tensor={:?}, scalar={:?}",
inputs[0].domain, inputs[1].domain
)));
}
Ok(vec![inputs[0].clone()])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
exact_arithmetic_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::CoverLocal]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ScalarMulOp;
impl Operator for ScalarMulOp {
fn name(&self) -> &'static str {
"scalar_mul"
}
fn signature(&self) -> OpSignature {
OpSignature {
inputs: vec![
OpInput {
name: "tensor".to_string(),
},
OpInput {
name: "scalar".to_string(),
},
],
outputs: vec![OpOutput {
name: "out".to_string(),
}],
}
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
if inputs.len() != 2 {
return Err(Error::operator(format!(
"scalar_mul expects 2 inputs (tensor, scalar), got {}",
inputs.len()
)));
}
for (i, m) in inputs.iter().enumerate() {
if m.object_kind != ObjectKind::Tensor {
return Err(Error::operator(format!(
"scalar_mul only supports tensor inputs, input {i} is {:?}",
m.object_kind
)));
}
}
if inputs[0].domain != inputs[1].domain {
return Err(Error::domain(format!(
"scalar_mul domain mismatch: tensor={:?}, scalar={:?}",
inputs[0].domain, inputs[1].domain
)));
}
Ok(vec![inputs[0].clone()])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
exact_arithmetic_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::CoverLocal]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct PowOp;
impl Operator for PowOp {
fn name(&self) -> &'static str {
"pow"
}
fn signature(&self) -> OpSignature {
OpSignature {
inputs: vec![
OpInput {
name: "tensor".to_string(),
},
OpInput {
name: "exp".to_string(),
},
],
outputs: vec![OpOutput {
name: "out".to_string(),
}],
}
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
if inputs.len() != 2 {
return Err(Error::operator(format!(
"pow expects 2 inputs (tensor, exp), got {}",
inputs.len()
)));
}
for (i, m) in inputs.iter().enumerate() {
if m.object_kind != ObjectKind::Tensor {
return Err(Error::operator(format!(
"pow only supports tensor inputs, input {i} is {:?}",
m.object_kind
)));
}
}
if inputs[0].domain != inputs[1].domain {
return Err(Error::domain(format!(
"pow domain mismatch: tensor={:?}, exp={:?}",
inputs[0].domain, inputs[1].domain
)));
}
Ok(vec![inputs[0].clone()])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
exact_arithmetic_contracts(Scope::Operator(self.name().to_string()))
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::CoverLocal]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct SqrtOp;
impl Operator for SqrtOp {
fn name(&self) -> &'static str {
"sqrt"
}
fn signature(&self) -> OpSignature {
unary_signature()
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
if inputs.len() != 1 {
return Err(Error::operator(format!(
"sqrt expects 1 input, got {}",
inputs.len()
)));
}
let m = &inputs[0];
if m.object_kind != ObjectKind::Tensor {
return Err(Error::operator(format!(
"sqrt only supports tensor inputs, got {:?}",
m.object_kind
)));
}
Ok(vec![m.clone()])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
ContractSet::from_iter([Contract::new(
ContractId(20),
Claim::Deterministic,
Scope::Operator(self.name().to_string()),
Evidence::Axiom,
)])
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::CoverLocal]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Exp2Op;
impl Operator for Exp2Op {
fn name(&self) -> &'static str {
"exp2"
}
fn signature(&self) -> OpSignature {
unary_signature()
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
if inputs.len() != 1 {
return Err(Error::operator(format!(
"exp2 expects 1 input, got {}",
inputs.len()
)));
}
let m = &inputs[0];
if m.object_kind != ObjectKind::Tensor {
return Err(Error::operator(format!(
"exp2 only supports tensor inputs, got {:?}",
m.object_kind
)));
}
Ok(vec![m.clone()])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
ContractSet::from_iter([Contract::new(
ContractId(21),
Claim::Deterministic,
Scope::Operator(self.name().to_string()),
Evidence::Axiom,
)])
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::CoverLocal]
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Log2Op;
impl Operator for Log2Op {
fn name(&self) -> &'static str {
"log2"
}
fn signature(&self) -> OpSignature {
unary_signature()
}
fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
if inputs.len() != 1 {
return Err(Error::operator(format!(
"log2 expects 1 input, got {}",
inputs.len()
)));
}
let m = &inputs[0];
if m.object_kind != ObjectKind::Tensor {
return Err(Error::operator(format!(
"log2 only supports tensor inputs, got {:?}",
m.object_kind
)));
}
Ok(vec![m.clone()])
}
fn required_contracts(&self) -> ContractSet {
ContractSet::new()
}
fn provided_contracts(&self) -> ContractSet {
ContractSet::from_iter([Contract::new(
ContractId(22),
Claim::Deterministic,
Scope::Operator(self.name().to_string()),
Evidence::Axiom,
)])
}
fn layer_behavior(&self) -> Vec<LayerBehavior> {
vec![LayerBehavior::Pointwise, LayerBehavior::CoverLocal]
}
}