1use crate::error::Error;
9use crate::field::Field;
10use crate::wire::Wire;
11
12#[derive(Debug, Clone)]
16pub enum Expression<F: Field> {
17 Constant(F),
19 Wire(Wire),
21 Neg(Box<Expression<F>>),
23 Sum(Box<Expression<F>>, Box<Expression<F>>),
25 Product(Box<Expression<F>>, Box<Expression<F>>),
27}
28
29impl<F: Field> Expression<F> {
30 #[must_use]
32 pub fn constant(c: F) -> Self {
33 Self::Constant(c)
34 }
35
36 #[must_use]
38 pub fn wire(w: Wire) -> Self {
39 Self::Wire(w)
40 }
41
42 pub fn evaluate(&self, assignment: &dyn Fn(Wire) -> Result<F, Error>) -> Result<F, Error> {
49 match self {
50 Self::Constant(c) => Ok(c.clone()),
51 Self::Wire(w) => assignment(*w),
52 Self::Neg(inner) => inner.evaluate(assignment).map(|v| -v),
53 Self::Sum(left, right) => {
54 let l = left.evaluate(assignment)?;
55 let r = right.evaluate(assignment)?;
56 Ok(l + r)
57 }
58 Self::Product(left, right) => {
59 let l = left.evaluate(assignment)?;
60 let r = right.evaluate(assignment)?;
61 Ok(l * r)
62 }
63 }
64 }
65}
66
67impl<F: Field> std::ops::Add for Expression<F> {
68 type Output = Self;
69 fn add(self, rhs: Self) -> Self {
70 Self::Sum(Box::new(self), Box::new(rhs))
71 }
72}
73
74impl<F: Field> std::ops::Sub for Expression<F> {
75 type Output = Self;
76 fn sub(self, rhs: Self) -> Self {
77 self + (-rhs)
78 }
79}
80
81impl<F: Field> std::ops::Mul for Expression<F> {
82 type Output = Self;
83 fn mul(self, rhs: Self) -> Self {
84 Self::Product(Box::new(self), Box::new(rhs))
85 }
86}
87
88impl<F: Field> std::ops::Neg for Expression<F> {
89 type Output = Self;
90 fn neg(self) -> Self {
91 Self::Neg(Box::new(self))
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98 use crate::field::F101;
99
100 fn test_assignment(w: Wire) -> Result<F101, Error> {
101 match w.index() {
102 0 => Ok(F101::new(3)),
103 1 => Ok(F101::new(5)),
104 2 => Ok(F101::new(7)),
105 _ => Err(Error::WireOutOfBounds {
106 wire_index: w.index(),
107 allocated: 3,
108 }),
109 }
110 }
111
112 #[test]
113 fn constant_evaluates_to_itself() -> Result<(), Error> {
114 let e = Expression::constant(F101::new(42));
115 assert_eq!(e.evaluate(&test_assignment)?, F101::new(42));
116 Ok(())
117 }
118
119 #[test]
120 fn wire_evaluates_to_assignment() -> Result<(), Error> {
121 let e = Expression::wire(Wire::new(1));
122 assert_eq!(e.evaluate(&test_assignment)?, F101::new(5));
123 Ok(())
124 }
125
126 #[test]
127 fn sum_evaluates_correctly() -> Result<(), Error> {
128 let e = Expression::wire(Wire::new(0)) + Expression::wire(Wire::new(1));
129 assert_eq!(e.evaluate(&test_assignment)?, F101::new(8));
130 Ok(())
131 }
132
133 #[test]
134 fn product_evaluates_correctly() -> Result<(), Error> {
135 let e = Expression::wire(Wire::new(0)) * Expression::wire(Wire::new(1));
136 assert_eq!(e.evaluate(&test_assignment)?, F101::new(15));
137 Ok(())
138 }
139
140 #[test]
141 fn subtraction_evaluates_correctly() -> Result<(), Error> {
142 let e = Expression::wire(Wire::new(0)) - Expression::wire(Wire::new(1));
144 assert_eq!(e.evaluate(&test_assignment)?, F101::new(99));
145 Ok(())
146 }
147
148 #[test]
149 fn negation_evaluates_correctly() -> Result<(), Error> {
150 let e = -Expression::wire(Wire::new(0));
152 assert_eq!(e.evaluate(&test_assignment)?, F101::new(98));
153 Ok(())
154 }
155
156 #[test]
157 fn complex_expression() -> Result<(), Error> {
158 let e = Expression::wire(Wire::new(0)) * Expression::wire(Wire::new(1))
160 + Expression::wire(Wire::new(2));
161 assert_eq!(e.evaluate(&test_assignment)?, F101::new(22));
162 Ok(())
163 }
164}