hdp_primitives/aggregate_fn/
integer.rs

1use std::str::FromStr;
2
3use alloy::primitives::U256;
4use anyhow::{bail, Result};
5use serde::{Deserialize, Serialize};
6
7use super::FunctionContext;
8
9/// Returns the average of the values: [`AVG`](https://en.wikipedia.org/wiki/Average)
10pub fn average(values: &[U256]) -> Result<U256> {
11    if values.is_empty() {
12        bail!("No values found");
13    }
14
15    let sum = values
16        .iter()
17        .try_fold(U256::from(0), |acc, val| acc.checked_add(*val))
18        .unwrap();
19
20    divide(sum, U256::from(values.len()))
21}
22
23// TODO: Implement bloom_filterize
24pub fn bloom_filterize(_values: &[U256]) -> Result<U256> {
25    Ok(U256::from(0))
26}
27
28/// Find the maximum value: [`MAX`](https://en.wikipedia.org/wiki/Maxima_and_minima)
29pub fn find_max(values: &[U256]) -> Result<U256> {
30    if values.is_empty() {
31        bail!("No values found");
32    }
33
34    let mut max = U256::from(0);
35
36    for value in values {
37        if value > &max {
38            max = *value;
39        }
40    }
41
42    Ok(max)
43}
44
45/// Find the minimum value: [`MIN`](https://en.wikipedia.org/wiki/Maxima_and_minima)
46pub fn find_min(values: &[U256]) -> Result<U256> {
47    if values.is_empty() {
48        bail!("No values found");
49    }
50
51    let mut min = U256::MAX;
52    for value in values {
53        if value < &min {
54            min = *value;
55        }
56    }
57
58    Ok(min)
59}
60
61/// Standard deviation
62/// wip
63pub fn standard_deviation(values: &[U256]) -> Result<U256> {
64    if values.is_empty() {
65        bail!("No values found");
66    }
67
68    let mut sum = U256::from(0);
69    let count = U256::from(values.len());
70
71    for value in values {
72        sum += value;
73    }
74
75    let avg = divide(sum, count)
76        .expect("Division have failed")
77        .to_string()
78        .parse::<f64>()
79        .unwrap();
80
81    let mut variance_sum = 0.0;
82    for value in values {
83        let value = value.to_string().parse::<f64>().unwrap();
84        variance_sum += (value - avg).powi(2);
85    }
86
87    let variance: f64 = divide(U256::from(variance_sum), U256::from(count))
88        .expect("Division have failed")
89        .to_string()
90        .parse::<f64>()
91        .unwrap();
92    Ok(U256::from(roundup(variance.sqrt().to_string())))
93}
94
95/// Sum of values: [`SUM`](https://en.wikipedia.org/wiki/Summation)
96pub fn sum(values: &[U256]) -> Result<U256> {
97    if values.is_empty() {
98        bail!("No values found");
99    }
100
101    let mut sum = U256::from(0);
102
103    for value in values {
104        sum += value;
105    }
106
107    Ok(sum)
108}
109
110/// Count number of values that satisfy a condition
111///
112/// The context is a string of 4 characters:
113/// - The first two characters are the logical operator
114/// - The last two characters are the value to compare
115///
116/// The logical operators are:
117/// - 00: Equal (=)
118/// - 01: Not equal (!=)
119/// - 02: Greater than (>)
120/// - 03: Greater than or equal (>=)
121/// - 04: Less than (<)
122/// - 05: Less than or equal (<=)
123pub fn count(values: &[U256], ctx: &FunctionContext) -> Result<U256> {
124    let logical_operator = &ctx.operator;
125    let value_to_compare = ctx.value_to_compare;
126
127    let mut condition_satisfiability_count = 0;
128
129    for value in values {
130        match logical_operator {
131            Operator::Equal => {
132                if value == &value_to_compare {
133                    condition_satisfiability_count += 1;
134                }
135            }
136            Operator::NotEqual => {
137                if value != &value_to_compare {
138                    condition_satisfiability_count += 1;
139                }
140            }
141            Operator::GreaterThan => {
142                if value > &value_to_compare {
143                    condition_satisfiability_count += 1;
144                }
145            }
146            Operator::GreaterThanOrEqual => {
147                if value >= &value_to_compare {
148                    condition_satisfiability_count += 1;
149                }
150            }
151            Operator::LessThan => {
152                if value < &value_to_compare {
153                    condition_satisfiability_count += 1;
154                }
155            }
156            Operator::LessThanOrEqual => {
157                if value <= &value_to_compare {
158                    condition_satisfiability_count += 1;
159                }
160            }
161            Operator::None => {
162                bail!("Count need logical operator");
163            }
164        }
165    }
166
167    Ok(U256::from(condition_satisfiability_count))
168}
169
170pub fn simple_linear_regression(values: &[U256]) -> Result<U256> {
171    // if value is empty or has only one value, return error
172    if values.is_empty() || values.len() == 1 {
173        bail!("At least 2 values are needed to compute SLR");
174    }
175    // TODO: handle custom compute module
176    Ok(U256::from(0))
177}
178
179#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
180#[serde(try_from = "String")]
181pub enum Operator {
182    None,
183    Equal,
184    NotEqual,
185    GreaterThan,
186    GreaterThanOrEqual,
187    LessThan,
188    LessThanOrEqual,
189}
190
191impl FromStr for Operator {
192    type Err = anyhow::Error;
193
194    fn from_str(operator: &str) -> Result<Self> {
195        match operator {
196            "eq" => Ok(Self::Equal),
197            "nq" => Ok(Self::NotEqual),
198            "gt" => Ok(Self::GreaterThan),
199            "gteq" => Ok(Self::GreaterThanOrEqual),
200            "lt" => Ok(Self::LessThan),
201            "lteq=" => Ok(Self::LessThanOrEqual),
202            "none" => Ok(Self::None),
203            _ => bail!("Unknown logical operator"),
204        }
205    }
206}
207
208impl TryFrom<String> for Operator {
209    type Error = anyhow::Error;
210
211    fn try_from(value: String) -> Result<Self> {
212        Operator::from_str(&value)
213    }
214}
215
216impl std::fmt::Display for Operator {
217    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218        let operator = match self {
219            Operator::Equal => "eq",
220            Operator::NotEqual => "nq",
221            Operator::GreaterThan => "gt",
222            Operator::GreaterThanOrEqual => "gteq",
223            Operator::LessThan => "lt",
224            Operator::LessThanOrEqual => "lteq",
225            Operator::None => "none",
226        };
227        write!(f, "{}", operator)
228    }
229}
230
231impl Operator {
232    pub fn from_symbol(symbol: &str) -> Result<Self> {
233        match symbol {
234            "=" => Ok(Self::Equal),
235            "!=" => Ok(Self::NotEqual),
236            ">" => Ok(Self::GreaterThan),
237            ">=" => Ok(Self::GreaterThanOrEqual),
238            "<" => Ok(Self::LessThan),
239            "<=" => Ok(Self::LessThanOrEqual),
240            "none" => Ok(Self::None),
241            _ => bail!("Unknown logical operator"),
242        }
243    }
244    // Convert operator to bytes
245    pub fn to_index(operator: &Self) -> u8 {
246        match operator {
247            Operator::Equal => 1,
248            Operator::NotEqual => 2,
249            Operator::GreaterThan => 3,
250            Operator::GreaterThanOrEqual => 4,
251            Operator::LessThan => 5,
252            Operator::LessThanOrEqual => 6,
253            Operator::None => 0,
254        }
255    }
256
257    pub fn from_index(bytes: u8) -> Result<Self> {
258        match bytes {
259            0 => Ok(Operator::None),
260            1 => Ok(Operator::Equal),
261            2 => Ok(Operator::NotEqual),
262            3 => Ok(Operator::GreaterThan),
263            4 => Ok(Operator::GreaterThanOrEqual),
264            5 => Ok(Operator::LessThan),
265            6 => Ok(Operator::LessThanOrEqual),
266            _ => bail!("Unknown logical operator"),
267        }
268    }
269}
270
271// Handle division properly using U256 type
272fn divide(a: U256, b: U256) -> Result<U256> {
273    if b.is_zero() {
274        bail!("Division by zero error");
275    }
276
277    let quotient = a / b;
278    let remainder = a % b;
279    let divisor_half = b / U256::from(2);
280
281    if remainder > divisor_half || (remainder == divisor_half && b % U256::from(2) == U256::from(0))
282    {
283        Ok(quotient + U256::from(1))
284    } else {
285        Ok(quotient)
286    }
287}
288
289fn roundup(value: String) -> u128 {
290    let result: f64 = value.parse().unwrap();
291    result.round() as u128
292}
293
294#[cfg(test)]
295mod tests {
296    use std::str::FromStr;
297
298    use super::*;
299
300    #[test]
301    fn test_avg() {
302        let values = vec![U256::from(1), U256::from(2), U256::from(3)];
303        assert_eq!(average(&values).unwrap(), U256::from(2));
304
305        let values = vec![U256::from(1), U256::from(2)];
306        assert_eq!(average(&values).unwrap(), U256::from(2));
307        let values = vec![U256::from_str("1000000000000").unwrap()];
308        assert_eq!(
309            average(&values).unwrap(),
310            U256::from_str("1000000000000").unwrap()
311        );
312        let values = vec![U256::from_str("41697298409483537348").unwrap()];
313        assert_eq!(
314            average(&values).unwrap(),
315            U256::from_str("41697298409483537348").unwrap()
316        );
317    }
318
319    #[test]
320    fn test_sum() {
321        let values = vec![U256::from(1), U256::from(2), U256::from(3)];
322        assert_eq!(sum(&values).unwrap(), U256::from(6));
323
324        let values = vec![U256::from(1), U256::from(2)];
325        assert_eq!(sum(&values).unwrap(), U256::from(3));
326
327        let values = vec![U256::from_str("6776").unwrap()];
328        assert_eq!(sum(&values).unwrap(), U256::from(6776));
329        let values = vec![U256::from_str("41697298409483537348").unwrap()];
330        assert_eq!(
331            sum(&values).unwrap(),
332            U256::from_str("41697298409483537348").unwrap()
333        );
334    }
335
336    #[test]
337    fn test_avg_multi() {
338        let values = vec![
339            U256::from_str("41697095938570171564").unwrap(),
340            U256::from_str("41697095938570171564").unwrap(),
341            U256::from_str("41697095938570171564").unwrap(),
342            U256::from_str("41697095938570171564").unwrap(),
343            U256::from_str("41697095938570171564").unwrap(),
344            U256::from_str("41697095938570171564").unwrap(),
345            U256::from_str("41697095938570171564").unwrap(),
346            U256::from_str("41697095938570171564").unwrap(),
347            U256::from_str("41697298409483537348").unwrap(),
348            U256::from_str("41697298409483537348").unwrap(),
349            U256::from_str("41697298409483537348").unwrap(),
350        ];
351        assert_eq!(
352            average(&values).unwrap(),
353            U256::from_str("41697151157910180414").unwrap()
354        );
355    }
356
357    #[test]
358    fn test_avg_empty() {
359        let values = vec![];
360        assert!(average(&values).is_err());
361    }
362
363    #[test]
364    fn test_find_max() {
365        let values = vec![U256::from(1), U256::from(2), U256::from(3)];
366        assert_eq!(find_max(&values).unwrap(), U256::from(3));
367
368        let values = vec![U256::from(1), U256::from(2)];
369        assert_eq!(find_max(&values).unwrap(), U256::from(2));
370    }
371
372    #[test]
373    fn test_find_min() {
374        let values = vec![U256::from(1), U256::from(2), U256::from(3)];
375        assert_eq!(find_min(&values).unwrap(), U256::from(1));
376
377        let values = vec![U256::from(1), U256::from(2)];
378        assert_eq!(find_min(&values).unwrap(), U256::from(1));
379    }
380
381    #[test]
382    fn test_std() {
383        let values = vec![U256::from(1), U256::from(2), U256::from(3)];
384        assert_eq!(standard_deviation(&values).unwrap(), U256::from(1));
385
386        let values = vec![
387            U256::from(0),
388            U256::from(2),
389            U256::from(10),
390            U256::from(2),
391            U256::from(100),
392        ];
393        assert_eq!(standard_deviation(&values).unwrap(), U256::from(39));
394    }
395
396    #[test]
397    fn test_count() {
398        let values = vec![U256::from(1), U256::from(165), U256::from(3)];
399        //    assert_eq!(count(&values, "04a5").unwrap(), "2".to_string());
400        assert_eq!(
401            count(
402                &values,
403                &FunctionContext::new(Operator::GreaterThanOrEqual, U256::from(2))
404            )
405            .unwrap(),
406            U256::from(2)
407        );
408
409        let values = vec![U256::from(1), U256::from(10)];
410        //assert_eq!(count(&values, "0000000000a").unwrap(), "1".to_string());
411        assert_eq!(
412            count(
413                &values,
414                &FunctionContext::new(Operator::GreaterThan, U256::from(1))
415            )
416            .unwrap(),
417            U256::from(1)
418        );
419    }
420}