1use alloy::primitives::U256;
2use anyhow::{bail, Result};
3use serde::{Deserialize, Serialize};
4use std::str::FromStr;
5
6use self::integer::Operator;
7
8pub mod integer;
9pub mod rand;
10pub mod string;
11
12#[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)]
23#[serde(rename_all = "lowercase")]
24pub enum AggregationFunction {
25 AVG,
26 SUM,
27 MIN,
28 MAX,
29 COUNT,
30 MERKLE,
31 SLR,
32}
33
34impl FromStr for AggregationFunction {
36 type Err = anyhow::Error;
37
38 fn from_str(function_id: &str) -> Result<Self, Self::Err> {
39 match function_id.to_uppercase().as_str() {
40 "AVG" => Ok(Self::AVG),
41 "SUM" => Ok(Self::SUM),
42 "MIN" => Ok(Self::MIN),
43 "MAX" => Ok(Self::MAX),
44 "COUNT" => Ok(Self::COUNT),
45 "MERKLE" => Ok(Self::MERKLE),
46 "SLR" => Ok(Self::SLR),
47 _ => bail!("Unknown aggregation function"),
48 }
49 }
50}
51
52impl std::fmt::Display for AggregationFunction {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 match self {
55 AggregationFunction::AVG => write!(f, "avg"),
56 AggregationFunction::SUM => write!(f, "sum"),
57 AggregationFunction::MIN => write!(f, "min"),
58 AggregationFunction::MAX => write!(f, "max"),
59 AggregationFunction::COUNT => write!(f, "count"),
60 AggregationFunction::MERKLE => write!(f, "merkle"),
61 AggregationFunction::SLR => write!(f, "slr"),
62 }
63 }
64}
65#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
66#[serde(rename_all = "camelCase")]
67pub struct FunctionContext {
68 pub operator: Operator,
69 pub value_to_compare: U256,
70}
71
72impl Default for FunctionContext {
73 fn default() -> Self {
74 Self {
75 operator: Operator::None,
76 value_to_compare: U256::ZERO,
77 }
78 }
79}
80
81impl FromStr for FunctionContext {
82 type Err = anyhow::Error;
83
84 fn from_str(context: &str) -> Result<Self, Self::Err> {
85 let parts: Vec<&str> = context.split('.').collect();
86 if parts.len() != 2 {
87 bail!("Invalid FnContext format");
88 }
89 let operator = parts[0].to_string();
90 let value_to_compare = parts[1].to_string();
91
92 Ok(Self {
93 operator: Operator::from_str(&operator).unwrap(),
94 value_to_compare: U256::from_str(&value_to_compare)?,
95 })
96 }
97}
98
99impl std::fmt::Display for FunctionContext {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 write!(f, "{}.{}", self.operator, self.value_to_compare)
102 }
103}
104
105impl FunctionContext {
106 pub fn new(operator: Operator, value_to_compare: U256) -> Self {
107 Self {
108 operator,
109 value_to_compare,
110 }
111 }
112}
113
114impl AggregationFunction {
115 pub fn to_index(function_id: &Self) -> u8 {
116 match function_id {
117 AggregationFunction::AVG => 0,
118 AggregationFunction::SUM => 1,
119 AggregationFunction::MIN => 2,
120 AggregationFunction::MAX => 3,
121 AggregationFunction::COUNT => 4,
122 AggregationFunction::MERKLE => 5,
123 AggregationFunction::SLR => 6,
124 }
125 }
126
127 pub fn from_index(index: u8) -> Result<Self> {
128 match index {
129 0 => Ok(AggregationFunction::AVG),
130 1 => Ok(AggregationFunction::SUM),
131 2 => Ok(AggregationFunction::MIN),
132 3 => Ok(AggregationFunction::MAX),
133 4 => Ok(AggregationFunction::COUNT),
134 5 => Ok(AggregationFunction::MERKLE),
135 6 => Ok(AggregationFunction::SLR),
136 _ => bail!("Unknown aggregation function index"),
137 }
138 }
139
140 pub fn operation(&self, values: &[U256], ctx: Option<FunctionContext>) -> Result<U256> {
141 match self {
142 AggregationFunction::AVG => integer::average(values),
144 AggregationFunction::MAX => integer::find_max(values),
145 AggregationFunction::MIN => integer::find_min(values),
146 AggregationFunction::SUM => integer::sum(values),
147 AggregationFunction::COUNT => {
148 if let Some(ctx) = ctx {
149 integer::count(values, &ctx)
150 } else {
151 bail!("Context not provided for COUNT")
152 }
153 }
154 AggregationFunction::MERKLE => todo!("Merkleize not implemented yet"),
156 AggregationFunction::SLR => integer::simple_linear_regression(values),
157 }
158 }
159
160 pub fn is_pre_processable(&self) -> bool {
162 match self {
163 AggregationFunction::AVG
164 | AggregationFunction::SUM
165 | AggregationFunction::MIN
166 | AggregationFunction::MAX
167 | AggregationFunction::COUNT => true,
168 AggregationFunction::SLR | AggregationFunction::MERKLE => false,
169 }
170 }
171}
172
173#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
202 fn test_sum() {
203 let sum_fn = AggregationFunction::SUM;
204
205 let values = vec![U256::from_str_radix("6776", 10).unwrap()];
207 let result = sum_fn.operation(&values, None).unwrap();
208 assert_eq!(result, U256::from(6776));
209
210 let values = vec![
212 U256::from_str_radix("6776", 10).unwrap(),
213 U256::from_str_radix("6776", 10).unwrap(),
214 U256::from_str_radix("6776", 10).unwrap(),
215 U256::from_str_radix("6777", 10).unwrap(),
216 ];
217 let result = sum_fn.operation(&values, None).unwrap();
218 assert_eq!(result, U256::from(27105));
219
220 let values = vec![U256::from_str_radix("9184e72a000", 16).unwrap()];
222 let result = sum_fn.operation(&values, None).unwrap();
223 assert_eq!(result, U256::from_str_radix("10000000000000", 10).unwrap());
224
225 let values = vec![
227 U256::from_str_radix("9184e72a000", 16).unwrap(),
228 U256::from_str_radix("9184e72a000", 16).unwrap(),
229 U256::from_str_radix("9184e72a000", 16).unwrap(),
230 U256::from_str_radix("9184e72a000", 16).unwrap(),
231 ];
232 let result = sum_fn.operation(&values, None).unwrap();
233 assert_eq!(result, U256::from_str_radix("40000000000000", 10).unwrap());
234
235 let values = vec![
237 U256::from_str_radix("41697298409483537348", 10).unwrap(),
238 U256::from_str_radix("41697298409483537348", 10).unwrap(),
239 U256::from_str_radix("41697298409483537348", 10).unwrap(),
240 U256::from_str_radix("41697095938570171564", 10).unwrap(),
241 ];
242 let result = sum_fn.operation(&values, None).unwrap();
243 assert_eq!(
244 result,
245 U256::from_str_radix("166788991167020783608", 10).unwrap()
246 );
247 }
248
249 #[test]
250 fn test_avg() {
251 let avg_fn = AggregationFunction::AVG;
252
253 let values = vec![U256::from_str_radix("6776", 10).unwrap()];
255 let result = avg_fn.operation(&values, None).unwrap();
256 assert_eq!(result, U256::from(6776));
257
258 let values = vec![
260 U256::from_str_radix("6776", 10).unwrap(),
261 U256::from_str_radix("6776", 10).unwrap(),
262 U256::from_str_radix("6776", 10).unwrap(),
263 U256::from_str_radix("6777", 10).unwrap(),
264 U256::from_str_radix("6777", 10).unwrap(),
265 U256::from_str_radix("6777", 10).unwrap(),
266 U256::from_str_radix("6777", 10).unwrap(),
267 U256::from_str_radix("6777", 10).unwrap(),
268 U256::from_str_radix("6777", 10).unwrap(),
269 U256::from_str_radix("6777", 10).unwrap(),
270 U256::from_str_radix("6777", 10).unwrap(),
271 ];
272 let result = avg_fn.operation(&values, None).unwrap();
273 assert_eq!(result, U256::from(6777));
274
275 let values = vec![U256::from_str_radix("9184e72a000", 16).unwrap()];
277 let result = avg_fn.operation(&values, None).unwrap();
278 assert_eq!(result, U256::from(10000000000000u64));
279
280 let values = vec![
282 U256::from_str_radix("9184e72a000", 16).unwrap(),
283 U256::from_str_radix("9184e72a000", 16).unwrap(),
284 U256::from_str_radix("9184e72a000", 16).unwrap(),
285 U256::from_str_radix("9184e72a000", 16).unwrap(),
286 ];
287 let result = avg_fn.operation(&values, None).unwrap();
288 assert_eq!(result, U256::from(10000000000000u64));
289
290 let values = vec![
292 U256::from_str_radix("41697298409483537348", 10).unwrap(),
293 U256::from_str_radix("41697298409483537348", 10).unwrap(),
294 U256::from_str_radix("41697298409483537348", 10).unwrap(),
295 U256::from_str_radix("41697095938570171564", 10).unwrap(),
296 U256::from_str_radix("41697095938570171564", 10).unwrap(),
297 U256::from_str_radix("41697095938570171564", 10).unwrap(),
298 U256::from_str_radix("41697095938570171564", 10).unwrap(),
299 U256::from_str_radix("41697095938570171564", 10).unwrap(),
300 U256::from_str_radix("41697095938570171564", 10).unwrap(),
301 U256::from_str_radix("41697095938570171564", 10).unwrap(),
302 U256::from_str_radix("41697095938570171564", 10).unwrap(),
303 ];
304 let result = avg_fn.operation(&values, None).unwrap();
305 assert_eq!(result, U256::from(41697151157910180414u128));
306 }
307
308 #[test]
309 fn test_max() {
310 let max_fn = AggregationFunction::MAX;
311
312 let values = vec![U256::from_str_radix("6776", 10).unwrap()];
314 let result = max_fn.operation(&values, None).unwrap();
315 assert_eq!(result, U256::from(6776));
316
317 let values = vec![
319 U256::from_str_radix("6776", 10).unwrap(),
320 U256::from_str_radix("6776", 10).unwrap(),
321 U256::from_str_radix("6776", 10).unwrap(),
322 U256::from_str_radix("6777", 10).unwrap(),
323 U256::from_str_radix("6777", 10).unwrap(),
324 U256::from_str_radix("6777", 10).unwrap(),
325 U256::from_str_radix("6777", 10).unwrap(),
326 U256::from_str_radix("6777", 10).unwrap(),
327 U256::from_str_radix("6777", 10).unwrap(),
328 U256::from_str_radix("6777", 10).unwrap(),
329 U256::from_str_radix("6777", 10).unwrap(),
330 ];
331 let result = max_fn.operation(&values, None).unwrap();
332 assert_eq!(result, U256::from(6777));
333
334 let values = vec![U256::from_str_radix("9184e72a000", 16).unwrap()];
336 let result = max_fn.operation(&values, None).unwrap();
337 assert_eq!(result, U256::from(10000000000000u64));
338
339 let values = vec![
341 U256::from_str_radix("9184e72a000", 16).unwrap(),
342 U256::from_str_radix("9184e72a000", 16).unwrap(),
343 U256::from_str_radix("9184e72a000", 16).unwrap(),
344 U256::from_str_radix("9184e72a000", 16).unwrap(),
345 ];
346 let result = max_fn.operation(&values, None).unwrap();
347 assert_eq!(result, U256::from(10000000000000u64));
348
349 let values = vec![
351 U256::from_str_radix("41697298409483537348", 10).unwrap(),
352 U256::from_str_radix("41697298409483537348", 10).unwrap(),
353 U256::from_str_radix("41697298409483537348", 10).unwrap(),
354 U256::from_str_radix("41697095938570171564", 10).unwrap(),
355 U256::from_str_radix("41697095938570171564", 10).unwrap(),
356 U256::from_str_radix("41697095938570171564", 10).unwrap(),
357 U256::from_str_radix("41697095938570171564", 10).unwrap(),
358 U256::from_str_radix("41697095938570171564", 10).unwrap(),
359 U256::from_str_radix("41697095938570171564", 10).unwrap(),
360 U256::from_str_radix("41697095938570171564", 10).unwrap(),
361 U256::from_str_radix("41697095938570171564", 10).unwrap(),
362 ];
363 let result = max_fn.operation(&values, None).unwrap();
364 assert_eq!(result, U256::from(41697298409483537348u128));
365 }
366
367 #[test]
368 fn test_min() {
369 let min_fn = AggregationFunction::MIN;
370
371 let values = vec![U256::from_str_radix("6776", 10).unwrap()];
373 let result = min_fn.operation(&values, None).unwrap();
374 assert_eq!(result, U256::from(6776));
375
376 let values = vec![
378 U256::from_str_radix("6776", 10).unwrap(),
379 U256::from_str_radix("6776", 10).unwrap(),
380 U256::from_str_radix("6776", 10).unwrap(),
381 U256::from_str_radix("6777", 10).unwrap(),
382 U256::from_str_radix("6777", 10).unwrap(),
383 U256::from_str_radix("6777", 10).unwrap(),
384 U256::from_str_radix("6777", 10).unwrap(),
385 U256::from_str_radix("6777", 10).unwrap(),
386 U256::from_str_radix("6777", 10).unwrap(),
387 U256::from_str_radix("6777", 10).unwrap(),
388 U256::from_str_radix("6777", 10).unwrap(),
389 ];
390 let result = min_fn.operation(&values, None).unwrap();
391 assert_eq!(result, U256::from(6776));
392
393 let values = vec![U256::from_str_radix("9184e72a000", 16).unwrap()];
395 let result = min_fn.operation(&values, None).unwrap();
396 assert_eq!(result, U256::from(10000000000000u64));
397
398 let values = vec![
400 U256::from_str_radix("9184e72a000", 16).unwrap(),
401 U256::from_str_radix("9184e72a000", 16).unwrap(),
402 U256::from_str_radix("9184e72a000", 16).unwrap(),
403 U256::from_str_radix("9184e72a000", 16).unwrap(),
404 ];
405 let result = min_fn.operation(&values, None).unwrap();
406 assert_eq!(result, U256::from(10000000000000u64));
407
408 let values = vec![
410 U256::from_str_radix("41697298409483537348", 10).unwrap(),
411 U256::from_str_radix("41697298409483537348", 10).unwrap(),
412 U256::from_str_radix("41697298409483537348", 10).unwrap(),
413 U256::from_str_radix("41697095938570171564", 10).unwrap(),
414 U256::from_str_radix("41697095938570171564", 10).unwrap(),
415 U256::from_str_radix("41697095938570171564", 10).unwrap(),
416 U256::from_str_radix("41697095938570171564", 10).unwrap(),
417 U256::from_str_radix("41697095938570171564", 10).unwrap(),
418 U256::from_str_radix("41697095938570171564", 10).unwrap(),
419 U256::from_str_radix("41697095938570171564", 10).unwrap(),
420 U256::from_str_radix("41697095938570171564", 10).unwrap(),
421 ];
422 let result = min_fn.operation(&values, None).unwrap();
423 assert_eq!(result, U256::from(41697095938570171564u128));
424 }
425
426 #[test]
427 fn test_count() {
428 let count = AggregationFunction::COUNT;
429
430 let values = vec![U256::from_str_radix("6776", 10).unwrap()];
432 let result = count
435 .operation(
436 &values,
437 Some(FunctionContext::new(
438 Operator::GreaterThanOrEqual,
439 U256::from(4095),
440 )),
441 )
442 .unwrap();
443 assert_eq!(result, U256::from(1));
444 let result = count
447 .operation(
448 &values,
449 Some(FunctionContext::new(Operator::Equal, U256::from(6776))),
450 )
451 .unwrap();
452 assert_eq!(result, U256::from(1));
453
454 let values = vec![
456 U256::from_str_radix("6776", 10).unwrap(),
457 U256::from_str_radix("6776", 10).unwrap(),
458 U256::from_str_radix("6776", 10).unwrap(),
459 U256::from_str_radix("6777", 10).unwrap(),
460 U256::from_str_radix("6777", 10).unwrap(),
461 U256::from_str_radix("6777", 10).unwrap(),
462 U256::from_str_radix("6777", 10).unwrap(),
463 U256::from_str_radix("6777", 10).unwrap(),
464 U256::from_str_radix("6777", 10).unwrap(),
465 U256::from_str_radix("6777", 10).unwrap(),
466 U256::from_str_radix("6777", 10).unwrap(),
467 ];
468 let result = count
471 .operation(
472 &values,
473 Some(FunctionContext::new(Operator::NotEqual, U256::from(6776))),
474 )
475 .unwrap();
476 assert_eq!(result, U256::from(8));
477
478 let result = count
481 .operation(
482 &values,
483 Some(FunctionContext::new(
484 Operator::GreaterThan,
485 U256::from(6776),
486 )),
487 )
488 .unwrap();
489 assert_eq!(result, U256::from(8));
490
491 let values = vec![U256::from_str_radix("9184e72a000", 16).unwrap()];
493 let result = count
496 .operation(
497 &values,
498 Some(FunctionContext::new(
499 Operator::Equal,
500 U256::from_str("10000000000000").unwrap(),
501 )),
502 )
503 .unwrap();
504 assert_eq!(result, U256::from(1));
505
506 let values = vec![
510 U256::from_str_radix("9184e72a000", 16).unwrap(),
511 U256::from_str_radix("9184e72a000", 16).unwrap(),
512 U256::from_str_radix("9184e72a000", 16).unwrap(),
513 U256::from_str_radix("9184e72a000", 16).unwrap(),
514 ];
515 let result = count
516 .operation(
517 &values,
518 Some(FunctionContext::new(
519 Operator::LessThanOrEqual,
520 U256::from_str("10000000000001").unwrap(),
521 )),
522 )
523 .unwrap();
524 assert_eq!(result, U256::from(4));
525
526 let values = vec![
530 U256::from_str_radix("41697298409483537348", 10).unwrap(),
531 U256::from_str_radix("41697298409483537348", 10).unwrap(),
532 U256::from_str_radix("41697298409483537348", 10).unwrap(),
533 U256::from_str_radix("41697095938570171564", 10).unwrap(),
534 U256::from_str_radix("41697095938570171564", 10).unwrap(),
535 U256::from_str_radix("41697095938570171564", 10).unwrap(),
536 U256::from_str_radix("41697095938570171564", 10).unwrap(),
537 U256::from_str_radix("41697095938570171564", 10).unwrap(),
538 U256::from_str_radix("41697095938570171564", 10).unwrap(),
539 U256::from_str_radix("41697095938570171564", 10).unwrap(),
540 U256::from_str_radix("41697095938570171564", 10).unwrap(),
541 ];
542 let result = count
543 .operation(
544 &values,
545 Some(FunctionContext::new(
546 Operator::LessThanOrEqual,
547 U256::from_str("41697095938570171564").unwrap(),
548 )),
549 )
550 .unwrap();
551 assert_eq!(result, U256::from(8));
552 }
553}