Skip to main content

sda_lib/
eval.rs

1use crate::ast::*;
2use crate::stdlib;
3use crate::{Env, Value};
4use thiserror::Error;
5
6#[derive(Debug, Error)]
7pub enum EvalError {
8    #[error("Unbound variable: {0}")]
9    UnboundVar(String),
10    #[error("Type error: {0}")]
11    TypeError(String),
12    #[error("Missing key: {0}")]
13    MissingKey(String),
14    #[error("Wrong shape: {0}")]
15    WrongShape(String),
16    #[error("Duplicate key: {0}")]
17    DuplicateKey(String),
18    #[error("Not callable: {0}")]
19    NotCallable(String),
20    #[error("Arity mismatch: expected {expected}, got {got}")]
21    ArityMismatch { expected: usize, got: usize },
22}
23
24fn fail_value(code: &str, msg: &str) -> Value {
25    Value::Fail_(code.to_string(), msg.to_string())
26}
27
28fn wrong_shape_value() -> Value {
29    fail_value("t_sda_wrong_shape", "wrong shape")
30}
31
32fn div_by_zero_value() -> Value {
33    fail_value("t_sda_div_by_zero", "division by zero")
34}
35
36fn unbound_name_value() -> Value {
37    fail_value("t_sda_unbound_name", "unbound name")
38}
39
40fn not_callable_value() -> Value {
41    fail_value("t_sda_not_callable", "not callable")
42}
43
44fn arity_mismatch_value() -> Value {
45    fail_value("t_sda_arity_mismatch", "arity mismatch")
46}
47
48pub(crate) fn ensure_comparable(value: &Value) -> Result<(), EvalError> {
49    match value {
50        Value::Null
51        | Value::Bool(_)
52        | Value::Num(_)
53        | Value::Str(_)
54        | Value::Bytes(_)
55        | Value::None_
56        | Value::Fail_(_, _) => Ok(()),
57        Value::Seq(items) | Value::Set(items) | Value::Bag(items) => {
58            for item in items {
59                ensure_comparable(item)?;
60            }
61            Ok(())
62        }
63        Value::Map(entries) | Value::Prod(entries) => {
64            for (_, value) in entries {
65                ensure_comparable(value)?;
66            }
67            Ok(())
68        }
69        Value::BagKV(pairs) => {
70            for (key, value) in pairs {
71                ensure_comparable(key)?;
72                ensure_comparable(value)?;
73            }
74            Ok(())
75        }
76        Value::Bind(key, value) => {
77            ensure_comparable(key)?;
78            ensure_comparable(value)
79        }
80        Value::Some_(inner) | Value::Ok_(inner) => ensure_comparable(inner),
81        Value::Lambda(_, _, _) => Err(EvalError::TypeError(
82            "function values are not comparable".to_string(),
83        )),
84    }
85}
86
87pub(crate) fn values_equal(a: &Value, b: &Value) -> bool {
88    a == b
89}
90
91pub fn eval_expr(expr: &Expr, env: &Env) -> Result<Value, EvalError> {
92    match expr {
93        Expr::Null => Ok(Value::Null),
94        Expr::Bool(b) => Ok(Value::Bool(*b)),
95        Expr::Num(n) => Ok(Value::Num(n.clone())),
96        Expr::Str(s) => Ok(Value::Str(s.clone())),
97        Expr::Bytes(bytes) => Ok(Value::Bytes(bytes.clone())),
98        Expr::Placeholder => Ok(env.get("_").cloned().unwrap_or_else(|| {
99            Value::Fail_(
100                "t_sda_unbound_placeholder".to_string(),
101                "unbound placeholder".to_string(),
102            )
103        })),
104        Expr::Ident(name) => env
105            .get(name)
106            .cloned()
107            .map_or_else(|| Ok(unbound_name_value()), Ok),
108        Expr::Seq(items) => {
109            let values: Result<Vec<Value>, EvalError> =
110                items.iter().map(|item| eval_expr(item, env)).collect();
111            Ok(Value::Seq(values?))
112        }
113        Expr::Set(items) => {
114            let mut values = Vec::new();
115            for item in items {
116                let value = eval_expr(item, env)?;
117                if ensure_comparable(&value).is_err() {
118                    return Ok(wrong_shape_value());
119                }
120                if !values.iter().any(|existing| values_equal(existing, &value)) {
121                    values.push(value);
122                }
123            }
124            Ok(Value::Set(values))
125        }
126        Expr::Bag(items) => {
127            let values: Result<Vec<Value>, EvalError> =
128                items.iter().map(|item| eval_expr(item, env)).collect();
129            Ok(Value::Bag(values?))
130        }
131        Expr::Map(entries) => {
132            let mut result = Vec::new();
133            for (k, v) in entries {
134                result.push((k.clone(), eval_expr(v, env)?));
135            }
136            Ok(Value::Map(result))
137        }
138        Expr::Prod(fields) => {
139            let mut result = Vec::new();
140            for (k, v) in fields {
141                result.push((k.clone(), eval_expr(v, env)?));
142            }
143            Ok(Value::Prod(result))
144        }
145        Expr::BagKV(entries) => {
146            let mut result = Vec::new();
147            for (k, v) in entries {
148                result.push((Value::Str(k.clone()), eval_expr(v, env)?));
149            }
150            Ok(Value::BagKV(result))
151        }
152        Expr::Some_(inner) => Ok(Value::Some_(Box::new(eval_expr(inner, env)?))),
153        Expr::None_ => Ok(Value::None_),
154        Expr::Ok_(inner) => Ok(Value::Ok_(Box::new(eval_expr(inner, env)?))),
155        Expr::Fail_(code_expr, msg_expr) => {
156            let code_value = eval_expr(code_expr, env)?;
157            let msg_value = eval_expr(msg_expr, env)?;
158            let code = match code_value {
159                Value::Str(s) => s,
160                other => format!("{other:?}"),
161            };
162            let msg = match msg_value {
163                Value::Str(s) => s,
164                other => format!("{other:?}"),
165            };
166            Ok(Value::Fail_(code, msg))
167        }
168        Expr::Lambda(param, body) => Ok(Value::Lambda(
169            param.clone(),
170            body.clone(),
171            Box::new(env.clone()),
172        )),
173        Expr::Call(func_expr, args) => {
174            let arg_vals: Result<Vec<Value>, EvalError> =
175                args.iter().map(|arg| eval_expr(arg, env)).collect();
176            let arg_vals = arg_vals?;
177
178            if let Expr::Ident(name) = func_expr.as_ref() {
179                if let Some(result) = stdlib::call_stdlib(name, arg_vals.clone()) {
180                    return match result {
181                        Err(EvalError::ArityMismatch { .. }) => Ok(arity_mismatch_value()),
182                        other => other,
183                    };
184                }
185                let func = if let Some(func) = env.get(name).cloned() {
186                    func
187                } else {
188                    return Ok(unbound_name_value());
189                };
190                return apply_lambda(func, arg_vals);
191            }
192
193            let func = eval_expr(func_expr, env)?;
194            apply_lambda(func, arg_vals)
195        }
196        Expr::Pipe(lhs, rhs) => {
197            let lhs_value = eval_expr(lhs, env)?;
198            let mut child_env = env.clone();
199            child_env.insert("_".to_string(), lhs_value);
200            eval_expr(rhs, &child_env)
201        }
202        Expr::Select(obj_expr, field, mode) => {
203            let obj = eval_expr(obj_expr, env)?;
204            eval_select(obj, field, mode)
205        }
206        Expr::UnOp(op, expr) => {
207            let value = eval_expr(expr, env)?;
208            match op {
209                UnOpKind::Neg => match value {
210                    Value::Num(n) => Ok(Value::Num(n.neg())),
211                    _ => Ok(wrong_shape_value()),
212                },
213                UnOpKind::Not => match value {
214                    Value::Bool(b) => Ok(Value::Bool(!b)),
215                    _ => Ok(wrong_shape_value()),
216                },
217            }
218        }
219        Expr::BinOp(op, lhs_expr, rhs_expr) => {
220            let lhs = eval_expr(lhs_expr, env)?;
221            let rhs = eval_expr(rhs_expr, env)?;
222            eval_binop(op, lhs, rhs)
223        }
224        Expr::Comprehension {
225            yield_expr,
226            binding,
227            collection,
228            pred,
229        } => {
230            enum Carrier {
231                Seq,
232                Set,
233                Bag,
234            }
235
236            let coll_val = eval_expr(collection, env)?;
237            let (items, carrier) = match coll_val {
238                Value::Seq(items) => (items, Carrier::Seq),
239                Value::Set(items) => (items, Carrier::Set),
240                Value::Bag(items) => (items, Carrier::Bag),
241                Value::BagKV(entries) => (
242                    entries
243                        .into_iter()
244                        .map(|(key, value)| Value::Bind(Box::new(key), Box::new(value)))
245                        .collect(),
246                    Carrier::Bag,
247                ),
248                _ => return Ok(wrong_shape_value()),
249            };
250
251            let mut results = Vec::new();
252            for item in items {
253                let mut child_env = env.clone();
254                child_env.insert(binding.clone(), item.clone());
255
256                if let Some(pred_expr) = pred {
257                    let pred_val = eval_expr(pred_expr, &child_env)?;
258                    match pred_val {
259                        Value::Bool(false) => continue,
260                        Value::Bool(true) => {}
261                        _ => return Ok(wrong_shape_value()),
262                    }
263                }
264
265                let result = if let Some(yield_expr) = yield_expr {
266                    eval_expr(yield_expr, &child_env)?
267                } else {
268                    item
269                };
270                results.push(result);
271            }
272
273            match carrier {
274                Carrier::Seq => Ok(Value::Seq(results)),
275                Carrier::Bag => Ok(Value::Bag(results)),
276                Carrier::Set => {
277                    let mut dedup = Vec::new();
278                    for value in results {
279                        if ensure_comparable(&value).is_err() {
280                            return Ok(wrong_shape_value());
281                        }
282                        if !dedup.iter().any(|existing| values_equal(existing, &value)) {
283                            dedup.push(value);
284                        }
285                    }
286                    Ok(Value::Set(dedup))
287                }
288            }
289        }
290    }
291}
292
293fn eval_select(obj: Value, field: &str, mode: &SelectMode) -> Result<Value, EvalError> {
294    match &obj {
295        Value::Map(entries) => {
296            let found = entries.iter().find(|(k, _)| k == field).map(|(_, v)| v.clone());
297            match mode {
298                SelectMode::Plain => Ok(Value::Fail_(
299                    "t_sda_wrong_shape".to_string(),
300                    "wrong shape".to_string(),
301                )),
302                SelectMode::Optional => Ok(found
303                    .map(|v| Value::Some_(Box::new(v)))
304                    .unwrap_or(Value::None_)),
305                SelectMode::Required => Ok(found
306                    .map(|v| Value::Ok_(Box::new(v)))
307                    .unwrap_or_else(|| {
308                        Value::Fail_(
309                            "t_sda_missing_key".to_string(),
310                            "missing key".to_string(),
311                        )
312                    })),
313            }
314        }
315        Value::Prod(fields) => {
316            let found = fields.iter().find(|(k, _)| k == field).map(|(_, v)| v.clone());
317            match mode {
318                SelectMode::Plain => Ok(found.unwrap_or_else(|| {
319                    Value::Fail_(
320                        "t_sda_unknown_field".to_string(),
321                        "unknown field".to_string(),
322                    )
323                })),
324                SelectMode::Optional | SelectMode::Required => Ok(Value::Fail_(
325                    "t_sda_wrong_shape".to_string(),
326                    "wrong shape".to_string(),
327                )),
328            }
329        }
330        Value::Bind(key, value) => {
331            let found = match field {
332                "key" => Some((**key).clone()),
333                "val" => Some((**value).clone()),
334                _ => None,
335            };
336            match mode {
337                SelectMode::Plain => Ok(found.unwrap_or(Value::Null)),
338                SelectMode::Optional => Ok(found
339                    .map(|v| Value::Some_(Box::new(v)))
340                    .unwrap_or(Value::None_)),
341                SelectMode::Required => Ok(found
342                    .map(|v| Value::Ok_(Box::new(v)))
343                    .unwrap_or_else(|| {
344                        Value::Fail_(
345                            "t_sda_missing_key".to_string(),
346                            "missing key".to_string(),
347                        )
348                    })),
349            }
350        }
351        Value::BagKV(entries) => {
352            let matches: Vec<_> = entries
353                .iter()
354                .filter(|(k, _)| matches!(k, Value::Str(s) if s == field))
355                .collect();
356            match mode {
357                SelectMode::Plain => Ok(Value::Fail_(
358                    "t_sda_wrong_shape".to_string(),
359                    "wrong shape".to_string(),
360                )),
361                SelectMode::Optional => match matches.len() {
362                    0 => Ok(Value::None_),
363                    1 => Ok(Value::Some_(Box::new(matches[0].1.clone()))),
364                    _ => Ok(Value::None_),
365                },
366                SelectMode::Required => match matches.len() {
367                    0 => Ok(Value::Fail_(
368                        "t_sda_missing_key".to_string(),
369                        "missing key".to_string(),
370                    )),
371                    1 => Ok(Value::Ok_(Box::new(matches[0].1.clone()))),
372                    _ => Ok(Value::Fail_(
373                        "t_sda_duplicate_key".to_string(),
374                        "duplicate key".to_string(),
375                    )),
376                },
377            }
378        }
379        _ => match mode {
380            SelectMode::Optional => Ok(Value::Fail_(
381                "t_sda_wrong_shape".to_string(),
382                "wrong shape".to_string(),
383            )),
384            SelectMode::Required => Ok(Value::Fail_(
385                "t_sda_wrong_shape".to_string(),
386                "wrong shape".to_string(),
387            )),
388            SelectMode::Plain => Ok(Value::Fail_(
389                "t_sda_wrong_shape".to_string(),
390                "wrong shape".to_string(),
391            )),
392        },
393    }
394}
395
396fn eval_binop(op: &BinOpKind, lhs: Value, rhs: Value) -> Result<Value, EvalError> {
397    match op {
398        BinOpKind::Add => match (lhs, rhs) {
399            (Value::Num(a), Value::Num(b)) => Ok(Value::Num(a.add(&b))),
400            _ => Ok(wrong_shape_value()),
401        },
402        BinOpKind::Sub => match (lhs, rhs) {
403            (Value::Num(a), Value::Num(b)) => Ok(Value::Num(a.sub(&b))),
404            _ => Ok(wrong_shape_value()),
405        },
406        BinOpKind::Mul => match (lhs, rhs) {
407            (Value::Num(a), Value::Num(b)) => Ok(Value::Num(a.mul(&b))),
408            _ => Ok(wrong_shape_value()),
409        },
410        BinOpKind::Div => match (lhs, rhs) {
411            (Value::Num(a), Value::Num(b)) => {
412                if b.is_zero() {
413                    Ok(div_by_zero_value())
414                } else {
415                    Ok(Value::Num(a.div(&b)))
416                }
417            }
418            _ => Ok(wrong_shape_value()),
419        },
420        BinOpKind::Concat => match (lhs, rhs) {
421            (Value::Str(a), Value::Str(b)) => Ok(Value::Str(a + &b)),
422            (Value::Seq(mut a), Value::Seq(b)) => {
423                a.extend(b);
424                Ok(Value::Seq(a))
425            }
426            _ => Ok(wrong_shape_value()),
427        },
428        BinOpKind::Eq => {
429            if ensure_comparable(&lhs).is_err() || ensure_comparable(&rhs).is_err() {
430                return Ok(wrong_shape_value());
431            }
432            Ok(Value::Bool(values_equal(&lhs, &rhs)))
433        }
434        BinOpKind::Neq => {
435            if ensure_comparable(&lhs).is_err() || ensure_comparable(&rhs).is_err() {
436                return Ok(wrong_shape_value());
437            }
438            Ok(Value::Bool(!values_equal(&lhs, &rhs)))
439        }
440        BinOpKind::Lt => match (lhs, rhs) {
441            (Value::Num(a), Value::Num(b)) => Ok(Value::Bool(a < b)),
442            (Value::Str(a), Value::Str(b)) => Ok(Value::Bool(a < b)),
443            _ => Ok(wrong_shape_value()),
444        },
445        BinOpKind::Le => match (lhs, rhs) {
446            (Value::Num(a), Value::Num(b)) => Ok(Value::Bool(a <= b)),
447            (Value::Str(a), Value::Str(b)) => Ok(Value::Bool(a <= b)),
448            _ => Ok(wrong_shape_value()),
449        },
450        BinOpKind::Gt => match (lhs, rhs) {
451            (Value::Num(a), Value::Num(b)) => Ok(Value::Bool(a > b)),
452            (Value::Str(a), Value::Str(b)) => Ok(Value::Bool(a > b)),
453            _ => Ok(wrong_shape_value()),
454        },
455        BinOpKind::Ge => match (lhs, rhs) {
456            (Value::Num(a), Value::Num(b)) => Ok(Value::Bool(a >= b)),
457            (Value::Str(a), Value::Str(b)) => Ok(Value::Bool(a >= b)),
458            _ => Ok(wrong_shape_value()),
459        },
460        BinOpKind::And => match (lhs, rhs) {
461            (Value::Bool(a), Value::Bool(b)) => Ok(Value::Bool(a && b)),
462            _ => Ok(wrong_shape_value()),
463        },
464        BinOpKind::Or => match (lhs, rhs) {
465            (Value::Bool(a), Value::Bool(b)) => Ok(Value::Bool(a || b)),
466            _ => Ok(wrong_shape_value()),
467        },
468        BinOpKind::Union => match (lhs, rhs) {
469            (Value::Set(mut a), Value::Set(b)) => {
470                for item in b {
471                    if !a.iter().any(|existing| values_equal(existing, &item)) {
472                        a.push(item);
473                    }
474                }
475                Ok(Value::Set(a))
476            }
477            _ => Ok(wrong_shape_value()),
478        },
479        BinOpKind::Inter => match (lhs, rhs) {
480            (Value::Set(a), Value::Set(b)) => {
481                let result = a
482                    .into_iter()
483                    .filter(|x| b.iter().any(|y| values_equal(x, y)))
484                    .collect();
485                Ok(Value::Set(result))
486            }
487            _ => Ok(wrong_shape_value()),
488        },
489        BinOpKind::Diff => match (lhs, rhs) {
490            (Value::Set(a), Value::Set(b)) => {
491                let result = a
492                    .into_iter()
493                    .filter(|x| !b.iter().any(|y| values_equal(x, y)))
494                    .collect();
495                Ok(Value::Set(result))
496            }
497            _ => Ok(wrong_shape_value()),
498        },
499        BinOpKind::BUnion => match (lhs, rhs) {
500            (Value::Bag(mut a), Value::Bag(b)) => {
501                a.extend(b);
502                Ok(Value::Bag(a))
503            }
504            _ => Ok(wrong_shape_value()),
505        },
506        BinOpKind::BDiff => match (lhs, rhs) {
507            (Value::Bag(a), Value::Bag(b)) => {
508                let mut remaining = b.clone();
509                let result = a
510                    .into_iter()
511                    .filter(|x| {
512                        if let Some(idx) = remaining.iter().position(|y| values_equal(x, y)) {
513                            remaining.remove(idx);
514                            false
515                        } else {
516                            true
517                        }
518                    })
519                    .collect();
520                Ok(Value::Bag(result))
521            }
522            _ => Ok(wrong_shape_value()),
523        },
524        BinOpKind::In => match rhs {
525            Value::Seq(items) => {
526                if ensure_comparable(&lhs).is_err() {
527                    return Ok(wrong_shape_value());
528                }
529                for item in &items {
530                    if ensure_comparable(item).is_err() {
531                        return Ok(wrong_shape_value());
532                    }
533                }
534                Ok(Value::Bool(items.iter().any(|x| values_equal(x, &lhs))))
535            }
536            Value::Set(items) => {
537                if ensure_comparable(&lhs).is_err() {
538                    return Ok(wrong_shape_value());
539                }
540                for item in &items {
541                    if ensure_comparable(item).is_err() {
542                        return Ok(wrong_shape_value());
543                    }
544                }
545                Ok(Value::Bool(items.iter().any(|x| values_equal(x, &lhs))))
546            }
547            Value::Bag(items) => {
548                if ensure_comparable(&lhs).is_err() {
549                    return Ok(wrong_shape_value());
550                }
551                for item in &items {
552                    if ensure_comparable(item).is_err() {
553                        return Ok(wrong_shape_value());
554                    }
555                }
556                Ok(Value::Bool(items.iter().any(|x| values_equal(x, &lhs))))
557            }
558            Value::Map(entries) => {
559                if let Value::Str(key) = &lhs {
560                    Ok(Value::Bool(entries.iter().any(|(k, _)| k == key)))
561                } else {
562                    Ok(wrong_shape_value())
563                }
564            }
565            Value::Prod(fields) => {
566                if let Value::Str(key) = &lhs {
567                    Ok(Value::Bool(fields.iter().any(|(k, _)| k == key)))
568                } else {
569                    Ok(wrong_shape_value())
570                }
571            }
572            _ => Ok(wrong_shape_value()),
573        },
574    }
575}
576
577pub(crate) fn apply_lambda(func: Value, args: Vec<Value>) -> Result<Value, EvalError> {
578    match func {
579        Value::Lambda(param, body, captured_env) => {
580            if args.len() != 1 {
581                return Ok(arity_mismatch_value());
582            }
583            let mut new_env = *captured_env;
584            new_env.insert(param, args.into_iter().next().unwrap());
585            eval_expr(&body, &new_env)
586        }
587        _ => Ok(not_callable_value()),
588    }
589}
590
591pub fn eval_program(program: &Program, env: &mut Env) -> Result<Option<Value>, EvalError> {
592    let mut last = None;
593    for stmt in &program.stmts {
594        match stmt {
595            Stmt::Let(name, expr) => {
596                let value = eval_expr(expr, env)?;
597                env.insert(name.clone(), value);
598                last = None;
599            }
600            Stmt::Expr(expr) => {
601                last = Some(eval_expr(expr, env)?);
602            }
603        }
604    }
605    Ok(last)
606}