1use std::ops::{Add, Sub};
2
3use ff::PrimeField;
4use serde::{Deserialize, Serialize};
5
6#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
8pub struct Variable(pub Index);
9
10impl Variable {
11 pub fn new_unchecked(idx: Index) -> Variable {
14 Variable(idx)
15 }
16
17 pub fn get_unchecked(&self) -> Index {
20 self.0
21 }
22}
23
24#[derive(Copy, Clone, PartialEq, Debug, Eq, Hash, Serialize, Deserialize)]
27pub enum Index {
28 Input(usize),
30 Aux(usize),
32}
33
34#[derive(Clone, Debug, PartialEq)]
37pub struct LinearCombination<Scalar: PrimeField> {
38 inputs: Indexer<Scalar>,
39 aux: Indexer<Scalar>,
40}
41
42#[derive(Clone, Debug, PartialEq)]
43struct Indexer<T> {
44 values: Vec<(usize, T)>,
46 last_inserted: Option<(usize, usize)>,
48}
49
50impl<T> Default for Indexer<T> {
51 fn default() -> Self {
52 Indexer {
53 values: Vec::new(),
54 last_inserted: None,
55 }
56 }
57}
58
59impl<T> Indexer<T> {
60 pub fn from_value(index: usize, value: T) -> Self {
61 Indexer {
62 values: vec![(index, value)],
63 last_inserted: Some((0, index)),
64 }
65 }
66
67 pub fn iter(&self) -> impl Iterator<Item = (&usize, &T)> + '_ {
68 #[allow(clippy::map_identity)]
69 self.values.iter().map(|(key, value)| (key, value))
70 }
71
72 pub fn iter_mut(&mut self) -> impl Iterator<Item = (&usize, &mut T)> + '_ {
73 self.values.iter_mut().map(|(key, value)| (&*key, value))
74 }
75
76 pub fn insert_or_update<F, G>(&mut self, key: usize, insert: F, update: G)
77 where
78 F: FnOnce() -> T,
79 G: FnOnce(&mut T),
80 {
81 if let Some((last_index, last_key)) = self.last_inserted {
82 if last_key == key {
85 update(&mut self.values[last_index].1);
87 return;
88 } else if last_key + 1 == key {
89 let i = last_index + 1;
91 if i >= self.values.len() {
92 self.values.push((key, insert()));
94 self.last_inserted = Some((i, key));
95 } else if self.values[i].0 == key {
96 update(&mut self.values[i].1);
98 } else {
99 self.values.insert(i, (key, insert()));
101 self.last_inserted = Some((i, key));
102 }
103 return;
104 }
105 }
106 match self.values.binary_search_by_key(&key, |(k, _)| *k) {
107 Ok(i) => {
108 update(&mut self.values[i].1);
109 }
110 Err(i) => {
111 self.values.insert(i, (key, insert()));
112 self.last_inserted = Some((i, key));
113 }
114 }
115 }
116
117 pub fn len(&self) -> usize {
118 self.values.len()
119 }
120
121 pub fn is_empty(&self) -> bool {
122 self.values.is_empty()
123 }
124}
125
126impl<Scalar: PrimeField> Default for LinearCombination<Scalar> {
127 fn default() -> Self {
128 Self::zero()
129 }
130}
131
132impl<Scalar: PrimeField> LinearCombination<Scalar> {
133 pub fn zero() -> LinearCombination<Scalar> {
135 LinearCombination {
136 inputs: Default::default(),
137 aux: Default::default(),
138 }
139 }
140
141 pub fn from_coeff(var: Variable, coeff: Scalar) -> Self {
143 match var {
144 Variable(Index::Input(i)) => Self {
145 inputs: Indexer::from_value(i, coeff),
146 aux: Default::default(),
147 },
148 Variable(Index::Aux(i)) => Self {
149 inputs: Default::default(),
150 aux: Indexer::from_value(i, coeff),
151 },
152 }
153 }
154
155 pub fn from_variable(var: Variable) -> Self {
157 Self::from_coeff(var, Scalar::ONE)
158 }
159
160 pub fn iter(&self) -> impl Iterator<Item = (Variable, &Scalar)> + '_ {
162 self
163 .inputs
164 .iter()
165 .map(|(k, v)| (Variable(Index::Input(*k)), v))
166 .chain(self.aux.iter().map(|(k, v)| (Variable(Index::Aux(*k)), v)))
167 }
168
169 #[inline]
171 pub fn iter_inputs(&self) -> impl Iterator<Item = (&usize, &Scalar)> + '_ {
172 self.inputs.iter()
173 }
174
175 #[inline]
177 pub fn iter_aux(&self) -> impl Iterator<Item = (&usize, &Scalar)> + '_ {
178 self.aux.iter()
179 }
180
181 pub fn iter_mut(&mut self) -> impl Iterator<Item = (Variable, &mut Scalar)> + '_ {
183 self
184 .inputs
185 .iter_mut()
186 .map(|(k, v)| (Variable(Index::Input(*k)), v))
187 .chain(
188 self
189 .aux
190 .iter_mut()
191 .map(|(k, v)| (Variable(Index::Aux(*k)), v)),
192 )
193 }
194
195 #[inline]
196 fn add_assign_unsimplified_input(&mut self, new_var: usize, coeff: Scalar) {
197 self
198 .inputs
199 .insert_or_update(new_var, || coeff, |val| *val += coeff);
200 }
201
202 #[inline]
203 fn add_assign_unsimplified_aux(&mut self, new_var: usize, coeff: Scalar) {
204 self
205 .aux
206 .insert_or_update(new_var, || coeff, |val| *val += coeff);
207 }
208
209 pub fn add_unsimplified(mut self, (coeff, var): (Scalar, Variable)) -> LinearCombination<Scalar> {
211 match var.0 {
212 Index::Input(new_var) => {
213 self.add_assign_unsimplified_input(new_var, coeff);
214 }
215 Index::Aux(new_var) => {
216 self.add_assign_unsimplified_aux(new_var, coeff);
217 }
218 }
219
220 self
221 }
222
223 #[inline]
224 fn sub_assign_unsimplified_input(&mut self, new_var: usize, coeff: Scalar) {
225 self.add_assign_unsimplified_input(new_var, -coeff);
226 }
227
228 #[inline]
229 fn sub_assign_unsimplified_aux(&mut self, new_var: usize, coeff: Scalar) {
230 self.add_assign_unsimplified_aux(new_var, -coeff);
231 }
232
233 pub fn sub_unsimplified(mut self, (coeff, var): (Scalar, Variable)) -> LinearCombination<Scalar> {
235 match var.0 {
236 Index::Input(new_var) => {
237 self.sub_assign_unsimplified_input(new_var, coeff);
238 }
239 Index::Aux(new_var) => {
240 self.sub_assign_unsimplified_aux(new_var, coeff);
241 }
242 }
243
244 self
245 }
246
247 pub fn len(&self) -> usize {
249 self.inputs.len() + self.aux.len()
250 }
251
252 pub fn is_empty(&self) -> bool {
254 self.inputs.is_empty() && self.aux.is_empty()
255 }
256
257 pub fn eval(&self, input_assignment: &[Scalar], aux_assignment: &[Scalar]) -> Scalar {
259 let mut acc = Scalar::ZERO;
260
261 let one = Scalar::ONE;
262
263 for (index, coeff) in self.iter_inputs() {
264 let mut tmp = input_assignment[*index];
265 if coeff != &one {
266 tmp *= coeff;
267 }
268 acc += tmp;
269 }
270
271 for (index, coeff) in self.iter_aux() {
272 let mut tmp = aux_assignment[*index];
273 if coeff != &one {
274 tmp *= coeff;
275 }
276 acc += tmp;
277 }
278
279 acc
280 }
281}
282
283impl<Scalar: PrimeField> Add<(Scalar, Variable)> for LinearCombination<Scalar> {
284 type Output = LinearCombination<Scalar>;
285
286 fn add(self, (coeff, var): (Scalar, Variable)) -> LinearCombination<Scalar> {
287 self.add_unsimplified((coeff, var))
288 }
289}
290
291impl<Scalar: PrimeField> Sub<(Scalar, Variable)> for LinearCombination<Scalar> {
292 type Output = LinearCombination<Scalar>;
293
294 #[allow(clippy::suspicious_arithmetic_impl)]
295 fn sub(self, (coeff, var): (Scalar, Variable)) -> LinearCombination<Scalar> {
296 self.sub_unsimplified((coeff, var))
297 }
298}
299
300impl<Scalar: PrimeField> Add<Variable> for LinearCombination<Scalar> {
301 type Output = LinearCombination<Scalar>;
302
303 fn add(self, other: Variable) -> LinearCombination<Scalar> {
304 self + (Scalar::ONE, other)
305 }
306}
307
308impl<Scalar: PrimeField> Sub<Variable> for LinearCombination<Scalar> {
309 type Output = LinearCombination<Scalar>;
310
311 fn sub(self, other: Variable) -> LinearCombination<Scalar> {
312 self - (Scalar::ONE, other)
313 }
314}
315
316impl<'a, Scalar: PrimeField> Add<&'a LinearCombination<Scalar>> for LinearCombination<Scalar> {
317 type Output = LinearCombination<Scalar>;
318
319 fn add(mut self, other: &'a LinearCombination<Scalar>) -> LinearCombination<Scalar> {
320 for (var, val) in other.inputs.iter() {
321 self.add_assign_unsimplified_input(*var, *val);
322 }
323
324 for (var, val) in other.aux.iter() {
325 self.add_assign_unsimplified_aux(*var, *val);
326 }
327
328 self
329 }
330}
331
332impl<'a, Scalar: PrimeField> Sub<&'a LinearCombination<Scalar>> for LinearCombination<Scalar> {
333 type Output = LinearCombination<Scalar>;
334
335 fn sub(mut self, other: &'a LinearCombination<Scalar>) -> LinearCombination<Scalar> {
336 for (var, val) in other.inputs.iter() {
337 self.sub_assign_unsimplified_input(*var, *val);
338 }
339
340 for (var, val) in other.aux.iter() {
341 self.sub_assign_unsimplified_aux(*var, *val);
342 }
343
344 self
345 }
346}
347
348impl<'a, Scalar: PrimeField> Add<(Scalar, &'a LinearCombination<Scalar>)>
349 for LinearCombination<Scalar>
350{
351 type Output = LinearCombination<Scalar>;
352
353 fn add(
354 mut self,
355 (coeff, other): (Scalar, &'a LinearCombination<Scalar>),
356 ) -> LinearCombination<Scalar> {
357 for (var, val) in other.inputs.iter() {
358 self.add_assign_unsimplified_input(*var, *val * coeff);
359 }
360
361 for (var, val) in other.aux.iter() {
362 self.add_assign_unsimplified_aux(*var, *val * coeff);
363 }
364
365 self
366 }
367}
368
369impl<'a, Scalar: PrimeField> Sub<(Scalar, &'a LinearCombination<Scalar>)>
370 for LinearCombination<Scalar>
371{
372 type Output = LinearCombination<Scalar>;
373
374 fn sub(
375 mut self,
376 (coeff, other): (Scalar, &'a LinearCombination<Scalar>),
377 ) -> LinearCombination<Scalar> {
378 for (var, val) in other.inputs.iter() {
379 self.sub_assign_unsimplified_input(*var, *val * coeff);
380 }
381
382 for (var, val) in other.aux.iter() {
383 self.sub_assign_unsimplified_aux(*var, *val * coeff);
384 }
385
386 self
387 }
388}