1use anyhow::{Context, Result};
2use chrono::{Duration, NaiveDate, NaiveDateTime, NaiveTime};
3use evalexpr::{
4 ContextWithMutableFunctions, ContextWithMutableVariables, Function, HashMapContext,
5 Value as EvalValue, eval_with_context,
6};
7
8use crate::data::{
9 Value, normalize_column_name, parse_naive_date, parse_naive_datetime, parse_naive_time,
10 value_to_evalexpr,
11};
12
13fn register_temporal_functions(context: &mut HashMapContext) -> Result<()> {
14 context
15 .set_function(
16 "date_add".into(),
17 Function::new(|arguments| {
18 let args = expect_args(arguments, 2, "date_add")?;
19 let date = parse_date_arg(&args[0])?;
20 let days = parse_i64_arg(&args[1], "days")?;
21 let result = date
22 .checked_add_signed(Duration::days(days))
23 .ok_or_else(|| eval_error("date_add overflow"))?;
24 Ok(EvalValue::String(result.format("%Y-%m-%d").to_string()))
25 }),
26 )
27 .map_err(anyhow::Error::from)?;
28
29 context
30 .set_function(
31 "date_sub".into(),
32 Function::new(|arguments| {
33 let args = expect_args(arguments, 2, "date_sub")?;
34 let date = parse_date_arg(&args[0])?;
35 let days = parse_i64_arg(&args[1], "days")?;
36 let result = date
37 .checked_sub_signed(Duration::days(days))
38 .ok_or_else(|| eval_error("date_sub overflow"))?;
39 Ok(EvalValue::String(result.format("%Y-%m-%d").to_string()))
40 }),
41 )
42 .map_err(anyhow::Error::from)?;
43
44 context
45 .set_function(
46 "date_diff_days".into(),
47 Function::new(|arguments| {
48 let args = expect_args(arguments, 2, "date_diff_days")?;
49 let end = parse_date_arg(&args[0])?;
50 let start = parse_date_arg(&args[1])?;
51 let diff = (end - start).num_days();
52 Ok(EvalValue::Int(diff))
53 }),
54 )
55 .map_err(anyhow::Error::from)?;
56
57 context
58 .set_function(
59 "datetime_add_seconds".into(),
60 Function::new(|arguments| {
61 let args = expect_args(arguments, 2, "datetime_add_seconds")?;
62 let dt = parse_datetime_arg(&args[0])?;
63 let seconds = parse_i64_arg(&args[1], "seconds")?;
64 let result = dt
65 .checked_add_signed(Duration::seconds(seconds))
66 .ok_or_else(|| eval_error("datetime_add_seconds overflow"))?;
67 Ok(EvalValue::String(
68 result.format("%Y-%m-%d %H:%M:%S").to_string(),
69 ))
70 }),
71 )
72 .map_err(anyhow::Error::from)?;
73
74 context
75 .set_function(
76 "datetime_diff_seconds".into(),
77 Function::new(|arguments| {
78 let args = expect_args(arguments, 2, "datetime_diff_seconds")?;
79 let end = parse_datetime_arg(&args[0])?;
80 let start = parse_datetime_arg(&args[1])?;
81 Ok(EvalValue::Int((end - start).num_seconds()))
82 }),
83 )
84 .map_err(anyhow::Error::from)?;
85
86 context
87 .set_function(
88 "datetime_to_date".into(),
89 Function::new(|arguments| {
90 let args = expect_args(arguments, 1, "datetime_to_date")?;
91 let dt = parse_datetime_arg(&args[0])?;
92 Ok(EvalValue::String(dt.date().format("%Y-%m-%d").to_string()))
93 }),
94 )
95 .map_err(anyhow::Error::from)?;
96
97 context
98 .set_function(
99 "datetime_to_time".into(),
100 Function::new(|arguments| {
101 let args = expect_args(arguments, 1, "datetime_to_time")?;
102 let dt = parse_datetime_arg(&args[0])?;
103 Ok(EvalValue::String(dt.time().format("%H:%M:%S").to_string()))
104 }),
105 )
106 .map_err(anyhow::Error::from)?;
107
108 context
109 .set_function(
110 "time_add_seconds".into(),
111 Function::new(|arguments| {
112 let args = expect_args(arguments, 2, "time_add_seconds")?;
113 let time = parse_time_arg(&args[0])?;
114 let seconds = parse_i64_arg(&args[1], "seconds")?;
115 let (result, overflow_days) =
116 time.overflowing_add_signed(Duration::seconds(seconds));
117 if overflow_days != 0 {
118 return Err(eval_error("time_add_seconds overflow"));
119 }
120 Ok(EvalValue::String(result.format("%H:%M:%S").to_string()))
121 }),
122 )
123 .map_err(anyhow::Error::from)?;
124
125 context
126 .set_function(
127 "time_diff_seconds".into(),
128 Function::new(|arguments| {
129 let args = expect_args(arguments, 2, "time_diff_seconds")?;
130 let end = parse_time_arg(&args[0])?;
131 let start = parse_time_arg(&args[1])?;
132 Ok(EvalValue::Int((end - start).num_seconds()))
133 }),
134 )
135 .map_err(anyhow::Error::from)?;
136
137 context
138 .set_function(
139 "date_format".into(),
140 Function::new(|arguments| {
141 let args = expect_args(arguments, 2, "date_format")?;
142 let date = parse_date_arg(&args[0])?;
143 let fmt = expect_string(&args[1], "format")?;
144 Ok(EvalValue::String(date.format(fmt).to_string()))
145 }),
146 )
147 .map_err(anyhow::Error::from)?;
148
149 context
150 .set_function(
151 "datetime_format".into(),
152 Function::new(|arguments| {
153 let args = expect_args(arguments, 2, "datetime_format")?;
154 let dt = parse_datetime_arg(&args[0])?;
155 let fmt = expect_string(&args[1], "format")?;
156 Ok(EvalValue::String(dt.format(fmt).to_string()))
157 }),
158 )
159 .map_err(anyhow::Error::from)?;
160
161 Ok(())
162}
163
164fn expect_args(
165 arguments: &EvalValue,
166 expected: usize,
167 name: &str,
168) -> Result<Vec<EvalValue>, evalexpr::EvalexprError> {
169 match arguments {
170 EvalValue::Empty if expected == 0 => Ok(Vec::new()),
171 value if expected == 1 && !matches!(value, EvalValue::Tuple(_)) => Ok(vec![value.clone()]),
172 EvalValue::Tuple(values) => {
173 if values.len() != expected {
174 return Err(evalexpr::EvalexprError::wrong_function_argument_amount(
175 values.len(),
176 expected,
177 ));
178 }
179 Ok(values.clone())
180 }
181 _ => Err(eval_error(&format!(
182 "{name} expects {expected} arguments provided as a tuple"
183 ))),
184 }
185}
186
187fn eval_error(message: &str) -> evalexpr::EvalexprError {
188 evalexpr::EvalexprError::CustomMessage(message.to_string())
189}
190
191fn parse_date_arg(value: &EvalValue) -> Result<NaiveDate, evalexpr::EvalexprError> {
192 let raw = expect_string(value, "date")?;
193 parse_naive_date(raw).map_err(|err| eval_error(&err.to_string()))
194}
195
196fn parse_datetime_arg(value: &EvalValue) -> Result<NaiveDateTime, evalexpr::EvalexprError> {
197 let raw = expect_string(value, "datetime")?;
198 parse_naive_datetime(raw).map_err(|err| eval_error(&err.to_string()))
199}
200
201fn parse_time_arg(value: &EvalValue) -> Result<NaiveTime, evalexpr::EvalexprError> {
202 let raw = expect_string(value, "time")?;
203 parse_naive_time(raw).map_err(|err| eval_error(&err.to_string()))
204}
205
206fn parse_i64_arg(value: &EvalValue, name: &str) -> Result<i64, evalexpr::EvalexprError> {
207 match value {
208 EvalValue::Int(i) => Ok(*i),
209 EvalValue::Float(f) => Ok(*f as i64),
210 other => Err(eval_error(&format!(
211 "Expected integer for {name}, got {other:?}",
212 ))),
213 }
214}
215
216fn expect_string<'a>(value: &'a EvalValue, name: &str) -> Result<&'a str, evalexpr::EvalexprError> {
217 if let EvalValue::String(s) = value {
218 Ok(s)
219 } else {
220 Err(eval_error(&format!("Expected string for {name}")))
221 }
222}
223
224pub fn build_context(
225 headers: &[String],
226 raw_row: &[String],
227 typed_row: &[Option<Value>],
228 row_number: Option<usize>,
229) -> Result<HashMapContext> {
230 let mut context = HashMapContext::new();
231 register_temporal_functions(&mut context)?;
232 for (idx, header) in headers.iter().enumerate() {
233 let canon = normalize_column_name(header);
234 let key = format!("c{idx}");
235 if let Some(Some(value)) = typed_row.get(idx) {
236 let eval_value = value_to_evalexpr(value);
237 context
238 .set_value(canon.clone(), eval_value.clone())
239 .with_context(|| format!("Binding column '{header}'"))?;
240 context
241 .set_value(key, eval_value)
242 .with_context(|| format!("Binding column index {idx}"))?;
243 } else if let Some(raw) = raw_row.get(idx) {
244 context
245 .set_value(canon.clone(), EvalValue::String(raw.clone()))
246 .with_context(|| format!("Binding raw column '{header}'"))?;
247 context
248 .set_value(key, EvalValue::String(raw.clone()))
249 .with_context(|| format!("Binding raw column index {idx}"))?;
250 }
251 }
252
253 if let Some(number) = row_number {
254 context
255 .set_value("row_number".to_string(), EvalValue::Int(number as i64))
256 .context("Binding row_number")?;
257 }
258
259 Ok(context)
260}
261
262pub fn evaluate_expression_to_bool(expr: &str, context: &HashMapContext) -> Result<bool> {
263 let result = eval_with_context(expr, context)
264 .with_context(|| format!("Evaluating expression '{expr}'"))?;
265 Ok(eval_value_truthy(result))
266}
267
268pub fn eval_value_truthy(value: EvalValue) -> bool {
269 match value {
270 EvalValue::Boolean(b) => b,
271 EvalValue::Int(i) => i != 0,
272 EvalValue::Float(f) => f != 0.0,
273 EvalValue::String(s) => !s.is_empty(),
274 EvalValue::Tuple(values) => values.into_iter().any(eval_value_truthy),
275 EvalValue::Empty => false,
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use proptest::prelude::*;
283
284 #[test]
285 fn date_add_and_diff_work() {
286 let mut ctx = HashMapContext::new();
287 register_temporal_functions(&mut ctx).unwrap();
288 let added = eval_with_context("date_add(\"2024-01-01\", 5)", &ctx)
289 .unwrap()
290 .as_string()
291 .unwrap()
292 .to_string();
293 assert_eq!(added, "2024-01-06");
294 let diff = eval_with_context("date_diff_days(\"2024-01-10\", \"2024-01-01\")", &ctx)
295 .unwrap()
296 .as_int()
297 .unwrap();
298 assert_eq!(diff, 9);
299 }
300
301 #[test]
302 fn datetime_functions_roundtrip() {
303 let mut ctx = HashMapContext::new();
304 register_temporal_functions(&mut ctx).unwrap();
305 let added = eval_with_context("datetime_add_seconds(\"2024-01-01 00:00:00\", 3661)", &ctx)
306 .unwrap()
307 .as_string()
308 .unwrap()
309 .to_string();
310 assert_eq!(added, "2024-01-01 01:01:01");
311 let diff = eval_with_context(
312 "datetime_diff_seconds(\"2024-01-01 01:01:01\", \"2024-01-01 00:00:00\")",
313 &ctx,
314 )
315 .unwrap()
316 .as_int()
317 .unwrap();
318 assert_eq!(diff, 3661);
319 }
320
321 #[test]
322 fn time_functions_behave() {
323 let mut ctx = HashMapContext::new();
324 register_temporal_functions(&mut ctx).unwrap();
325 let added = eval_with_context("time_add_seconds(\"08:00:00\", 90)", &ctx)
326 .unwrap()
327 .as_string()
328 .unwrap()
329 .to_string();
330 assert_eq!(added, "08:01:30");
331 let diff = eval_with_context("time_diff_seconds(\"08:01:30\", \"08:00:00\")", &ctx)
332 .unwrap()
333 .as_int()
334 .unwrap();
335 assert_eq!(diff, 90);
336 }
337
338 proptest! {
339 #[test]
340 fn evaluate_expression_handles_random_numeric_context(
341 a in -10_000i64..=10_000,
342 b in -10_000i64..=10_000,
343 header0 in "[A-Za-z0-9_ ]{3,12}",
344 header1 in "[A-Za-z0-9_ ]{3,12}"
345 ) {
346 let headers = vec![header0.clone(), header1.clone()];
347 let raw = vec![a.to_string(), b.to_string()];
348 let typed = vec![Some(Value::Integer(a)), Some(Value::Integer(b))];
349 let context = build_context(&headers, &raw, &typed, None).expect("build context");
350 let name0 = normalize_column_name(&header0);
351 let name1 = normalize_column_name(&header1);
352 let expr_named = format!("({name0} + {name1}) > {name0}");
353 let expr_indexed = "(c0 + c1) > c0";
354 let lhs = evaluate_expression_to_bool(&expr_named, &context).expect("named expression");
355 let rhs = evaluate_expression_to_bool(expr_indexed, &context).expect("indexed expression");
356 prop_assert_eq!(lhs, rhs);
357 }
358 }
359}