csv_managed/
expr.rs

1use anyhow::{Context, Result};
2use chrono::{Duration, NaiveDate, NaiveDateTime, NaiveTime};
3use evalexpr::{
4    ContextWithMutableFunctions, ContextWithMutableVariables, Function, HashMapContext,
5    Value as EvalValue, eval_with_context,
6};
7
8use crate::data::{
9    Value, normalize_column_name, parse_naive_date, parse_naive_datetime, parse_naive_time,
10    value_to_evalexpr,
11};
12
13fn register_temporal_functions(context: &mut HashMapContext) -> Result<()> {
14    context
15        .set_function(
16            "date_add".into(),
17            Function::new(|arguments| {
18                let args = expect_args(arguments, 2, "date_add")?;
19                let date = parse_date_arg(&args[0])?;
20                let days = parse_i64_arg(&args[1], "days")?;
21                let result = date
22                    .checked_add_signed(Duration::days(days))
23                    .ok_or_else(|| eval_error("date_add overflow"))?;
24                Ok(EvalValue::String(result.format("%Y-%m-%d").to_string()))
25            }),
26        )
27        .map_err(anyhow::Error::from)?;
28
29    context
30        .set_function(
31            "date_sub".into(),
32            Function::new(|arguments| {
33                let args = expect_args(arguments, 2, "date_sub")?;
34                let date = parse_date_arg(&args[0])?;
35                let days = parse_i64_arg(&args[1], "days")?;
36                let result = date
37                    .checked_sub_signed(Duration::days(days))
38                    .ok_or_else(|| eval_error("date_sub overflow"))?;
39                Ok(EvalValue::String(result.format("%Y-%m-%d").to_string()))
40            }),
41        )
42        .map_err(anyhow::Error::from)?;
43
44    context
45        .set_function(
46            "date_diff_days".into(),
47            Function::new(|arguments| {
48                let args = expect_args(arguments, 2, "date_diff_days")?;
49                let end = parse_date_arg(&args[0])?;
50                let start = parse_date_arg(&args[1])?;
51                let diff = (end - start).num_days();
52                Ok(EvalValue::Int(diff))
53            }),
54        )
55        .map_err(anyhow::Error::from)?;
56
57    context
58        .set_function(
59            "datetime_add_seconds".into(),
60            Function::new(|arguments| {
61                let args = expect_args(arguments, 2, "datetime_add_seconds")?;
62                let dt = parse_datetime_arg(&args[0])?;
63                let seconds = parse_i64_arg(&args[1], "seconds")?;
64                let result = dt
65                    .checked_add_signed(Duration::seconds(seconds))
66                    .ok_or_else(|| eval_error("datetime_add_seconds overflow"))?;
67                Ok(EvalValue::String(
68                    result.format("%Y-%m-%d %H:%M:%S").to_string(),
69                ))
70            }),
71        )
72        .map_err(anyhow::Error::from)?;
73
74    context
75        .set_function(
76            "datetime_diff_seconds".into(),
77            Function::new(|arguments| {
78                let args = expect_args(arguments, 2, "datetime_diff_seconds")?;
79                let end = parse_datetime_arg(&args[0])?;
80                let start = parse_datetime_arg(&args[1])?;
81                Ok(EvalValue::Int((end - start).num_seconds()))
82            }),
83        )
84        .map_err(anyhow::Error::from)?;
85
86    context
87        .set_function(
88            "datetime_to_date".into(),
89            Function::new(|arguments| {
90                let args = expect_args(arguments, 1, "datetime_to_date")?;
91                let dt = parse_datetime_arg(&args[0])?;
92                Ok(EvalValue::String(dt.date().format("%Y-%m-%d").to_string()))
93            }),
94        )
95        .map_err(anyhow::Error::from)?;
96
97    context
98        .set_function(
99            "datetime_to_time".into(),
100            Function::new(|arguments| {
101                let args = expect_args(arguments, 1, "datetime_to_time")?;
102                let dt = parse_datetime_arg(&args[0])?;
103                Ok(EvalValue::String(dt.time().format("%H:%M:%S").to_string()))
104            }),
105        )
106        .map_err(anyhow::Error::from)?;
107
108    context
109        .set_function(
110            "time_add_seconds".into(),
111            Function::new(|arguments| {
112                let args = expect_args(arguments, 2, "time_add_seconds")?;
113                let time = parse_time_arg(&args[0])?;
114                let seconds = parse_i64_arg(&args[1], "seconds")?;
115                let (result, overflow_days) =
116                    time.overflowing_add_signed(Duration::seconds(seconds));
117                if overflow_days != 0 {
118                    return Err(eval_error("time_add_seconds overflow"));
119                }
120                Ok(EvalValue::String(result.format("%H:%M:%S").to_string()))
121            }),
122        )
123        .map_err(anyhow::Error::from)?;
124
125    context
126        .set_function(
127            "time_diff_seconds".into(),
128            Function::new(|arguments| {
129                let args = expect_args(arguments, 2, "time_diff_seconds")?;
130                let end = parse_time_arg(&args[0])?;
131                let start = parse_time_arg(&args[1])?;
132                Ok(EvalValue::Int((end - start).num_seconds()))
133            }),
134        )
135        .map_err(anyhow::Error::from)?;
136
137    context
138        .set_function(
139            "date_format".into(),
140            Function::new(|arguments| {
141                let args = expect_args(arguments, 2, "date_format")?;
142                let date = parse_date_arg(&args[0])?;
143                let fmt = expect_string(&args[1], "format")?;
144                Ok(EvalValue::String(date.format(fmt).to_string()))
145            }),
146        )
147        .map_err(anyhow::Error::from)?;
148
149    context
150        .set_function(
151            "datetime_format".into(),
152            Function::new(|arguments| {
153                let args = expect_args(arguments, 2, "datetime_format")?;
154                let dt = parse_datetime_arg(&args[0])?;
155                let fmt = expect_string(&args[1], "format")?;
156                Ok(EvalValue::String(dt.format(fmt).to_string()))
157            }),
158        )
159        .map_err(anyhow::Error::from)?;
160
161    Ok(())
162}
163
164fn expect_args(
165    arguments: &EvalValue,
166    expected: usize,
167    name: &str,
168) -> Result<Vec<EvalValue>, evalexpr::EvalexprError> {
169    match arguments {
170        EvalValue::Empty if expected == 0 => Ok(Vec::new()),
171        value if expected == 1 && !matches!(value, EvalValue::Tuple(_)) => Ok(vec![value.clone()]),
172        EvalValue::Tuple(values) => {
173            if values.len() != expected {
174                return Err(evalexpr::EvalexprError::wrong_function_argument_amount(
175                    values.len(),
176                    expected,
177                ));
178            }
179            Ok(values.clone())
180        }
181        _ => Err(eval_error(&format!(
182            "{name} expects {expected} arguments provided as a tuple"
183        ))),
184    }
185}
186
187fn eval_error(message: &str) -> evalexpr::EvalexprError {
188    evalexpr::EvalexprError::CustomMessage(message.to_string())
189}
190
191fn parse_date_arg(value: &EvalValue) -> Result<NaiveDate, evalexpr::EvalexprError> {
192    let raw = expect_string(value, "date")?;
193    parse_naive_date(raw).map_err(|err| eval_error(&err.to_string()))
194}
195
196fn parse_datetime_arg(value: &EvalValue) -> Result<NaiveDateTime, evalexpr::EvalexprError> {
197    let raw = expect_string(value, "datetime")?;
198    parse_naive_datetime(raw).map_err(|err| eval_error(&err.to_string()))
199}
200
201fn parse_time_arg(value: &EvalValue) -> Result<NaiveTime, evalexpr::EvalexprError> {
202    let raw = expect_string(value, "time")?;
203    parse_naive_time(raw).map_err(|err| eval_error(&err.to_string()))
204}
205
206fn parse_i64_arg(value: &EvalValue, name: &str) -> Result<i64, evalexpr::EvalexprError> {
207    match value {
208        EvalValue::Int(i) => Ok(*i),
209        EvalValue::Float(f) => Ok(*f as i64),
210        other => Err(eval_error(&format!(
211            "Expected integer for {name}, got {other:?}",
212        ))),
213    }
214}
215
216fn expect_string<'a>(value: &'a EvalValue, name: &str) -> Result<&'a str, evalexpr::EvalexprError> {
217    if let EvalValue::String(s) = value {
218        Ok(s)
219    } else {
220        Err(eval_error(&format!("Expected string for {name}")))
221    }
222}
223
224pub fn build_context(
225    headers: &[String],
226    raw_row: &[String],
227    typed_row: &[Option<Value>],
228    row_number: Option<usize>,
229) -> Result<HashMapContext> {
230    let mut context = HashMapContext::new();
231    register_temporal_functions(&mut context)?;
232    for (idx, header) in headers.iter().enumerate() {
233        let canon = normalize_column_name(header);
234        let key = format!("c{idx}");
235        if let Some(Some(value)) = typed_row.get(idx) {
236            let eval_value = value_to_evalexpr(value);
237            context
238                .set_value(canon.clone(), eval_value.clone())
239                .with_context(|| format!("Binding column '{header}'"))?;
240            context
241                .set_value(key, eval_value)
242                .with_context(|| format!("Binding column index {idx}"))?;
243        } else if let Some(raw) = raw_row.get(idx) {
244            context
245                .set_value(canon.clone(), EvalValue::String(raw.clone()))
246                .with_context(|| format!("Binding raw column '{header}'"))?;
247            context
248                .set_value(key, EvalValue::String(raw.clone()))
249                .with_context(|| format!("Binding raw column index {idx}"))?;
250        }
251    }
252
253    if let Some(number) = row_number {
254        context
255            .set_value("row_number".to_string(), EvalValue::Int(number as i64))
256            .context("Binding row_number")?;
257    }
258
259    Ok(context)
260}
261
262pub fn evaluate_expression_to_bool(expr: &str, context: &HashMapContext) -> Result<bool> {
263    let result = eval_with_context(expr, context)
264        .with_context(|| format!("Evaluating expression '{expr}'"))?;
265    Ok(eval_value_truthy(result))
266}
267
268pub fn eval_value_truthy(value: EvalValue) -> bool {
269    match value {
270        EvalValue::Boolean(b) => b,
271        EvalValue::Int(i) => i != 0,
272        EvalValue::Float(f) => f != 0.0,
273        EvalValue::String(s) => !s.is_empty(),
274        EvalValue::Tuple(values) => values.into_iter().any(eval_value_truthy),
275        EvalValue::Empty => false,
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282    use proptest::prelude::*;
283
284    #[test]
285    fn date_add_and_diff_work() {
286        let mut ctx = HashMapContext::new();
287        register_temporal_functions(&mut ctx).unwrap();
288        let added = eval_with_context("date_add(\"2024-01-01\", 5)", &ctx)
289            .unwrap()
290            .as_string()
291            .unwrap()
292            .to_string();
293        assert_eq!(added, "2024-01-06");
294        let diff = eval_with_context("date_diff_days(\"2024-01-10\", \"2024-01-01\")", &ctx)
295            .unwrap()
296            .as_int()
297            .unwrap();
298        assert_eq!(diff, 9);
299    }
300
301    #[test]
302    fn datetime_functions_roundtrip() {
303        let mut ctx = HashMapContext::new();
304        register_temporal_functions(&mut ctx).unwrap();
305        let added = eval_with_context("datetime_add_seconds(\"2024-01-01 00:00:00\", 3661)", &ctx)
306            .unwrap()
307            .as_string()
308            .unwrap()
309            .to_string();
310        assert_eq!(added, "2024-01-01 01:01:01");
311        let diff = eval_with_context(
312            "datetime_diff_seconds(\"2024-01-01 01:01:01\", \"2024-01-01 00:00:00\")",
313            &ctx,
314        )
315        .unwrap()
316        .as_int()
317        .unwrap();
318        assert_eq!(diff, 3661);
319    }
320
321    #[test]
322    fn time_functions_behave() {
323        let mut ctx = HashMapContext::new();
324        register_temporal_functions(&mut ctx).unwrap();
325        let added = eval_with_context("time_add_seconds(\"08:00:00\", 90)", &ctx)
326            .unwrap()
327            .as_string()
328            .unwrap()
329            .to_string();
330        assert_eq!(added, "08:01:30");
331        let diff = eval_with_context("time_diff_seconds(\"08:01:30\", \"08:00:00\")", &ctx)
332            .unwrap()
333            .as_int()
334            .unwrap();
335        assert_eq!(diff, 90);
336    }
337
338    proptest! {
339        #[test]
340        fn evaluate_expression_handles_random_numeric_context(
341            a in -10_000i64..=10_000,
342            b in -10_000i64..=10_000,
343            header0 in "[A-Za-z0-9_ ]{3,12}",
344            header1 in "[A-Za-z0-9_ ]{3,12}"
345        ) {
346            let headers = vec![header0.clone(), header1.clone()];
347            let raw = vec![a.to_string(), b.to_string()];
348            let typed = vec![Some(Value::Integer(a)), Some(Value::Integer(b))];
349            let context = build_context(&headers, &raw, &typed, None).expect("build context");
350            let name0 = normalize_column_name(&header0);
351            let name1 = normalize_column_name(&header1);
352            let expr_named = format!("({name0} + {name1}) > {name0}");
353            let expr_indexed = "(c0 + c1) > c0";
354            let lhs = evaluate_expression_to_bool(&expr_named, &context).expect("named expression");
355            let rhs = evaluate_expression_to_bool(expr_indexed, &context).expect("indexed expression");
356            prop_assert_eq!(lhs, rhs);
357        }
358    }
359}