use crate::spec::types::OpSpec;
use crate::spec::types::{DataType, OpSignature};
use crate::spec::OverflowContract;
#[derive(Debug, Clone, PartialEq, Eq)]
#[allow(missing_docs)]
pub struct ComposedCertificate {
pub left: &'static str,
pub right: &'static str,
pub input_type: DataType,
pub output_type: DataType,
pub overflow_contract: Option<OverflowContract>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ClosureError {
TypeMismatch {
left: String,
right: String,
left_output: DataType,
right_input: DataType,
},
NonUnaryRight {
right: String,
arity: usize,
},
NonUnaryLeft {
left: String,
arity: usize,
},
OverflowContractConflict {
left: String,
right: String,
left_contract: OverflowContract,
right_contract: OverflowContract,
},
}
impl ClosureError {
#[must_use]
#[inline]
pub fn fix_hint(&self) -> String {
match self {
Self::TypeMismatch {
left,
right,
left_output,
right_input,
} => format!(
"Fix: change `{left}` to output {right_input:?} or change `{right}` to accept {left_output:?}."
),
Self::NonUnaryLeft { left, arity } => format!(
"Fix: `{left}` takes {arity} inputs; the 1-in-1-out closure rule does not apply. Declare a decomposition explicitly in conform/src/enforce/decomposition.rs."
),
Self::NonUnaryRight { right, arity } => format!(
"Fix: `{right}` takes {arity} inputs; the 1-in-1-out closure rule does not apply. Declare a decomposition explicitly in conform/src/enforce/decomposition.rs."
),
Self::OverflowContractConflict {
left,
right,
left_contract,
right_contract,
} => format!(
"Fix: align the overflow contracts (`{left}` declares {left_contract}, `{right}` declares {right_contract}). One option: mark the composed op as `Checked` with an explicit bounds check."
),
}
}
}
impl std::fmt::Display for ClosureError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TypeMismatch {
left,
right,
left_output,
right_input,
} => write!(
f,
"composition `{right} ∘ {left}`: output type {left_output:?} of `{left}` does not match input type {right_input:?} of `{right}`. {}",
self.fix_hint()
),
Self::NonUnaryLeft { left, arity } => write!(
f,
"composition left op `{left}` has arity {arity}, closure requires unary. {}",
self.fix_hint()
),
Self::NonUnaryRight { right, arity } => write!(
f,
"composition right op `{right}` has arity {arity}, closure requires unary. {}",
self.fix_hint()
),
Self::OverflowContractConflict {
left,
right,
left_contract,
right_contract,
} => write!(
f,
"composition `{right} ∘ {left}`: overflow contracts conflict ({left}={left_contract}, {right}={right_contract}). {}",
self.fix_hint()
),
}
}
}
#[inline]
pub fn closure_cert_for(
left: &OpSpec,
right: &OpSpec,
) -> Result<ComposedCertificate, ClosureError> {
if left.signature.inputs.len() != 1 {
return Err(ClosureError::NonUnaryLeft {
left: left.id.to_string(),
arity: left.signature.inputs.len(),
});
}
if right.signature.inputs.len() != 1 {
return Err(ClosureError::NonUnaryRight {
right: right.id.to_string(),
arity: right.signature.inputs.len(),
});
}
let left_output = left.signature.output.clone();
let right_input = right.signature.inputs[0].clone();
if left_output != right_input {
return Err(ClosureError::TypeMismatch {
left: left.id.to_string(),
right: right.id.to_string(),
left_output,
right_input,
});
}
let overflow_contract = match (left.overflow_contract, right.overflow_contract) {
(Some(l), Some(r)) if l == r => Some(l),
(Some(_), Some(OverflowContract::Checked | OverflowContract::Unchecked)) => {
right.overflow_contract
}
(Some(OverflowContract::Checked | OverflowContract::Unchecked), Some(_)) => {
left.overflow_contract
}
(Some(l), Some(r)) => {
return Err(ClosureError::OverflowContractConflict {
left: left.id.to_string(),
right: right.id.to_string(),
left_contract: l,
right_contract: r,
});
}
(Some(l), None) => Some(l),
(None, Some(r)) => Some(r),
(None, None) => None,
};
Ok(ComposedCertificate {
left: left.id,
right: right.id,
input_type: left.signature.inputs[0].clone(),
output_type: right.signature.output.clone(),
overflow_contract,
})
}
#[inline]
pub fn chain_is_closed(chain: &[&OpSpec]) -> Result<ChainCertificate, ClosureError> {
if chain.is_empty() {
return Ok(ChainCertificate {
op_ids: Vec::new(),
input_type: None,
output_type: None,
overflow_contract: None,
});
}
if chain.len() == 1 {
let only = chain[0];
return Ok(ChainCertificate {
op_ids: vec![only.id],
input_type: only.signature.inputs.first().cloned(),
output_type: Some(only.signature.output.clone()),
overflow_contract: only.overflow_contract,
});
}
let mut accumulator: Option<ChainCertificate> = None;
for pair in chain.windows(2) {
let left = pair[0];
let right = pair[1];
let _cert = closure_cert_for(left, right)?;
accumulator = Some(match accumulator {
None => ChainCertificate {
op_ids: vec![left.id, right.id],
input_type: left.signature.inputs.first().cloned(),
output_type: Some(right.signature.output.clone()),
overflow_contract: match (left.overflow_contract, right.overflow_contract) {
(Some(l), Some(r)) if l == r => Some(l),
(Some(l), None) => Some(l),
(None, Some(r)) => Some(r),
_ => left.overflow_contract.or(right.overflow_contract),
},
},
Some(mut acc) => {
acc.op_ids.push(right.id);
acc.output_type = Some(right.signature.output.clone());
if let Some(r) = right.overflow_contract {
acc.overflow_contract = Some(r);
}
acc
}
});
}
Ok(accumulator.expect("non-empty chain always produces an accumulator"))
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[allow(missing_docs)]
pub struct ChainCertificate {
pub op_ids: Vec<&'static str>,
pub input_type: Option<DataType>,
pub output_type: Option<DataType>,
pub overflow_contract: Option<OverflowContract>,
}
#[must_use]
#[inline]
pub fn is_self_composable(spec: &OpSpec) -> bool {
spec.signature.inputs.len() == 1
&& signature_output_is(&spec.signature, &spec.signature.inputs[0])
}
fn signature_output_is(sig: &OpSignature, ty: &DataType) -> bool {
&sig.output == ty
}
pub struct CompositionClosureEnforcer;
impl crate::enforce::EnforceGate for CompositionClosureEnforcer {
fn id(&self) -> &'static str {
"composition_closure"
}
fn name(&self) -> &'static str {
"composition_closure"
}
fn run(&self, ctx: &crate::enforce::EnforceCtx<'_>) -> Vec<crate::enforce::Finding> {
let _ = ctx;
crate::enforce::finding_result(self.id(), Vec::new())
}
}
pub const REGISTERED: CompositionClosureEnforcer = CompositionClosureEnforcer;
#[cfg(test)]
mod tests {
use super::*;
use crate::spec::types::conform::Strictness;
use crate::spec::types::OpSpec;
use crate::spec::types::{DataType, OpSignature};
use crate::spec::{AlgebraicLaw, OverflowContract};
use vyre_spec::Category;
fn unary_op(id: &'static str, input: DataType, output: DataType) -> OpSpec {
OpSpec::builder(id)
.signature(OpSignature {
inputs: vec![input],
output,
})
.cpu_fn(|i| i.to_vec())
.wgsl_fn(|| "fn main() {}".to_string())
.category(Category::A {
composition_of: vec![id],
})
.laws(vec![AlgebraicLaw::Bounded {
lo: 0,
hi: u32::MAX,
}])
.strictness(Strictness::Strict)
.version(1)
.build()
.unwrap()
}
fn binary_op(id: &'static str) -> OpSpec {
OpSpec::builder(id)
.signature(OpSignature {
inputs: vec![DataType::U32, DataType::U32],
output: DataType::U32,
})
.cpu_fn(|i| i.to_vec())
.wgsl_fn(|| "fn main() {}".to_string())
.category(Category::A {
composition_of: vec![id],
})
.laws(vec![AlgebraicLaw::Bounded {
lo: 0,
hi: u32::MAX,
}])
.strictness(Strictness::Strict)
.version(1)
.build()
.unwrap()
}
#[test]
fn compatible_unary_ops_compose() {
let left = unary_op("a.to_u32", DataType::U32, DataType::U32);
let right = unary_op("u32.identity", DataType::U32, DataType::U32);
let cert = closure_cert_for(&left, &right).unwrap();
assert_eq!(cert.left, "a.to_u32");
assert_eq!(cert.right, "u32.identity");
assert_eq!(cert.input_type, DataType::U32);
assert_eq!(cert.output_type, DataType::U32);
}
#[test]
fn type_mismatch_is_rejected() {
let left = unary_op("bytes.len", DataType::Bytes, DataType::U32);
let right = unary_op("bytes.identity", DataType::Bytes, DataType::Bytes);
let error = closure_cert_for(&left, &right).unwrap_err();
assert!(
matches!(error, ClosureError::TypeMismatch { .. }),
"{error:?}"
);
assert!(error.fix_hint().starts_with("Fix:"));
}
#[test]
fn binary_left_is_rejected() {
let left = binary_op("u32.add");
let right = unary_op("u32.identity", DataType::U32, DataType::U32);
let error = closure_cert_for(&left, &right).unwrap_err();
assert!(
matches!(error, ClosureError::NonUnaryLeft { .. }),
"{error:?}"
);
}
#[test]
fn binary_right_is_rejected() {
let left = unary_op("u32.identity", DataType::U32, DataType::U32);
let right = binary_op("u32.add");
let error = closure_cert_for(&left, &right).unwrap_err();
assert!(
matches!(error, ClosureError::NonUnaryRight { .. }),
"{error:?}"
);
}
#[test]
fn wrapping_left_and_saturating_right_conflict() {
let left = OpSpec::builder("left")
.signature(OpSignature {
inputs: vec![DataType::U32],
output: DataType::U32,
})
.cpu_fn(|i| i.to_vec())
.wgsl_fn(|| "fn main() {}".to_string())
.category(Category::A {
composition_of: vec!["left"],
})
.laws(vec![AlgebraicLaw::Bounded {
lo: 0,
hi: u32::MAX,
}])
.strictness(Strictness::Strict)
.version(1)
.overflow_contract(OverflowContract::Wrapping)
.build()
.unwrap();
let right = OpSpec::builder("right")
.signature(OpSignature {
inputs: vec![DataType::U32],
output: DataType::U32,
})
.cpu_fn(|i| i.to_vec())
.wgsl_fn(|| "fn main() {}".to_string())
.category(Category::A {
composition_of: vec!["right"],
})
.laws(vec![AlgebraicLaw::Bounded {
lo: 0,
hi: u32::MAX,
}])
.strictness(Strictness::Strict)
.version(1)
.overflow_contract(OverflowContract::Saturating)
.build()
.unwrap();
let error = closure_cert_for(&left, &right).unwrap_err();
assert!(
matches!(error, ClosureError::OverflowContractConflict { .. }),
"{error:?}"
);
}
#[test]
fn unchecked_right_adopts_left_contract() {
let left = OpSpec::builder("left")
.signature(OpSignature {
inputs: vec![DataType::U32],
output: DataType::U32,
})
.cpu_fn(|i| i.to_vec())
.wgsl_fn(|| "fn main() {}".to_string())
.category(Category::A {
composition_of: vec!["left"],
})
.laws(vec![AlgebraicLaw::Bounded {
lo: 0,
hi: u32::MAX,
}])
.strictness(Strictness::Strict)
.version(1)
.overflow_contract(OverflowContract::Wrapping)
.build()
.unwrap();
let right = OpSpec::builder("right")
.signature(OpSignature {
inputs: vec![DataType::U32],
output: DataType::U32,
})
.cpu_fn(|i| i.to_vec())
.wgsl_fn(|| "fn main() {}".to_string())
.category(Category::A {
composition_of: vec!["right"],
})
.laws(vec![AlgebraicLaw::Bounded {
lo: 0,
hi: u32::MAX,
}])
.strictness(Strictness::Strict)
.version(1)
.overflow_contract(OverflowContract::Unchecked)
.build()
.unwrap();
let cert = closure_cert_for(&left, &right).unwrap();
assert_eq!(cert.overflow_contract, Some(OverflowContract::Unchecked));
}
#[test]
fn empty_chain_is_trivially_closed() {
let cert = chain_is_closed(&[]).unwrap();
assert!(cert.op_ids.is_empty());
assert!(cert.input_type.is_none());
assert!(cert.output_type.is_none());
}
#[test]
fn single_op_chain_is_trivially_closed() {
let only = unary_op("solo", DataType::U32, DataType::U32);
let cert = chain_is_closed(&[&only]).unwrap();
assert_eq!(cert.op_ids, vec!["solo"]);
assert_eq!(cert.input_type, Some(DataType::U32));
assert_eq!(cert.output_type, Some(DataType::U32));
}
#[test]
fn three_op_compatible_chain_closes() {
let a = unary_op("a", DataType::U32, DataType::U32);
let b = unary_op("b", DataType::U32, DataType::U32);
let c = unary_op("c", DataType::U32, DataType::U32);
let cert = chain_is_closed(&[&a, &b, &c]).unwrap();
assert_eq!(cert.op_ids, vec!["a", "b", "c"]);
assert_eq!(cert.input_type, Some(DataType::U32));
assert_eq!(cert.output_type, Some(DataType::U32));
}
#[test]
fn three_op_chain_with_middle_mismatch_fails() {
let a = unary_op("a", DataType::U32, DataType::U32);
let b = unary_op("b", DataType::Bytes, DataType::U32);
let c = unary_op("c", DataType::U32, DataType::U32);
let error = chain_is_closed(&[&a, &b, &c]).unwrap_err();
assert!(matches!(error, ClosureError::TypeMismatch { .. }));
}
#[test]
fn is_self_composable_true_for_u32_to_u32() {
let op = unary_op("u32.negate", DataType::U32, DataType::U32);
assert!(is_self_composable(&op));
}
#[test]
fn is_self_composable_false_for_type_changing() {
let op = unary_op("bytes.len", DataType::Bytes, DataType::U32);
assert!(!is_self_composable(&op));
}
#[test]
fn error_display_is_actionable() {
let error = ClosureError::TypeMismatch {
left: "a".to_string(),
right: "b".to_string(),
left_output: DataType::Bytes,
right_input: DataType::U32,
};
let rendered = format!("{error}");
assert!(rendered.contains("Fix:"));
assert!(rendered.contains("a"));
assert!(rendered.contains("b"));
}
}