Skip to main content

plonkish_cat/
expr.rs

1//! Symbolic polynomial expressions over wires.
2//!
3//! [`Expression<F>`] is the free algebra over wire references and
4//! field constants.  Expressions are built symbolically during
5//! constraint generation and evaluated against wire assignments
6//! for satisfaction checking.
7
8use field_cat::Field;
9
10use crate::error::Error;
11use crate::wire::Wire;
12
13/// A symbolic polynomial expression over field `F` and wire indices.
14///
15/// Used to build constraints: an expression that must equal zero.
16#[derive(Debug, Clone)]
17pub enum Expression<F: Field> {
18    /// A field constant.
19    Constant(F),
20    /// A wire reference (a variable).
21    Wire(Wire),
22    /// Negation of an expression.
23    Neg(Box<Expression<F>>),
24    /// Sum of two expressions.
25    Sum(Box<Expression<F>>, Box<Expression<F>>),
26    /// Product of two expressions.
27    Product(Box<Expression<F>>, Box<Expression<F>>),
28}
29
30impl<F: Field> Expression<F> {
31    /// A constant expression.
32    #[must_use]
33    pub fn constant(c: F) -> Self {
34        Self::Constant(c)
35    }
36
37    /// A wire reference.
38    #[must_use]
39    pub fn wire(w: Wire) -> Self {
40        Self::Wire(w)
41    }
42
43    /// Evaluate this expression given a wire-value assignment.
44    ///
45    /// # Errors
46    ///
47    /// Returns [`Error::WireOutOfBounds`] if a referenced wire
48    /// is not in the assignment.
49    pub fn evaluate(&self, assignment: &dyn Fn(Wire) -> Result<F, Error>) -> Result<F, Error> {
50        match self {
51            Self::Constant(c) => Ok(c.clone()),
52            Self::Wire(w) => assignment(*w),
53            Self::Neg(inner) => inner.evaluate(assignment).map(|v| -v),
54            Self::Sum(left, right) => {
55                let l = left.evaluate(assignment)?;
56                let r = right.evaluate(assignment)?;
57                Ok(l + r)
58            }
59            Self::Product(left, right) => {
60                let l = left.evaluate(assignment)?;
61                let r = right.evaluate(assignment)?;
62                Ok(l * r)
63            }
64        }
65    }
66}
67
68impl<F: Field> std::ops::Add for Expression<F> {
69    type Output = Self;
70    fn add(self, rhs: Self) -> Self {
71        Self::Sum(Box::new(self), Box::new(rhs))
72    }
73}
74
75impl<F: Field> std::ops::Sub for Expression<F> {
76    type Output = Self;
77    fn sub(self, rhs: Self) -> Self {
78        self + (-rhs)
79    }
80}
81
82impl<F: Field> std::ops::Mul for Expression<F> {
83    type Output = Self;
84    fn mul(self, rhs: Self) -> Self {
85        Self::Product(Box::new(self), Box::new(rhs))
86    }
87}
88
89impl<F: Field> std::ops::Neg for Expression<F> {
90    type Output = Self;
91    fn neg(self) -> Self {
92        Self::Neg(Box::new(self))
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99    use field_cat::F101;
100
101    fn test_assignment(w: Wire) -> Result<F101, Error> {
102        match w.index() {
103            0 => Ok(F101::new(3)),
104            1 => Ok(F101::new(5)),
105            2 => Ok(F101::new(7)),
106            _ => Err(Error::WireOutOfBounds {
107                wire_index: w.index(),
108                allocated: 3,
109            }),
110        }
111    }
112
113    #[test]
114    fn constant_evaluates_to_itself() -> Result<(), Error> {
115        let e = Expression::constant(F101::new(42));
116        assert_eq!(e.evaluate(&test_assignment)?, F101::new(42));
117        Ok(())
118    }
119
120    #[test]
121    fn wire_evaluates_to_assignment() -> Result<(), Error> {
122        let e = Expression::wire(Wire::new(1));
123        assert_eq!(e.evaluate(&test_assignment)?, F101::new(5));
124        Ok(())
125    }
126
127    #[test]
128    fn sum_evaluates_correctly() -> Result<(), Error> {
129        let e = Expression::wire(Wire::new(0)) + Expression::wire(Wire::new(1));
130        assert_eq!(e.evaluate(&test_assignment)?, F101::new(8));
131        Ok(())
132    }
133
134    #[test]
135    fn product_evaluates_correctly() -> Result<(), Error> {
136        let e = Expression::wire(Wire::new(0)) * Expression::wire(Wire::new(1));
137        assert_eq!(e.evaluate(&test_assignment)?, F101::new(15));
138        Ok(())
139    }
140
141    #[test]
142    fn subtraction_evaluates_correctly() -> Result<(), Error> {
143        // 3 - 5 = -2 = 99 (mod 101)
144        let e = Expression::wire(Wire::new(0)) - Expression::wire(Wire::new(1));
145        assert_eq!(e.evaluate(&test_assignment)?, F101::new(99));
146        Ok(())
147    }
148
149    #[test]
150    fn negation_evaluates_correctly() -> Result<(), Error> {
151        // -3 = 98 (mod 101)
152        let e = -Expression::wire(Wire::new(0));
153        assert_eq!(e.evaluate(&test_assignment)?, F101::new(98));
154        Ok(())
155    }
156
157    #[test]
158    fn complex_expression() -> Result<(), Error> {
159        // w0 * w1 + w2 = 3 * 5 + 7 = 22
160        let e = Expression::wire(Wire::new(0)) * Expression::wire(Wire::new(1))
161            + Expression::wire(Wire::new(2));
162        assert_eq!(e.evaluate(&test_assignment)?, F101::new(22));
163        Ok(())
164    }
165}