vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
//! Pairwise composition strategy.
//!
//! Generates all valid two-op chains where the output type of the first
//! op matches the first input type of the second.

use super::CompositionStrategy;
use crate::spec::types::{ChainSpec, ConstructionTime, OpSignature, OpSpec, ProofToken};

/// Generates all valid two-op chains from the spec registry.
pub struct PairwiseComposer;

impl CompositionStrategy for PairwiseComposer {
    fn generate_chains(&self, specs: &[OpSpec]) -> Vec<ChainSpec> {
        let mut chains = Vec::new();
        for first in specs {
            for second in specs {
                if !can_chain(first, second) {
                    continue;
                }
                let proof_token = match ProofToken::from_specs(
                    &[first.clone(), second.clone()],
                    ConstructionTime::PairwiseComposer,
                ) {
                    Ok(t) => t,
                    Err(_) => continue,
                };
                chains.push(ChainSpec {
                    id: format!("{}__then__{}", first.id, second.id),
                    ops: vec![first.id, second.id],
                    signature: OpSignature {
                        inputs: first.signature.inputs.clone(),
                        output: second.signature.output.clone(),
                    },
                    specs: vec![first.clone(), second.clone()],
                    cpu_chain: None,
                    proof_token,
                });
            }
        }
        chains
    }
}

/// Check whether two ops can be chained (first's output matches second's first input).
fn can_chain(first: &OpSpec, second: &OpSpec) -> bool {
    second
        .signature
        .inputs
        .first()
        .is_some_and(|input| *input == first.signature.output)
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::spec::builder::BuildError;
    use crate::spec::types::DataType;

    fn dummy_cpu(_input: &[u8]) -> Vec<u8> {
        // Bounded-law verification requires a decodable u32 output word, so
        // emit four zero bytes instead of an empty buffer.
        vec![0, 0, 0, 0]
    }

    fn dummy_wgsl() -> String {
        String::new()
    }

    fn make_spec(
        id: &'static str,
        inputs: Vec<DataType>,
        output: DataType,
    ) -> Result<OpSpec, BuildError> {
        OpSpec::builder(id)
            .signature(OpSignature { inputs, output })
            .cpu_fn(dummy_cpu)
            .wgsl_fn(dummy_wgsl)
            .category(crate::Category::A {
                composition_of: vec![id],
            })
            .laws(vec![crate::spec::law::AlgebraicLaw::Bounded {
                lo: 0,
                hi: u32::MAX,
            }])
            .strictness(crate::spec::types::Strictness::Strict)
            .version(1)
            .build()
    }

    #[test]
    fn can_chain_matching_types() -> Result<(), BuildError> {
        let a = make_spec("a", vec![DataType::U32], DataType::U32)?;
        let b = make_spec("b", vec![DataType::U32], DataType::U32)?;
        assert!(can_chain(&a, &b));
        Ok(())
    }

    #[test]
    fn cannot_chain_mismatched_types() -> Result<(), BuildError> {
        let a = make_spec("a", vec![DataType::U32], DataType::U32)?;
        let b = make_spec("b", vec![DataType::Bytes], DataType::Bytes)?;
        assert!(!can_chain(&a, &b));
        Ok(())
    }

    #[test]
    fn pairwise_generates_self_chains() -> Result<(), BuildError> {
        let specs = vec![make_spec(
            "xor",
            vec![DataType::U32, DataType::U32],
            DataType::U32,
        )?];
        let chains = PairwiseComposer.generate_chains(&specs);
        assert_eq!(chains.len(), 1);
        assert_eq!(chains[0].id, "xor__then__xor");
        Ok(())
    }

    #[test]
    fn pairwise_cross_type_excluded() -> Result<(), BuildError> {
        let specs = vec![
            make_spec("a", vec![DataType::U32], DataType::U32)?,
            make_spec("b", vec![DataType::Bytes], DataType::Bytes)?,
        ];
        let chains = PairwiseComposer.generate_chains(&specs);
        // a→a and b→b should work, but a→b and b→a should not.
        assert_eq!(chains.len(), 2);
        Ok(())
    }

    #[test]
    fn pairwise_preserves_output_type() -> Result<(), String> {
        let a = make_spec("a", vec![DataType::U32, DataType::U32], DataType::U32)
            .map_err(|err| format!("Fix: valid pairwise fixture must build: {err:?}"))?;
        let b = make_spec("b", vec![DataType::U32], DataType::Bytes)
            .map_err(|err| format!("Fix: valid pairwise fixture must build: {err:?}"))?;
        let specs = vec![a, b];
        let chains = PairwiseComposer.generate_chains(&specs);
        let Some(a_then_b) = chains.iter().find(|c| c.id == "a__then__b") else {
            return Err(
                "Fix: pairwise composer must emit a__then__b for matching signatures".to_string(),
            );
        };
        assert_eq!(a_then_b.signature.output, DataType::Bytes);
        Ok(())
    }

    #[test]
    fn pairwise_preserves_input_types() -> Result<(), String> {
        let a = make_spec("a", vec![DataType::U32, DataType::U32], DataType::U32)
            .map_err(|err| format!("Fix: valid pairwise fixture must build: {err:?}"))?;
        let b = make_spec("b", vec![DataType::U32], DataType::U32)
            .map_err(|err| format!("Fix: valid pairwise fixture must build: {err:?}"))?;
        let specs = vec![a, b];
        let chains = PairwiseComposer.generate_chains(&specs);
        let Some(a_then_b) = chains.iter().find(|c| c.id == "a__then__b") else {
            return Err(
                "Fix: pairwise composer must emit a__then__b for matching signatures".to_string(),
            );
        };
        assert_eq!(
            a_then_b.signature.inputs,
            vec![DataType::U32, DataType::U32]
        );
        Ok(())
    }

    #[test]
    fn pairwise_empty_specs_produces_nothing() {
        let chains = PairwiseComposer.generate_chains(&[]);
        assert!(chains.is_empty());
    }
}