cozo_ce/data/
expr.rs

1/*
2 * Copyright 2022, The Cozo Project Authors.
3 *
4 * This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
5 * If a copy of the MPL was not distributed with this file,
6 * You can obtain one at https://mozilla.org/MPL/2.0/.
7 */
8
9use std::cmp::{max, min};
10use std::collections::{BTreeMap, BTreeSet};
11use std::fmt::{Debug, Display, Formatter};
12use std::mem;
13
14use itertools::Itertools;
15use miette::{bail, miette, Diagnostic, Result};
16use serde::de::{Error, Visitor};
17use serde::{Deserializer, Serializer};
18use smartstring::{LazyCompact, SmartString};
19use thiserror::Error;
20
21use crate::data::functions::*;
22use crate::data::relation::NullableColType;
23use crate::data::symb::Symbol;
24use crate::data::value::{DataValue, LARGEST_UTF_CHAR};
25use crate::parse::expr::expr2bytecode;
26use crate::parse::SourceSpan;
27
28#[derive(Clone, PartialEq, Eq, serde_derive::Serialize, serde_derive::Deserialize, Debug)]
29pub enum Bytecode {
30    /// push 1
31    Binding {
32        var: Symbol,
33        tuple_pos: Option<usize>,
34    },
35    /// push 1
36    Const {
37        val: DataValue,
38        #[serde(skip)]
39        span: SourceSpan,
40    },
41    /// pop n, push 1
42    Apply {
43        op: &'static Op,
44        arity: usize,
45        #[serde(skip)]
46        span: SourceSpan,
47    },
48    /// pop 1
49    JumpIfFalse {
50        jump_to: usize,
51        #[serde(skip)]
52        span: SourceSpan,
53    },
54    /// unchanged
55    Goto {
56        jump_to: usize,
57        #[serde(skip)]
58        span: SourceSpan,
59    },
60}
61
62#[derive(Error, Diagnostic, Debug)]
63#[error("The variable '{0}' is unbound")]
64#[diagnostic(code(eval::unbound))]
65struct UnboundVariableError(String, #[label] SourceSpan);
66
67#[derive(Error, Diagnostic, Debug)]
68#[error("The tuple bound by variable '{0}' is too short: index is {1}, length is {2}")]
69#[diagnostic(help("This is definitely a bug. Please report it."))]
70#[diagnostic(code(eval::tuple_too_short))]
71struct TupleTooShortError(String, usize, usize, #[label] SourceSpan);
72
73pub fn eval_bytecode_pred(
74    bytecodes: &[Bytecode],
75    bindings: impl AsRef<[DataValue]>,
76    stack: &mut Vec<DataValue>,
77    span: SourceSpan,
78) -> Result<bool> {
79    match eval_bytecode(bytecodes, bindings, stack)? {
80        DataValue::Bool(b) => Ok(b),
81        v => bail!(PredicateTypeError(span, v)),
82    }
83}
84
85pub fn eval_bytecode(
86    bytecodes: &[Bytecode],
87    bindings: impl AsRef<[DataValue]>,
88    stack: &mut Vec<DataValue>,
89) -> Result<DataValue> {
90    stack.clear();
91    let mut pointer = 0;
92    // for (i, c) in bytecodes.iter().enumerate() {
93    //     println!("{i}  {c:?}");
94    // }
95    // println!();
96    loop {
97        // println!("{pointer}  {stack:?}");
98        if pointer == bytecodes.len() {
99            break;
100        }
101        let current_instruction = &bytecodes[pointer];
102        // println!("{current_instruction:?}");
103        match current_instruction {
104            Bytecode::Binding { var, tuple_pos, .. } => match tuple_pos {
105                None => {
106                    bail!(UnboundVariableError(var.name.to_string(), var.span))
107                }
108                Some(i) => {
109                    let val = bindings
110                        .as_ref()
111                        .get(*i)
112                        .ok_or_else(|| {
113                            TupleTooShortError(
114                                var.name.to_string(),
115                                *i,
116                                bindings.as_ref().len(),
117                                var.span,
118                            )
119                        })?
120                        .clone();
121                    stack.push(val);
122                    pointer += 1;
123                }
124            },
125            Bytecode::Const { val, .. } => {
126                stack.push(val.clone());
127                pointer += 1;
128            }
129            Bytecode::Apply { op, arity, span } => {
130                let frame_start = stack.len() - *arity;
131                let args_frame = &stack[frame_start..];
132                let result = (op.inner)(args_frame)
133                    .map_err(|err| EvalRaisedError(*span, err.to_string()))?;
134                stack.truncate(frame_start);
135                stack.push(result);
136                pointer += 1;
137            }
138            Bytecode::JumpIfFalse { jump_to, span } => {
139                let val = stack.pop().unwrap();
140                let cond = val
141                    .get_bool()
142                    .ok_or_else(|| PredicateTypeError(*span, val))?;
143                if cond {
144                    pointer += 1;
145                } else {
146                    pointer = *jump_to;
147                }
148            }
149            Bytecode::Goto { jump_to, .. } => {
150                pointer = *jump_to;
151            }
152        }
153    }
154    Ok(stack.pop().unwrap())
155}
156
157/// Expression can be evaluated to yield a DataValue
158#[derive(Clone, PartialEq, Eq, serde_derive::Serialize, serde_derive::Deserialize)]
159pub enum Expr {
160    /// Binding to variables
161    Binding {
162        /// The variable name to bind
163        var: Symbol,
164        /// When executing in the context of a tuple, the position of the binding within the tuple
165        tuple_pos: Option<usize>,
166    },
167    /// Constant expression containing a value
168    Const {
169        /// The value
170        val: DataValue,
171        /// Source span
172        #[serde(skip)]
173        span: SourceSpan,
174    },
175    /// Function application
176    Apply {
177        /// Op representing the function to apply
178        op: &'static Op,
179        /// Arguments to the application
180        args: Box<[Expr]>,
181        /// Source span
182        #[serde(skip)]
183        span: SourceSpan,
184    },
185    /// Unbound function application
186    UnboundApply {
187        /// Op representing the function to apply
188        op: SmartString<LazyCompact>,
189        /// Arguments to the application
190        args: Box<[Expr]>,
191        /// Source span
192        #[serde(skip)]
193        span: SourceSpan,
194    },
195    /// Conditional expressions
196    Cond {
197        /// Conditional clauses, the first expression in each tuple should evaluate to a boolean
198        clauses: Vec<(Expr, Expr)>,
199        /// Source span
200        #[serde(skip)]
201        span: SourceSpan,
202    },
203}
204
205impl Debug for Expr {
206    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
207        write!(f, "{self}")
208    }
209}
210
211impl Display for Expr {
212    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
213        match self {
214            Expr::Binding { var, .. } => {
215                write!(f, "{}", var.name)
216            }
217            Expr::Const { val, .. } => {
218                write!(f, "{val}")
219            }
220            Expr::Apply { op, args, .. } => {
221                let mut writer =
222                    f.debug_tuple(op.name.strip_prefix("OP_").unwrap().to_lowercase().as_str());
223                for arg in args.iter() {
224                    writer.field(arg);
225                }
226                writer.finish()
227            }
228            Expr::UnboundApply { op, args, .. } => {
229                let mut writer = f.debug_tuple(op);
230                for arg in args.iter() {
231                    writer.field(arg);
232                }
233                writer.finish()
234            }
235            Expr::Cond { clauses, .. } => {
236                let mut writer = f.debug_tuple("cond");
237                for (cond, expr) in clauses {
238                    writer.field(cond);
239                    writer.field(expr);
240                }
241                writer.finish()
242            }
243        }
244    }
245}
246
247#[derive(Debug, Error, Diagnostic)]
248#[error("No implementation found for op `{1}`")]
249#[diagnostic(code(eval::no_implementation))]
250pub(crate) struct NoImplementationError(#[label] pub(crate) SourceSpan, pub(crate) String);
251
252#[derive(Debug, Error, Diagnostic)]
253#[error("Found value {1:?} where a boolean value is expected")]
254#[diagnostic(code(eval::predicate_not_bool))]
255pub(crate) struct PredicateTypeError(#[label] pub(crate) SourceSpan, pub(crate) DataValue);
256
257#[derive(Debug, Error, Diagnostic)]
258#[error("Cannot build entity ID from {0:?}")]
259#[diagnostic(code(parser::bad_eid))]
260#[diagnostic(help("Entity ID should be an integer satisfying certain constraints"))]
261struct BadEntityId(DataValue, #[label] SourceSpan);
262
263#[derive(Error, Diagnostic, Debug)]
264#[error("Evaluation of expression failed")]
265#[diagnostic(code(eval::throw))]
266struct EvalRaisedError(#[label] SourceSpan, #[help] String);
267
268impl Expr {
269    pub(crate) fn compile(&self) -> Result<Vec<Bytecode>> {
270        let mut collector = vec![];
271        expr2bytecode(self, &mut collector)?;
272        Ok(collector)
273    }
274    pub(crate) fn span(&self) -> SourceSpan {
275        match self {
276            Expr::Binding { var, .. } => var.span,
277            Expr::Const { span, .. } | Expr::Apply { span, .. } | Expr::Cond { span, .. } => *span,
278            Expr::UnboundApply { span, .. } => *span,
279        }
280    }
281    pub(crate) fn get_binding(&self) -> Option<&Symbol> {
282        if let Expr::Binding { var, .. } = self {
283            Some(var)
284        } else {
285            None
286        }
287    }
288    pub(crate) fn get_const(&self) -> Option<&DataValue> {
289        if let Expr::Const { val, .. } = self {
290            Some(val)
291        } else {
292            None
293        }
294    }
295    pub(crate) fn build_equate(exprs: Vec<Expr>, span: SourceSpan) -> Self {
296        Expr::Apply {
297            op: &OP_EQ,
298            args: exprs.into(),
299            span,
300        }
301    }
302    pub(crate) fn build_and(exprs: Vec<Expr>, span: SourceSpan) -> Self {
303        Expr::Apply {
304            op: &OP_AND,
305            args: exprs.into(),
306            span,
307        }
308    }
309    pub(crate) fn build_is_in(exprs: Vec<Expr>, span: SourceSpan) -> Self {
310        Expr::Apply {
311            op: &OP_IS_IN,
312            args: exprs.into(),
313            span,
314        }
315    }
316    pub(crate) fn negate(self, span: SourceSpan) -> Self {
317        Expr::Apply {
318            op: &OP_NEGATE,
319            args: Box::new([self]),
320            span,
321        }
322    }
323    pub(crate) fn to_conjunction(&self) -> Vec<Self> {
324        match self {
325            Expr::Apply { op, args, .. } if **op == OP_AND => args.to_vec(),
326            v => vec![v.clone()],
327        }
328    }
329    pub(crate) fn fill_binding_indices(
330        &mut self,
331        binding_map: &BTreeMap<Symbol, usize>,
332    ) -> Result<()> {
333        match self {
334            Expr::Binding { var, tuple_pos, .. } => {
335                #[derive(Debug, Error, Diagnostic)]
336                #[error("Cannot find binding {0}")]
337                #[diagnostic(code(eval::bad_binding))]
338                struct BadBindingError(String, #[label] SourceSpan);
339
340                let found_idx = *binding_map
341                    .get(var)
342                    .ok_or_else(|| BadBindingError(var.to_string(), var.span))?;
343                *tuple_pos = Some(found_idx)
344            }
345            Expr::Const { .. } => {}
346            Expr::Apply { args, .. } => {
347                for arg in args.iter_mut() {
348                    arg.fill_binding_indices(binding_map)?;
349                }
350            }
351            Expr::Cond { clauses, .. } => {
352                for (cond, val) in clauses {
353                    cond.fill_binding_indices(binding_map)?;
354                    val.fill_binding_indices(binding_map)?;
355                }
356            }
357            Expr::UnboundApply { op, span, .. } => {
358                bail!(NoImplementationError(*span, op.to_string()));
359            }
360        }
361        Ok(())
362    }
363    #[allow(dead_code)]
364    pub(crate) fn binding_indices(&self) -> Result<BTreeSet<usize>> {
365        let mut ret = BTreeSet::default();
366        self.do_binding_indices(&mut ret)?;
367        Ok(ret)
368    }
369    #[allow(dead_code)]
370    fn do_binding_indices(&self, coll: &mut BTreeSet<usize>) -> Result<()> {
371        match self {
372            Expr::Binding { tuple_pos, .. } => {
373                if let Some(idx) = tuple_pos {
374                    coll.insert(*idx);
375                }
376            }
377            Expr::Const { .. } => {}
378            Expr::Apply { args, .. } => {
379                for arg in args.iter() {
380                    arg.do_binding_indices(coll)?;
381                }
382            }
383            Expr::Cond { clauses, .. } => {
384                for (cond, val) in clauses {
385                    cond.do_binding_indices(coll)?;
386                    val.do_binding_indices(coll)?;
387                }
388            } // Expr::Try { clauses, .. } => {
389            //     for clause in clauses {
390            //         clause.do_binding_indices(coll)
391            //     }
392            // }
393            Expr::UnboundApply { op, span, .. } => {
394                bail!(NoImplementationError(*span, op.to_string()));
395            }
396        }
397        Ok(())
398    }
399    /// Evaluate the expression to a constant value if possible
400    pub fn eval_to_const(mut self) -> Result<DataValue> {
401        #[derive(Error, Diagnostic, Debug)]
402        #[error("Expression contains unevaluated constant")]
403        #[diagnostic(code(eval::not_constant))]
404        struct NotConstError;
405
406        self.partial_eval()?;
407        match self {
408            Expr::Const { val, .. } => Ok(val),
409            _ => bail!(NotConstError),
410        }
411    }
412    pub(crate) fn partial_eval(&mut self) -> Result<()> {
413        if let Expr::Apply { args, span, .. } = self {
414            let span = *span;
415            let mut all_evaluated = true;
416            for arg in args.iter_mut() {
417                arg.partial_eval()?;
418                all_evaluated = all_evaluated && matches!(arg, Expr::Const { .. });
419            }
420            if all_evaluated {
421                let result = self.eval(&vec![])?;
422                mem::swap(self, &mut Expr::Const { val: result, span });
423            }
424            // nested not's can accumulate during conversion to normal form
425            if let Expr::Apply {
426                op: op1,
427                args: arg1,
428                ..
429            } = self
430            {
431                if op1.name == OP_NEGATE.name {
432                    if let Some(Expr::Apply {
433                        op: op2,
434                        args: arg2,
435                        ..
436                    }) = arg1.first()
437                    {
438                        if op2.name == OP_NEGATE.name {
439                            let mut new_self = arg2[0].clone();
440                            mem::swap(self, &mut new_self);
441                        }
442                    }
443                }
444            }
445        }
446        Ok(())
447    }
448    pub(crate) fn bindings(&self) -> Result<BTreeSet<Symbol>> {
449        let mut ret = BTreeSet::new();
450        self.collect_bindings(&mut ret)?;
451        Ok(ret)
452    }
453    pub(crate) fn collect_bindings(&self, coll: &mut BTreeSet<Symbol>) -> Result<()> {
454        match self {
455            Expr::Binding { var, .. } => {
456                coll.insert(var.clone());
457            }
458            Expr::Const { .. } => {}
459            Expr::Apply { args, .. } => {
460                for arg in args.iter() {
461                    arg.collect_bindings(coll)?;
462                }
463            }
464            Expr::Cond { clauses, .. } => {
465                for (cond, val) in clauses {
466                    cond.collect_bindings(coll)?;
467                    val.collect_bindings(coll)?;
468                }
469            }
470            Expr::UnboundApply { op, span, .. } => {
471                bail!(NoImplementationError(*span, op.to_string()));
472            }
473        }
474        Ok(())
475    }
476    pub(crate) fn eval(&self, bindings: impl AsRef<[DataValue]>) -> Result<DataValue> {
477        match self {
478            Expr::Binding { var, tuple_pos, .. } => match tuple_pos {
479                None => {
480                    bail!(UnboundVariableError(var.name.to_string(), var.span))
481                }
482                Some(i) => Ok(bindings
483                    .as_ref()
484                    .get(*i)
485                    .ok_or_else(|| {
486                        TupleTooShortError(
487                            var.name.to_string(),
488                            *i,
489                            bindings.as_ref().len(),
490                            var.span,
491                        )
492                    })?
493                    .clone()),
494            },
495            Expr::Const { val, .. } => Ok(val.clone()),
496            Expr::Apply { op, args, .. } => {
497                let args: Box<[DataValue]> = args
498                    .iter()
499                    .map(|v| v.eval(bindings.as_ref()))
500                    .try_collect()?;
501                Ok((op.inner)(&args)
502                    .map_err(|err| EvalRaisedError(self.span(), err.to_string()))?)
503            }
504            Expr::Cond { clauses, .. } => {
505                for (cond, val) in clauses {
506                    let cond_val = cond.eval(bindings.as_ref())?;
507                    let cond_val = cond_val
508                        .get_bool()
509                        .ok_or_else(|| PredicateTypeError(cond.span(), cond_val))?;
510
511                    if cond_val {
512                        return val.eval(bindings.as_ref());
513                    }
514                }
515                Ok(DataValue::Null)
516            }
517            Expr::UnboundApply { op, span, .. } => {
518                bail!(NoImplementationError(*span, op.to_string()));
519            }
520        }
521    }
522    pub(crate) fn extract_bound(&self, target: &Symbol) -> Result<ValueRange> {
523        Ok(match self {
524            Expr::Binding { .. } | Expr::Const { .. } | Expr::Cond { .. } => ValueRange::default(),
525            Expr::Apply { op, args, .. } => match op.name {
526                n if n == OP_GE.name || n == OP_GT.name => {
527                    if let Some(symb) = args[0].get_binding() {
528                        if let Some(val) = args[1].get_const() {
529                            if target == symb {
530                                let tar_val = match val.get_int() {
531                                    Some(i) => DataValue::from(i),
532                                    None => val.clone(),
533                                };
534                                return Ok(ValueRange::lower_bound(tar_val));
535                            }
536                        }
537                    }
538                    if let Some(symb) = args[1].get_binding() {
539                        if let Some(val) = args[0].get_const() {
540                            if target == symb {
541                                let tar_val = match val.get_float() {
542                                    Some(i) => DataValue::from(i),
543                                    None => val.clone(),
544                                };
545                                return Ok(ValueRange::upper_bound(tar_val));
546                            }
547                        }
548                    }
549                    ValueRange::default()
550                }
551                n if n == OP_LE.name || n == OP_LT.name => {
552                    if let Some(symb) = args[0].get_binding() {
553                        if let Some(val) = args[1].get_const() {
554                            if target == symb {
555                                let tar_val = match val.get_float() {
556                                    Some(i) => DataValue::from(i),
557                                    None => val.clone(),
558                                };
559
560                                return Ok(ValueRange::upper_bound(tar_val));
561                            }
562                        }
563                    }
564                    if let Some(symb) = args[1].get_binding() {
565                        if let Some(val) = args[0].get_const() {
566                            if target == symb {
567                                let tar_val = match val.get_int() {
568                                    Some(i) => DataValue::from(i),
569                                    None => val.clone(),
570                                };
571
572                                return Ok(ValueRange::lower_bound(tar_val));
573                            }
574                        }
575                    }
576                    ValueRange::default()
577                }
578                n if n == OP_STARTS_WITH.name => {
579                    if let Some(symb) = args[0].get_binding() {
580                        if let Some(val) = args[1].get_const() {
581                            if target == symb {
582                                let s = val.get_str().ok_or_else(|| {
583                                    #[derive(Debug, Error, Diagnostic)]
584                                    #[error("Cannot prefix scan with {0:?}")]
585                                    #[diagnostic(code(eval::bad_string_range_scan))]
586                                    #[diagnostic(help("A string argument is required"))]
587                                    struct StrRangeScanError(DataValue, #[label] SourceSpan);
588
589                                    StrRangeScanError(val.clone(), symb.span)
590                                })?;
591                                let lower = DataValue::from(s);
592                                // let lower = DataValue::Str(s.to_string());
593                                let mut upper = SmartString::from(s);
594                                // let mut upper = s.to_string();
595                                upper.push(LARGEST_UTF_CHAR);
596                                let upper = DataValue::Str(upper);
597                                return Ok(ValueRange::new(lower, upper));
598                            }
599                        }
600                    }
601                    ValueRange::default()
602                }
603                _ => ValueRange::default(),
604            },
605            Expr::UnboundApply { op, span, .. } => {
606                bail!(NoImplementationError(*span, op.to_string()));
607            }
608        })
609    }
610    pub(crate) fn get_variables(&self) -> Result<BTreeSet<String>> {
611        let mut ret = BTreeSet::new();
612        self.do_get_variables(&mut ret)?;
613        Ok(ret)
614    }
615    fn do_get_variables(&self, coll: &mut BTreeSet<String>) -> Result<()> {
616        match self {
617            Expr::Binding { var, .. } => {
618                coll.insert(var.to_string());
619            }
620            Expr::Const { .. } => {}
621            Expr::Apply { args, .. } => {
622                for arg in args.iter() {
623                    arg.do_get_variables(coll)?;
624                }
625            }
626            Expr::Cond { clauses, .. } => {
627                for (cond, act) in clauses.iter() {
628                    cond.do_get_variables(coll)?;
629                    act.do_get_variables(coll)?;
630                }
631            }
632            Expr::UnboundApply { op, span, .. } => {
633                bail!(NoImplementationError(*span, op.to_string()));
634            }
635        }
636        Ok(())
637    }
638    pub(crate) fn to_var_list(&self) -> Result<Vec<SmartString<LazyCompact>>> {
639        match self {
640            Expr::Apply { op, args, .. } => {
641                if op.name != "OP_LIST" {
642                    Err(miette!("Invalid fields op: {} for {}", op.name, self))
643                } else {
644                    let mut collected = vec![];
645                    for field in args.iter() {
646                        match field {
647                            Expr::Binding { var, .. } => collected.push(var.name.clone()),
648                            _ => return Err(miette!("Invalid field element: {}", field)),
649                        }
650                    }
651                    Ok(collected)
652                }
653            }
654            Expr::Binding { var, .. } => Ok(vec![var.name.clone()]),
655            _ => Err(miette!("Invalid fields: {}", self)),
656        }
657    }
658}
659
660pub(crate) fn compute_bounds(
661    filters: &[Expr],
662    symbols: &[Symbol],
663) -> Result<(Vec<DataValue>, Vec<DataValue>)> {
664    let mut lowers = vec![];
665    let mut uppers = vec![];
666    for current in symbols {
667        let mut cur_bound = ValueRange::default();
668        for filter in filters {
669            let nxt = filter.extract_bound(current)?;
670            cur_bound = cur_bound.merge(nxt);
671        }
672        lowers.push(cur_bound.lower);
673        uppers.push(cur_bound.upper);
674    }
675
676    Ok((lowers, uppers))
677}
678
679#[derive(Clone, Debug, Eq, PartialEq)]
680pub(crate) struct ValueRange {
681    pub(crate) lower: DataValue,
682    pub(crate) upper: DataValue,
683}
684
685impl ValueRange {
686    fn merge(self, other: Self) -> Self {
687        let lower = max(self.lower, other.lower);
688        let upper = min(self.upper, other.upper);
689        if lower > upper {
690            Self::null()
691        } else {
692            Self { lower, upper }
693        }
694    }
695    fn null() -> Self {
696        Self {
697            lower: DataValue::Bot,
698            upper: DataValue::Bot,
699        }
700    }
701    fn new(lower: DataValue, upper: DataValue) -> Self {
702        Self { lower, upper }
703    }
704    fn lower_bound(val: DataValue) -> Self {
705        Self {
706            lower: val,
707            upper: DataValue::Bot,
708        }
709    }
710    fn upper_bound(val: DataValue) -> Self {
711        Self {
712            lower: DataValue::Null,
713            upper: val,
714        }
715    }
716}
717
718impl Default for ValueRange {
719    fn default() -> Self {
720        Self {
721            lower: DataValue::Null,
722            upper: DataValue::Bot,
723        }
724    }
725}
726
727#[derive(Clone)]
728pub struct Op {
729    pub(crate) name: &'static str,
730    pub(crate) min_arity: usize,
731    pub(crate) vararg: bool,
732    pub(crate) inner: fn(&[DataValue]) -> Result<DataValue>,
733}
734
735/// Used as `Arc<dyn CustomOp>`
736pub trait CustomOp {
737    fn name(&self) -> &'static str;
738    fn min_arity(&self) -> usize;
739    fn vararg(&self) -> bool;
740    fn return_type(&self) -> NullableColType;
741    fn call(&self, args: &[DataValue]) -> Result<DataValue>;
742}
743
744impl serde::Serialize for &'_ Op {
745    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
746    where
747        S: Serializer,
748    {
749        serializer.serialize_str(self.name)
750    }
751}
752
753impl<'de> serde::Deserialize<'de> for &'static Op {
754    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
755    where
756        D: Deserializer<'de>,
757    {
758        deserializer.deserialize_str(OpVisitor)
759    }
760}
761
762struct OpVisitor;
763
764impl<'de> Visitor<'de> for OpVisitor {
765    type Value = &'static Op;
766
767    fn expecting(&self, formatter: &mut Formatter<'_>) -> std::fmt::Result {
768        formatter.write_str("name of the op")
769    }
770
771    fn visit_str<E>(self, v: &str) -> std::result::Result<Self::Value, E>
772    where
773        E: Error,
774    {
775        let name = v.strip_prefix("OP_").unwrap().to_ascii_lowercase();
776        get_op(&name).ok_or_else(|| E::custom(format!("op not found in serialized data: {v}")))
777    }
778}
779
780impl PartialEq for Op {
781    fn eq(&self, other: &Self) -> bool {
782        self.name == other.name
783    }
784}
785
786impl Eq for Op {}
787
788impl Debug for Op {
789    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
790        write!(f, "{}", self.name)
791    }
792}
793
794pub(crate) fn get_op(name: &str) -> Option<&'static Op> {
795    Some(match name {
796        "coalesce" => &OP_COALESCE,
797        "list" => &OP_LIST,
798        "json" => &OP_JSON,
799        "set_json_path" => &OP_SET_JSON_PATH,
800        "remove_json_path" => &OP_REMOVE_JSON_PATH,
801        "parse_json" => &OP_PARSE_JSON,
802        "dump_json" => &OP_DUMP_JSON,
803        "json_object" => &OP_JSON_OBJECT,
804        "is_json" => &OP_IS_JSON,
805        "json_to_scalar" => &OP_JSON_TO_SCALAR,
806        "add" => &OP_ADD,
807        "sub" => &OP_SUB,
808        "mul" => &OP_MUL,
809        "div" => &OP_DIV,
810        "minus" => &OP_MINUS,
811        "abs" => &OP_ABS,
812        "signum" => &OP_SIGNUM,
813        "floor" => &OP_FLOOR,
814        "ceil" => &OP_CEIL,
815        "round" => &OP_ROUND,
816        "mod" => &OP_MOD,
817        "max" => &OP_MAX,
818        "min" => &OP_MIN,
819        "pow" => &OP_POW,
820        "sqrt" => &OP_SQRT,
821        "exp" => &OP_EXP,
822        "exp2" => &OP_EXP2,
823        "ln" => &OP_LN,
824        "log2" => &OP_LOG2,
825        "log10" => &OP_LOG10,
826        "sin" => &OP_SIN,
827        "cos" => &OP_COS,
828        "tan" => &OP_TAN,
829        "asin" => &OP_ASIN,
830        "acos" => &OP_ACOS,
831        "atan" => &OP_ATAN,
832        "atan2" => &OP_ATAN2,
833        "sinh" => &OP_SINH,
834        "cosh" => &OP_COSH,
835        "tanh" => &OP_TANH,
836        "asinh" => &OP_ASINH,
837        "acosh" => &OP_ACOSH,
838        "atanh" => &OP_ATANH,
839        "eq" => &OP_EQ,
840        "neq" => &OP_NEQ,
841        "gt" => &OP_GT,
842        "ge" => &OP_GE,
843        "lt" => &OP_LT,
844        "le" => &OP_LE,
845        "or" => &OP_OR,
846        "and" => &OP_AND,
847        "negate" => &OP_NEGATE,
848        "bit_and" => &OP_BIT_AND,
849        "bit_or" => &OP_BIT_OR,
850        "bit_not" => &OP_BIT_NOT,
851        "bit_xor" => &OP_BIT_XOR,
852        "pack_bits" => &OP_PACK_BITS,
853        "unpack_bits" => &OP_UNPACK_BITS,
854        "concat" => &OP_CONCAT,
855        "str_includes" => &OP_STR_INCLUDES,
856        "lowercase" => &OP_LOWERCASE,
857        "uppercase" => &OP_UPPERCASE,
858        "trim" => &OP_TRIM,
859        "trim_start" => &OP_TRIM_START,
860        "trim_end" => &OP_TRIM_END,
861        "starts_with" => &OP_STARTS_WITH,
862        "ends_with" => &OP_ENDS_WITH,
863        "is_null" => &OP_IS_NULL,
864        "is_int" => &OP_IS_INT,
865        "is_float" => &OP_IS_FLOAT,
866        "is_num" => &OP_IS_NUM,
867        "is_string" => &OP_IS_STRING,
868        "is_list" => &OP_IS_LIST,
869        "is_bytes" => &OP_IS_BYTES,
870        "is_in" => &OP_IS_IN,
871        "is_finite" => &OP_IS_FINITE,
872        "is_infinite" => &OP_IS_INFINITE,
873        "is_nan" => &OP_IS_NAN,
874        "is_uuid" => &OP_IS_UUID,
875        "is_vec" => &OP_IS_VEC,
876        "length" => &OP_LENGTH,
877        "sorted" => &OP_SORTED,
878        "reverse" => &OP_REVERSE,
879        "append" => &OP_APPEND,
880        "prepend" => &OP_PREPEND,
881        "unicode_normalize" => &OP_UNICODE_NORMALIZE,
882        "haversine" => &OP_HAVERSINE,
883        "haversine_deg_input" => &OP_HAVERSINE_DEG_INPUT,
884        "deg_to_rad" => &OP_DEG_TO_RAD,
885        "rad_to_deg" => &OP_RAD_TO_DEG,
886        "get" => &OP_GET,
887        "maybe_get" => &OP_MAYBE_GET,
888        "chars" => &OP_CHARS,
889        "slice_string" => &OP_SLICE_STRING,
890        "from_substrings" => &OP_FROM_SUBSTRINGS,
891        "slice" => &OP_SLICE,
892        "regex_matches" => &OP_REGEX_MATCHES,
893        "regex_replace" => &OP_REGEX_REPLACE,
894        "regex_replace_all" => &OP_REGEX_REPLACE_ALL,
895        "regex_extract" => &OP_REGEX_EXTRACT,
896        "regex_extract_first" => &OP_REGEX_EXTRACT_FIRST,
897        "t2s" => &OP_T2S,
898        "encode_base64" => &OP_ENCODE_BASE64,
899        "decode_base64" => &OP_DECODE_BASE64,
900        "first" => &OP_FIRST,
901        "last" => &OP_LAST,
902        "chunks" => &OP_CHUNKS,
903        "chunks_exact" => &OP_CHUNKS_EXACT,
904        "windows" => &OP_WINDOWS,
905        "to_int" => &OP_TO_INT,
906        "to_float" => &OP_TO_FLOAT,
907        "to_string" => &OP_TO_STRING,
908        "l2_dist" => &OP_L2_DIST,
909        "l2_normalize" => &OP_L2_NORMALIZE,
910        "ip_dist" => &OP_IP_DIST,
911        "cos_dist" => &OP_COS_DIST,
912        "int_range" => &OP_INT_RANGE,
913        "rand_float" => &OP_RAND_FLOAT,
914        "rand_bernoulli" => &OP_RAND_BERNOULLI,
915        "rand_int" => &OP_RAND_INT,
916        "rand_choose" => &OP_RAND_CHOOSE,
917        "assert" => &OP_ASSERT,
918        "union" => &OP_UNION,
919        "intersection" => &OP_INTERSECTION,
920        "difference" => &OP_DIFFERENCE,
921        "to_uuid" => &OP_TO_UUID,
922        "to_bool" => &OP_TO_BOOL,
923        "to_unity" => &OP_TO_UNITY,
924        "rand_uuid_v1" => &OP_RAND_UUID_V1,
925        "rand_uuid_v4" => &OP_RAND_UUID_V4,
926        "uuid_timestamp" => &OP_UUID_TIMESTAMP,
927        "validity" => &OP_VALIDITY,
928        "now" => &OP_NOW,
929        "format_timestamp" => &OP_FORMAT_TIMESTAMP,
930        "parse_timestamp" => &OP_PARSE_TIMESTAMP,
931        "vec" => &OP_VEC,
932        "rand_vec" => &OP_RAND_VEC,
933        _ => return None,
934    })
935}
936
937impl Op {
938    pub(crate) fn post_process_args(&self, args: &mut [Expr]) {
939        if self.name.starts_with("OP_REGEX_") {
940            args[1] = Expr::Apply {
941                op: &OP_REGEX,
942                args: [args[1].clone()].into(),
943                span: args[1].span(),
944            }
945        }
946    }
947}