Skip to main content

plonkish_cat/
expr.rs

1//! Symbolic polynomial expressions over wires.
2//!
3//! [`Expression<F>`] is the free algebra over wire references and
4//! field constants.  Expressions are built symbolically during
5//! constraint generation and evaluated against wire assignments
6//! for satisfaction checking.
7
8use crate::error::Error;
9use crate::field::Field;
10use crate::wire::Wire;
11
12/// A symbolic polynomial expression over field `F` and wire indices.
13///
14/// Used to build constraints: an expression that must equal zero.
15#[derive(Debug, Clone)]
16pub enum Expression<F: Field> {
17    /// A field constant.
18    Constant(F),
19    /// A wire reference (a variable).
20    Wire(Wire),
21    /// Negation of an expression.
22    Neg(Box<Expression<F>>),
23    /// Sum of two expressions.
24    Sum(Box<Expression<F>>, Box<Expression<F>>),
25    /// Product of two expressions.
26    Product(Box<Expression<F>>, Box<Expression<F>>),
27}
28
29impl<F: Field> Expression<F> {
30    /// A constant expression.
31    #[must_use]
32    pub fn constant(c: F) -> Self {
33        Self::Constant(c)
34    }
35
36    /// A wire reference.
37    #[must_use]
38    pub fn wire(w: Wire) -> Self {
39        Self::Wire(w)
40    }
41
42    /// Evaluate this expression given a wire-value assignment.
43    ///
44    /// # Errors
45    ///
46    /// Returns [`Error::WireOutOfBounds`] if a referenced wire
47    /// is not in the assignment.
48    pub fn evaluate(&self, assignment: &dyn Fn(Wire) -> Result<F, Error>) -> Result<F, Error> {
49        match self {
50            Self::Constant(c) => Ok(c.clone()),
51            Self::Wire(w) => assignment(*w),
52            Self::Neg(inner) => inner.evaluate(assignment).map(|v| -v),
53            Self::Sum(left, right) => {
54                let l = left.evaluate(assignment)?;
55                let r = right.evaluate(assignment)?;
56                Ok(l + r)
57            }
58            Self::Product(left, right) => {
59                let l = left.evaluate(assignment)?;
60                let r = right.evaluate(assignment)?;
61                Ok(l * r)
62            }
63        }
64    }
65}
66
67impl<F: Field> std::ops::Add for Expression<F> {
68    type Output = Self;
69    fn add(self, rhs: Self) -> Self {
70        Self::Sum(Box::new(self), Box::new(rhs))
71    }
72}
73
74impl<F: Field> std::ops::Sub for Expression<F> {
75    type Output = Self;
76    fn sub(self, rhs: Self) -> Self {
77        self + (-rhs)
78    }
79}
80
81impl<F: Field> std::ops::Mul for Expression<F> {
82    type Output = Self;
83    fn mul(self, rhs: Self) -> Self {
84        Self::Product(Box::new(self), Box::new(rhs))
85    }
86}
87
88impl<F: Field> std::ops::Neg for Expression<F> {
89    type Output = Self;
90    fn neg(self) -> Self {
91        Self::Neg(Box::new(self))
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98    use crate::field::F101;
99
100    fn test_assignment(w: Wire) -> Result<F101, Error> {
101        match w.index() {
102            0 => Ok(F101::new(3)),
103            1 => Ok(F101::new(5)),
104            2 => Ok(F101::new(7)),
105            _ => Err(Error::WireOutOfBounds {
106                wire_index: w.index(),
107                allocated: 3,
108            }),
109        }
110    }
111
112    #[test]
113    fn constant_evaluates_to_itself() -> Result<(), Error> {
114        let e = Expression::constant(F101::new(42));
115        assert_eq!(e.evaluate(&test_assignment)?, F101::new(42));
116        Ok(())
117    }
118
119    #[test]
120    fn wire_evaluates_to_assignment() -> Result<(), Error> {
121        let e = Expression::wire(Wire::new(1));
122        assert_eq!(e.evaluate(&test_assignment)?, F101::new(5));
123        Ok(())
124    }
125
126    #[test]
127    fn sum_evaluates_correctly() -> Result<(), Error> {
128        let e = Expression::wire(Wire::new(0)) + Expression::wire(Wire::new(1));
129        assert_eq!(e.evaluate(&test_assignment)?, F101::new(8));
130        Ok(())
131    }
132
133    #[test]
134    fn product_evaluates_correctly() -> Result<(), Error> {
135        let e = Expression::wire(Wire::new(0)) * Expression::wire(Wire::new(1));
136        assert_eq!(e.evaluate(&test_assignment)?, F101::new(15));
137        Ok(())
138    }
139
140    #[test]
141    fn subtraction_evaluates_correctly() -> Result<(), Error> {
142        // 3 - 5 = -2 = 99 (mod 101)
143        let e = Expression::wire(Wire::new(0)) - Expression::wire(Wire::new(1));
144        assert_eq!(e.evaluate(&test_assignment)?, F101::new(99));
145        Ok(())
146    }
147
148    #[test]
149    fn negation_evaluates_correctly() -> Result<(), Error> {
150        // -3 = 98 (mod 101)
151        let e = -Expression::wire(Wire::new(0));
152        assert_eq!(e.evaluate(&test_assignment)?, F101::new(98));
153        Ok(())
154    }
155
156    #[test]
157    fn complex_expression() -> Result<(), Error> {
158        // w0 * w1 + w2 = 3 * 5 + 7 = 22
159        let e = Expression::wire(Wire::new(0)) * Expression::wire(Wire::new(1))
160            + Expression::wire(Wire::new(2));
161        assert_eq!(e.evaluate(&test_assignment)?, F101::new(22));
162        Ok(())
163    }
164}