Skip to main content

fret_ui_headless/table/
aggregation_fns.rs

1use std::sync::Arc;
2
3use super::TanStackValue;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
6pub enum BuiltInAggregationFn {
7    Sum,
8    Min,
9    Max,
10    Extent,
11    Mean,
12    Median,
13    Unique,
14    UniqueCount,
15    Count,
16}
17
18impl BuiltInAggregationFn {
19    pub fn from_tanstack_key(key: &str) -> Option<Self> {
20        match key {
21            "sum" => Some(Self::Sum),
22            "min" => Some(Self::Min),
23            "max" => Some(Self::Max),
24            "extent" => Some(Self::Extent),
25            "mean" => Some(Self::Mean),
26            "median" => Some(Self::Median),
27            "unique" => Some(Self::Unique),
28            "uniqueCount" => Some(Self::UniqueCount),
29            "count" => Some(Self::Count),
30            _ => None,
31        }
32    }
33}
34
35#[derive(Debug, Clone, PartialEq, Eq, Default)]
36pub enum AggregationFnSpec {
37    /// TanStack `aggregationFn: 'auto'` (default).
38    #[default]
39    Auto,
40    BuiltIn(BuiltInAggregationFn),
41    /// TanStack `aggregationFn: <string>` resolved via `options.aggregationFns[key] ?? builtIn[key]`.
42    Named(Arc<str>),
43    /// No aggregation for this column.
44    None,
45}
46
47pub type AggregationFn = Arc<dyn Fn(&str, &[TanStackValue]) -> TanStackValue + Send + Sync>;
48
49fn value_key(v: &TanStackValue) -> String {
50    match v {
51        TanStackValue::Undefined => "u".into(),
52        TanStackValue::Null => "n".into(),
53        TanStackValue::Bool(b) => format!("b:{b}"),
54        TanStackValue::Number(n) => {
55            if n.is_nan() {
56                "num:nan".into()
57            } else if *n == 0.0 {
58                // SameValueZero: -0 and 0 are treated as equal.
59                "num:0".into()
60            } else {
61                format!("num:{:016x}", n.to_bits())
62            }
63        }
64        TanStackValue::String(s) => format!("s:{s}"),
65        TanStackValue::Array(arr) => {
66            let inner = arr.iter().map(value_key).collect::<Vec<_>>().join(",");
67            format!("a:[{inner}]")
68        }
69        TanStackValue::DateTime(n) => {
70            if n.is_nan() {
71                "dt:nan".into()
72            } else if *n == 0.0 {
73                "dt:0".into()
74            } else {
75                format!("dt:{:016x}", n.to_bits())
76            }
77        }
78    }
79}
80
81fn as_f64_strict(v: &TanStackValue) -> Option<f64> {
82    match v {
83        TanStackValue::Number(n) => Some(*n),
84        TanStackValue::DateTime(n) => Some(*n),
85        _ => None,
86    }
87}
88
89fn to_number_like(v: &TanStackValue) -> Option<f64> {
90    match v {
91        TanStackValue::Number(n) => Some(*n),
92        TanStackValue::DateTime(n) => Some(*n),
93        TanStackValue::Bool(b) => Some(if *b { 1.0 } else { 0.0 }),
94        TanStackValue::Null => Some(0.0),
95        TanStackValue::String(s) => s.parse::<f64>().ok(),
96        TanStackValue::Undefined | TanStackValue::Array(_) => None,
97    }
98}
99
100pub fn apply_builtin_aggregation(
101    agg: BuiltInAggregationFn,
102    values: &[TanStackValue],
103) -> TanStackValue {
104    match agg {
105        BuiltInAggregationFn::Count => TanStackValue::Number(values.len() as f64),
106        BuiltInAggregationFn::Sum => {
107            let mut sum = 0.0;
108            for v in values {
109                let Some(n) = as_f64_strict(v) else {
110                    continue;
111                };
112                if n.is_nan() {
113                    continue;
114                }
115                sum += n;
116            }
117            TanStackValue::Number(sum)
118        }
119        BuiltInAggregationFn::Min => {
120            let mut min: Option<f64> = None;
121            for v in values {
122                let Some(n) = as_f64_strict(v) else {
123                    continue;
124                };
125                if n.is_nan() {
126                    continue;
127                }
128                min = Some(match min {
129                    Some(acc) => acc.min(n),
130                    None => n,
131                });
132            }
133            min.map(TanStackValue::Number)
134                .unwrap_or(TanStackValue::Undefined)
135        }
136        BuiltInAggregationFn::Max => {
137            let mut max: Option<f64> = None;
138            for v in values {
139                let Some(n) = as_f64_strict(v) else {
140                    continue;
141                };
142                if n.is_nan() {
143                    continue;
144                }
145                max = Some(match max {
146                    Some(acc) => acc.max(n),
147                    None => n,
148                });
149            }
150            max.map(TanStackValue::Number)
151                .unwrap_or(TanStackValue::Undefined)
152        }
153        BuiltInAggregationFn::Extent => {
154            let mut min: Option<f64> = None;
155            let mut max: Option<f64> = None;
156            for v in values {
157                let Some(n) = as_f64_strict(v) else {
158                    continue;
159                };
160                if n.is_nan() {
161                    continue;
162                }
163                if min.is_none() {
164                    min = Some(n);
165                    max = Some(n);
166                } else {
167                    min = Some(min.unwrap().min(n));
168                    max = Some(max.unwrap().max(n));
169                }
170            }
171            TanStackValue::Array(vec![
172                min.map(TanStackValue::Number)
173                    .unwrap_or(TanStackValue::Undefined),
174                max.map(TanStackValue::Number)
175                    .unwrap_or(TanStackValue::Undefined),
176            ])
177        }
178        BuiltInAggregationFn::Mean => {
179            let mut sum = 0.0;
180            let mut count = 0usize;
181            for v in values {
182                let Some(n) = to_number_like(v) else {
183                    continue;
184                };
185                if n.is_nan() {
186                    continue;
187                }
188                count += 1;
189                sum += n;
190            }
191            if count == 0 {
192                return TanStackValue::Undefined;
193            }
194            TanStackValue::Number(sum / (count as f64))
195        }
196        BuiltInAggregationFn::Median => {
197            if values.is_empty() {
198                return TanStackValue::Undefined;
199            }
200            let mut nums: Vec<f64> = Vec::with_capacity(values.len());
201            for v in values {
202                let Some(n) = as_f64_strict(v) else {
203                    return TanStackValue::Undefined;
204                };
205                if n.is_nan() {
206                    return TanStackValue::Undefined;
207                }
208                nums.push(n);
209            }
210
211            if nums.len() == 1 {
212                return TanStackValue::Number(nums[0]);
213            }
214
215            nums.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
216
217            let mid = nums.len() / 2;
218            if nums.len() % 2 == 1 {
219                TanStackValue::Number(nums[mid])
220            } else {
221                TanStackValue::Number((nums[mid - 1] + nums[mid]) / 2.0)
222            }
223        }
224        BuiltInAggregationFn::Unique => {
225            let mut seen: std::collections::HashSet<String> = Default::default();
226            let mut out: Vec<TanStackValue> = Vec::new();
227            for v in values {
228                let key = value_key(v);
229                if !seen.insert(key) {
230                    continue;
231                }
232                out.push(v.clone());
233            }
234            TanStackValue::Array(out)
235        }
236        BuiltInAggregationFn::UniqueCount => {
237            let mut seen: std::collections::HashSet<String> = Default::default();
238            for v in values {
239                seen.insert(value_key(v));
240            }
241            TanStackValue::Number(seen.len() as f64)
242        }
243    }
244}
245
246pub fn resolve_auto_aggregation(values: &[TanStackValue]) -> Option<BuiltInAggregationFn> {
247    for v in values {
248        match v {
249            TanStackValue::Undefined | TanStackValue::Null => continue,
250            TanStackValue::Number(_) => return Some(BuiltInAggregationFn::Sum),
251            TanStackValue::DateTime(_) => return Some(BuiltInAggregationFn::Extent),
252            _ => return None,
253        }
254    }
255    None
256}