Skip to main content

ganit_core/eval/functions/database/
mod.rs

1use super::super::{FunctionMeta, Registry};
2use crate::eval::functions::math::criterion::{matches_criterion, parse_criterion};
3use crate::types::{ErrorKind, Value};
4
5// ── Helpers ───────────────────────────────────────────────────────────────────
6
7/// Extract rows from a 2D Value::Array.
8/// The outer array contains rows (each row is a Value::Array).
9/// Returns None if the structure is not a valid 2D array.
10fn extract_rows(v: &Value) -> Option<Vec<&[Value]>> {
11    match v {
12        Value::Array(outer) => {
13            let rows: Option<Vec<&[Value]>> = outer
14                .iter()
15                .map(|row| match row {
16                    Value::Array(r) => Some(r.as_slice()),
17                    _ => None,
18                })
19                .collect();
20            rows
21        }
22        _ => None,
23    }
24}
25
26/// Resolve field argument to a 0-based column index.
27/// field can be a 1-based number or a column name string matching headers.
28fn resolve_field(field: &Value, headers: &[Value]) -> Option<usize> {
29    match field {
30        Value::Number(n) => {
31            let idx = *n as usize;
32            if idx >= 1 && idx <= headers.len() {
33                Some(idx - 1)
34            } else {
35                None
36            }
37        }
38        Value::Text(name) => {
39            let name_lower = name.to_lowercase();
40            headers.iter().position(|h| match h {
41                Value::Text(s) => s.to_lowercase() == name_lower,
42                _ => false,
43            })
44        }
45        _ => None,
46    }
47}
48
49/// Check whether a data row matches all criteria.
50/// criteria_rows: [header_row, criteria_row]
51/// data_headers: the database headers
52/// data_row: the data row to check
53fn row_matches_criteria(
54    criteria_rows: &[&[Value]],
55    data_headers: &[Value],
56    data_row: &[Value],
57) -> bool {
58    if criteria_rows.len() < 2 {
59        return false;
60    }
61    let crit_headers = criteria_rows[0];
62    let crit_values = criteria_rows[1];
63
64    for (crit_col, crit_val) in crit_headers.iter().zip(crit_values.iter()) {
65        // Skip empty criteria
66        match crit_val {
67            Value::Empty => continue,
68            Value::Text(s) if s.is_empty() => continue,
69            _ => {}
70        }
71
72        // Find the matching data column
73        let col_name = match crit_col {
74            Value::Text(s) => s.to_lowercase(),
75            _ => continue,
76        };
77        let col_idx = data_headers.iter().position(|h| match h {
78            Value::Text(s) => s.to_lowercase() == col_name,
79            _ => false,
80        });
81        let col_idx = match col_idx {
82            Some(i) => i,
83            None => return false, // criteria column not found in database
84        };
85
86        let cell_val = data_row.get(col_idx).unwrap_or(&Value::Empty);
87        let criterion = parse_criterion(crit_val);
88        if !matches_criterion(cell_val, &criterion) {
89            return false;
90        }
91    }
92    true
93}
94
95/// Parse the database and criteria arguments, collect the field values for
96/// matching rows.
97///
98/// Returns `Err(Value)` on structural errors (wrong types, field not found).
99fn collect_matching_values(args: &[Value]) -> Result<Vec<Value>, Value> {
100    if args.len() != 3 {
101        return Err(Value::Error(ErrorKind::NA));
102    }
103    let db_rows = extract_rows(&args[0])
104        .ok_or(Value::Error(ErrorKind::Value))?;
105
106    if db_rows.len() < 2 {
107        // Need at least header row + one data row
108        return Ok(vec![]);
109    }
110
111    let headers = db_rows[0];
112    let field_idx = resolve_field(&args[1], headers)
113        .ok_or(Value::Error(ErrorKind::Value))?;
114
115    let crit_rows = extract_rows(&args[2])
116        .ok_or(Value::Error(ErrorKind::Value))?;
117
118    if crit_rows.len() < 2 {
119        return Err(Value::Error(ErrorKind::Value));
120    }
121
122    let mut values = Vec::new();
123    for data_row in &db_rows[1..] {
124        if row_matches_criteria(&crit_rows, headers, data_row) {
125            let val = data_row.get(field_idx).cloned().unwrap_or(Value::Empty);
126            values.push(val);
127        }
128    }
129    Ok(values)
130}
131
132// ── D* functions ──────────────────────────────────────────────────────────────
133
134/// `DSUM(database, field, criteria)` — sum of field values for matching rows.
135pub fn dsum_fn(args: &[Value]) -> Value {
136    match collect_matching_values(args) {
137        Err(e) => e,
138        Ok(values) => {
139            let mut sum = 0.0_f64;
140            for v in &values {
141                if let Value::Number(n) = v {
142                    sum += n;
143                }
144            }
145            Value::Number(sum)
146        }
147    }
148}
149
150/// `DAVERAGE(database, field, criteria)` — average of field values for matching rows.
151pub fn daverage_fn(args: &[Value]) -> Value {
152    match collect_matching_values(args) {
153        Err(e) => e,
154        Ok(values) => {
155            let nums: Vec<f64> = values
156                .iter()
157                .filter_map(|v| if let Value::Number(n) = v { Some(*n) } else { None })
158                .collect();
159            if nums.is_empty() {
160                return Value::Error(ErrorKind::DivByZero);
161            }
162            Value::Number(nums.iter().sum::<f64>() / nums.len() as f64)
163        }
164    }
165}
166
167/// `DCOUNT(database, field, criteria)` — count of numeric field values for matching rows.
168pub fn dcount_fn(args: &[Value]) -> Value {
169    match collect_matching_values(args) {
170        Err(e) => e,
171        Ok(values) => {
172            let count = values
173                .iter()
174                .filter(|v| matches!(v, Value::Number(_)))
175                .count();
176            Value::Number(count as f64)
177        }
178    }
179}
180
181/// `DCOUNTA(database, field, criteria)` — count of non-empty field values for matching rows.
182pub fn dcounta_fn(args: &[Value]) -> Value {
183    match collect_matching_values(args) {
184        Err(e) => e,
185        Ok(values) => {
186            let count = values
187                .iter()
188                .filter(|v| !matches!(v, Value::Empty))
189                .count();
190            Value::Number(count as f64)
191        }
192    }
193}
194
195/// `DGET(database, field, criteria)` — returns the single matching value, or error.
196pub fn dget_fn(args: &[Value]) -> Value {
197    match collect_matching_values(args) {
198        Err(e) => e,
199        Ok(values) => {
200            if values.len() == 1 {
201                values.into_iter().next().unwrap()
202            } else if values.is_empty() {
203                Value::Error(ErrorKind::Value)
204            } else {
205                // Multiple matches
206                Value::Error(ErrorKind::Num)
207            }
208        }
209    }
210}
211
212/// `DMAX(database, field, criteria)` — max of field values for matching rows.
213pub fn dmax_fn(args: &[Value]) -> Value {
214    match collect_matching_values(args) {
215        Err(e) => e,
216        Ok(values) => {
217            let nums: Vec<f64> = values
218                .iter()
219                .filter_map(|v| if let Value::Number(n) = v { Some(*n) } else { None })
220                .collect();
221            if nums.is_empty() {
222                return Value::Number(0.0);
223            }
224            Value::Number(nums.iter().cloned().fold(f64::NEG_INFINITY, f64::max))
225        }
226    }
227}
228
229/// `DMIN(database, field, criteria)` — min of field values for matching rows.
230pub fn dmin_fn(args: &[Value]) -> Value {
231    match collect_matching_values(args) {
232        Err(e) => e,
233        Ok(values) => {
234            let nums: Vec<f64> = values
235                .iter()
236                .filter_map(|v| if let Value::Number(n) = v { Some(*n) } else { None })
237                .collect();
238            if nums.is_empty() {
239                return Value::Number(0.0);
240            }
241            Value::Number(nums.iter().cloned().fold(f64::INFINITY, f64::min))
242        }
243    }
244}
245
246/// `DPRODUCT(database, field, criteria)` — product of field values for matching rows.
247pub fn dproduct_fn(args: &[Value]) -> Value {
248    match collect_matching_values(args) {
249        Err(e) => e,
250        Ok(values) => {
251            let nums: Vec<f64> = values
252                .iter()
253                .filter_map(|v| if let Value::Number(n) = v { Some(*n) } else { None })
254                .collect();
255            if nums.is_empty() {
256                return Value::Number(0.0);
257            }
258            Value::Number(nums.iter().product())
259        }
260    }
261}
262
263/// `DSTDEV(database, field, criteria)` — sample standard deviation of matching numeric values.
264pub fn dstdev_fn(args: &[Value]) -> Value {
265    match collect_matching_values(args) {
266        Err(e) => e,
267        Ok(values) => {
268            let nums: Vec<f64> = values
269                .iter()
270                .filter_map(|v| if let Value::Number(n) = v { Some(*n) } else { None })
271                .collect();
272            let n = nums.len();
273            if n < 2 {
274                return Value::Error(ErrorKind::DivByZero);
275            }
276            let mean = nums.iter().sum::<f64>() / n as f64;
277            let var = nums.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (n - 1) as f64;
278            Value::Number(var.sqrt())
279        }
280    }
281}
282
283/// `DSTDEVP(database, field, criteria)` — population standard deviation of matching numeric values.
284pub fn dstdevp_fn(args: &[Value]) -> Value {
285    match collect_matching_values(args) {
286        Err(e) => e,
287        Ok(values) => {
288            let nums: Vec<f64> = values
289                .iter()
290                .filter_map(|v| if let Value::Number(n) = v { Some(*n) } else { None })
291                .collect();
292            let n = nums.len();
293            if n == 0 {
294                return Value::Error(ErrorKind::DivByZero);
295            }
296            let mean = nums.iter().sum::<f64>() / n as f64;
297            let var = nums.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n as f64;
298            Value::Number(var.sqrt())
299        }
300    }
301}
302
303/// `DVAR(database, field, criteria)` — sample variance of matching numeric values.
304pub fn dvar_fn(args: &[Value]) -> Value {
305    match collect_matching_values(args) {
306        Err(e) => e,
307        Ok(values) => {
308            let nums: Vec<f64> = values
309                .iter()
310                .filter_map(|v| if let Value::Number(n) = v { Some(*n) } else { None })
311                .collect();
312            let n = nums.len();
313            if n < 2 {
314                return Value::Error(ErrorKind::DivByZero);
315            }
316            let mean = nums.iter().sum::<f64>() / n as f64;
317            let var = nums.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (n - 1) as f64;
318            Value::Number(var)
319        }
320    }
321}
322
323/// `DVARP(database, field, criteria)` — population variance of matching numeric values.
324pub fn dvarp_fn(args: &[Value]) -> Value {
325    match collect_matching_values(args) {
326        Err(e) => e,
327        Ok(values) => {
328            let nums: Vec<f64> = values
329                .iter()
330                .filter_map(|v| if let Value::Number(n) = v { Some(*n) } else { None })
331                .collect();
332            let n = nums.len();
333            if n == 0 {
334                return Value::Error(ErrorKind::DivByZero);
335            }
336            let mean = nums.iter().sum::<f64>() / n as f64;
337            let var = nums.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n as f64;
338            Value::Number(var)
339        }
340    }
341}
342
343// ── Registration ──────────────────────────────────────────────────────────────
344
345pub fn register_database(registry: &mut Registry) {
346    registry.register_eager("DSUM",     dsum_fn,     FunctionMeta { category: "database", signature: "DSUM(database, field, criteria)",     description: "Sum of field values for rows matching criteria" });
347    registry.register_eager("DAVERAGE", daverage_fn, FunctionMeta { category: "database", signature: "DAVERAGE(database, field, criteria)", description: "Average of field values for rows matching criteria" });
348    registry.register_eager("DCOUNT",   dcount_fn,   FunctionMeta { category: "database", signature: "DCOUNT(database, field, criteria)",   description: "Count of numeric field values for rows matching criteria" });
349    registry.register_eager("DCOUNTA",  dcounta_fn,  FunctionMeta { category: "database", signature: "DCOUNTA(database, field, criteria)",  description: "Count of non-empty field values for rows matching criteria" });
350    registry.register_eager("DGET",     dget_fn,     FunctionMeta { category: "database", signature: "DGET(database, field, criteria)",     description: "Single field value for rows matching criteria" });
351    registry.register_eager("DMAX",     dmax_fn,     FunctionMeta { category: "database", signature: "DMAX(database, field, criteria)",     description: "Maximum field value for rows matching criteria" });
352    registry.register_eager("DMIN",     dmin_fn,     FunctionMeta { category: "database", signature: "DMIN(database, field, criteria)",     description: "Minimum field value for rows matching criteria" });
353    registry.register_eager("DPRODUCT", dproduct_fn, FunctionMeta { category: "database", signature: "DPRODUCT(database, field, criteria)", description: "Product of field values for rows matching criteria" });
354    registry.register_eager("DSTDEV",   dstdev_fn,   FunctionMeta { category: "database", signature: "DSTDEV(database, field, criteria)",   description: "Sample standard deviation of field values for rows matching criteria" });
355    registry.register_eager("DSTDEVP",  dstdevp_fn,  FunctionMeta { category: "database", signature: "DSTDEVP(database, field, criteria)",  description: "Population standard deviation of field values for rows matching criteria" });
356    registry.register_eager("DVAR",     dvar_fn,     FunctionMeta { category: "database", signature: "DVAR(database, field, criteria)",     description: "Sample variance of field values for rows matching criteria" });
357    registry.register_eager("DVARP",    dvarp_fn,    FunctionMeta { category: "database", signature: "DVARP(database, field, criteria)",    description: "Population variance of field values for rows matching criteria" });
358}
359
360#[cfg(test)]
361mod tests;