1use field_cat::Field;
9
10use crate::error::Error;
11use crate::wire::Wire;
12
13#[derive(Debug, Clone)]
17pub enum Expression<F: Field> {
18 Constant(F),
20 Wire(Wire),
22 Neg(Box<Expression<F>>),
24 Sum(Box<Expression<F>>, Box<Expression<F>>),
26 Product(Box<Expression<F>>, Box<Expression<F>>),
28}
29
30impl<F: Field> Expression<F> {
31 #[must_use]
33 pub fn constant(c: F) -> Self {
34 Self::Constant(c)
35 }
36
37 #[must_use]
39 pub fn wire(w: Wire) -> Self {
40 Self::Wire(w)
41 }
42
43 pub fn evaluate(&self, assignment: &dyn Fn(Wire) -> Result<F, Error>) -> Result<F, Error> {
50 match self {
51 Self::Constant(c) => Ok(c.clone()),
52 Self::Wire(w) => assignment(*w),
53 Self::Neg(inner) => inner.evaluate(assignment).map(|v| -v),
54 Self::Sum(left, right) => {
55 let l = left.evaluate(assignment)?;
56 let r = right.evaluate(assignment)?;
57 Ok(l + r)
58 }
59 Self::Product(left, right) => {
60 let l = left.evaluate(assignment)?;
61 let r = right.evaluate(assignment)?;
62 Ok(l * r)
63 }
64 }
65 }
66}
67
68impl<F: Field> std::ops::Add for Expression<F> {
69 type Output = Self;
70 fn add(self, rhs: Self) -> Self {
71 Self::Sum(Box::new(self), Box::new(rhs))
72 }
73}
74
75impl<F: Field> std::ops::Sub for Expression<F> {
76 type Output = Self;
77 fn sub(self, rhs: Self) -> Self {
78 self + (-rhs)
79 }
80}
81
82impl<F: Field> std::ops::Mul for Expression<F> {
83 type Output = Self;
84 fn mul(self, rhs: Self) -> Self {
85 Self::Product(Box::new(self), Box::new(rhs))
86 }
87}
88
89impl<F: Field> std::ops::Neg for Expression<F> {
90 type Output = Self;
91 fn neg(self) -> Self {
92 Self::Neg(Box::new(self))
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99 use field_cat::F101;
100
101 fn test_assignment(w: Wire) -> Result<F101, Error> {
102 match w.index() {
103 0 => Ok(F101::new(3)),
104 1 => Ok(F101::new(5)),
105 2 => Ok(F101::new(7)),
106 _ => Err(Error::WireOutOfBounds {
107 wire_index: w.index(),
108 allocated: 3,
109 }),
110 }
111 }
112
113 #[test]
114 fn constant_evaluates_to_itself() -> Result<(), Error> {
115 let e = Expression::constant(F101::new(42));
116 assert_eq!(e.evaluate(&test_assignment)?, F101::new(42));
117 Ok(())
118 }
119
120 #[test]
121 fn wire_evaluates_to_assignment() -> Result<(), Error> {
122 let e = Expression::wire(Wire::new(1));
123 assert_eq!(e.evaluate(&test_assignment)?, F101::new(5));
124 Ok(())
125 }
126
127 #[test]
128 fn sum_evaluates_correctly() -> Result<(), Error> {
129 let e = Expression::wire(Wire::new(0)) + Expression::wire(Wire::new(1));
130 assert_eq!(e.evaluate(&test_assignment)?, F101::new(8));
131 Ok(())
132 }
133
134 #[test]
135 fn product_evaluates_correctly() -> Result<(), Error> {
136 let e = Expression::wire(Wire::new(0)) * Expression::wire(Wire::new(1));
137 assert_eq!(e.evaluate(&test_assignment)?, F101::new(15));
138 Ok(())
139 }
140
141 #[test]
142 fn subtraction_evaluates_correctly() -> Result<(), Error> {
143 let e = Expression::wire(Wire::new(0)) - Expression::wire(Wire::new(1));
145 assert_eq!(e.evaluate(&test_assignment)?, F101::new(99));
146 Ok(())
147 }
148
149 #[test]
150 fn negation_evaluates_correctly() -> Result<(), Error> {
151 let e = -Expression::wire(Wire::new(0));
153 assert_eq!(e.evaluate(&test_assignment)?, F101::new(98));
154 Ok(())
155 }
156
157 #[test]
158 fn complex_expression() -> Result<(), Error> {
159 let e = Expression::wire(Wire::new(0)) * Expression::wire(Wire::new(1))
161 + Expression::wire(Wire::new(2));
162 assert_eq!(e.evaluate(&test_assignment)?, F101::new(22));
163 Ok(())
164 }
165}