cozo_ce/data/
program.rs

1/*
2 * Copyright 2022, The Cozo Project Authors.
3 *
4 * This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
5 * If a copy of the MPL was not distributed with this file,
6 * You can obtain one at https://mozilla.org/MPL/2.0/.
7 */
8
9use std::collections::btree_map::Entry;
10use std::collections::{BTreeMap, BTreeSet};
11use std::fmt::{Debug, Display, Formatter};
12use std::sync::Arc;
13
14use miette::{bail, ensure, miette, Diagnostic, Result};
15use smallvec::SmallVec;
16use smartstring::{LazyCompact, SmartString};
17use thiserror::Error;
18
19use crate::data::aggr::Aggregation;
20use crate::data::expr::Expr;
21use crate::data::relation::StoredRelationMetadata;
22use crate::data::symb::{Symbol, PROG_ENTRY};
23use crate::data::value::{DataValue, ValidityTs};
24use crate::fixed_rule::{FixedRule, FixedRuleHandle};
25use crate::fts::FtsIndexManifest;
26use crate::parse::SourceSpan;
27use crate::query::compile::ContainedRuleMultiplicity;
28use crate::query::logical::{Disjunction, NamedFieldNotFound};
29use crate::runtime::hnsw::HnswIndexManifest;
30use crate::runtime::minhash_lsh::{LshSearch, MinHashLshIndexManifest};
31use crate::runtime::relation::{
32    AccessLevel, InputRelationHandle, InsufficientAccessLevel, RelationHandle,
33};
34use crate::runtime::temp_store::EpochStore;
35use crate::runtime::transact::SessionTx;
36
37#[derive(Debug, Clone, Eq, PartialEq)]
38pub(crate) enum QueryAssertion {
39    AssertNone(SourceSpan),
40    AssertSome(SourceSpan),
41}
42
43#[derive(Debug, Copy, Clone, Eq, PartialEq)]
44pub(crate) enum ReturnMutation {
45    NotReturning,
46    Returning,
47}
48
49#[derive(Clone, PartialEq, Default)]
50pub struct QueryOutOptions {
51    pub limit: Option<usize>,
52    pub offset: Option<usize>,
53    /// Terminate query with an error if it exceeds this many seconds.
54    pub timeout: Option<f64>,
55    /// Sleep after performing the query for this number of seconds. Ignored in WASM.
56    pub sleep: Option<f64>,
57    pub sorters: Vec<(Symbol, SortDir)>,
58    pub store_relation: Option<(InputRelationHandle, RelationOp, ReturnMutation)>,
59    pub assertion: Option<QueryAssertion>,
60}
61
62impl Debug for QueryOutOptions {
63    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
64        write!(f, "{self}")
65    }
66}
67
68impl Display for QueryOutOptions {
69    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
70        if let Some(l) = self.limit {
71            writeln!(f, ":limit {l};")?;
72        }
73        if let Some(l) = self.offset {
74            writeln!(f, ":offset {l};")?;
75        }
76        if let Some(l) = self.timeout {
77            writeln!(f, ":timeout {l};")?;
78        }
79        for (symb, dir) in &self.sorters {
80            write!(f, ":order ")?;
81            if *dir == SortDir::Dsc {
82                write!(f, "-")?;
83            }
84            writeln!(f, "{symb};")?;
85        }
86        if let Some((
87            InputRelationHandle {
88                name,
89                metadata: StoredRelationMetadata { keys, non_keys },
90                key_bindings,
91                dep_bindings,
92                ..
93            },
94            op,
95            return_mutation,
96        )) = &self.store_relation
97        {
98            if *return_mutation == ReturnMutation::Returning {
99                writeln!(f, ":returning")?;
100            }
101            match op {
102                RelationOp::Create => {
103                    write!(f, ":create ")?;
104                }
105                RelationOp::Replace => {
106                    write!(f, ":replace ")?;
107                }
108                RelationOp::Insert => {
109                    write!(f, ":insert ")?;
110                }
111                RelationOp::Put => {
112                    write!(f, ":put ")?;
113                }
114                RelationOp::Update => {
115                    write!(f, ":update ")?;
116                }
117                RelationOp::Rm => {
118                    write!(f, ":rm ")?;
119                }
120                RelationOp::Delete => {
121                    write!(f, ":delete ")?;
122                }
123                RelationOp::Ensure => {
124                    write!(f, ":ensure ")?;
125                }
126                RelationOp::EnsureNot => {
127                    write!(f, ":ensure_not ")?;
128                }
129            }
130            write!(f, "{name} {{")?;
131            let mut is_first = true;
132            for (col, bind) in keys.iter().zip(key_bindings) {
133                if is_first {
134                    is_first = false
135                } else {
136                    write!(f, ", ")?;
137                }
138                write!(f, "{}: {}", col.name, col.typing)?;
139                if let Some(gen) = &col.default_gen {
140                    write!(f, " default {gen}")?;
141                } else {
142                    write!(f, " = {bind}")?;
143                }
144            }
145            write!(f, " => ")?;
146            let mut is_first = true;
147            for (col, bind) in non_keys.iter().zip(dep_bindings) {
148                if is_first {
149                    is_first = false
150                } else {
151                    write!(f, ", ")?;
152                }
153                write!(f, "{}: {}", col.name, col.typing)?;
154                if let Some(gen) = &col.default_gen {
155                    write!(f, " default {gen}")?;
156                } else {
157                    write!(f, " = {bind}")?;
158                }
159            }
160            writeln!(f, "}};")?;
161        }
162
163        if let Some(a) = &self.assertion {
164            match a {
165                QueryAssertion::AssertNone(_) => {
166                    writeln!(f, ":assert none;")?;
167                }
168                QueryAssertion::AssertSome(_) => {
169                    writeln!(f, ":assert some;")?;
170                }
171            }
172        }
173
174        Ok(())
175    }
176}
177
178impl QueryOutOptions {
179    pub(crate) fn num_to_take(&self) -> Option<usize> {
180        match (self.limit, self.offset) {
181            (None, _) => None,
182            (Some(i), None) => Some(i),
183            (Some(i), Some(j)) => Some(i + j),
184        }
185    }
186}
187
188#[derive(Debug, Copy, Clone, Eq, PartialEq)]
189pub enum SortDir {
190    Asc,
191    Dsc,
192}
193
194#[derive(Debug, Copy, Clone, Eq, PartialEq)]
195pub enum RelationOp {
196    Create,
197    Replace,
198    Put,
199    Insert,
200    Update,
201    Rm,
202    Delete,
203    Ensure,
204    EnsureNot,
205}
206
207#[derive(Default)]
208pub(crate) struct TempSymbGen {
209    last_id: u32,
210}
211
212impl TempSymbGen {
213    pub(crate) fn next(&mut self, span: SourceSpan) -> Symbol {
214        self.last_id += 1;
215        Symbol::new(&format!("*{}", self.last_id) as &str, span)
216    }
217    pub(crate) fn next_ignored(&mut self, span: SourceSpan) -> Symbol {
218        self.last_id += 1;
219        Symbol::new(&format!("~{}", self.last_id) as &str, span)
220    }
221}
222
223#[derive(Debug, Clone)]
224pub enum InputInlineRulesOrFixed {
225    Rules { rules: Vec<InputInlineRule> },
226    Fixed { fixed: FixedRuleApply },
227}
228
229impl InputInlineRulesOrFixed {
230    pub(crate) fn first_span(&self) -> SourceSpan {
231        match self {
232            InputInlineRulesOrFixed::Rules { rules, .. } => rules[0].span,
233            InputInlineRulesOrFixed::Fixed { fixed, .. } => fixed.span,
234        }
235    }
236    // pub(crate) fn used_rule(&self, rule_name: &Symbol) -> bool {
237    //     match self {
238    //         InputInlineRulesOrFixed::Rules { rules, .. } => rules
239    //             .iter()
240    //             .any(|rule| rule.body.iter().any(|atom| atom.used_rule(rule_name))),
241    //         InputInlineRulesOrFixed::Fixed { fixed, .. } => fixed.rule_args.iter().any(|arg| {
242    //             if let FixedRuleArg::InMem { name, .. } = arg {
243    //                 if name == rule_name {
244    //                     return true;
245    //                 }
246    //             }
247    //             false
248    //         }),
249    //     }
250    // }
251}
252
253#[derive(Clone)]
254pub struct FixedRuleApply {
255    pub fixed_handle: FixedRuleHandle,
256    pub rule_args: Vec<FixedRuleArg>,
257    pub options: Arc<BTreeMap<SmartString<LazyCompact>, Expr>>,
258    pub head: Vec<Symbol>,
259    pub arity: usize,
260    pub span: SourceSpan,
261    pub fixed_impl: Arc<Box<dyn FixedRule>>,
262}
263
264impl FixedRuleApply {
265    pub(crate) fn arity(&self) -> Result<usize> {
266        self.fixed_impl
267            .as_ref()
268            .arity(&self.options, &self.head, self.span)
269    }
270}
271
272impl Debug for FixedRuleApply {
273    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
274        f.debug_struct("FixedRuleApply")
275            .field("name", &self.fixed_handle.name)
276            .field("rules", &self.rule_args)
277            .field("options", &self.options)
278            .finish()
279    }
280}
281
282pub(crate) struct MagicFixedRuleApply {
283    pub(crate) fixed_handle: FixedRuleHandle,
284    pub(crate) rule_args: Vec<MagicFixedRuleRuleArg>,
285    pub(crate) options: Arc<BTreeMap<SmartString<LazyCompact>, Expr>>,
286    pub(crate) span: SourceSpan,
287    pub(crate) arity: usize,
288    pub(crate) fixed_impl: Arc<Box<dyn FixedRule>>,
289}
290
291#[derive(Error, Diagnostic, Debug)]
292#[error("Cannot find a required named option '{name}' for '{rule_name}'")]
293#[diagnostic(code(fixed_rule::arg_not_found))]
294pub(crate) struct FixedRuleOptionNotFoundError {
295    pub(crate) name: String,
296    #[label]
297    pub(crate) span: SourceSpan,
298    pub(crate) rule_name: String,
299}
300
301#[derive(Error, Diagnostic, Debug)]
302#[error("Wrong value for option '{name}' of '{rule_name}'")]
303#[diagnostic(code(fixed_rule::arg_wrong))]
304pub(crate) struct WrongFixedRuleOptionError {
305    pub(crate) name: String,
306    #[label]
307    pub(crate) span: SourceSpan,
308    pub(crate) rule_name: String,
309    #[help]
310    pub(crate) help: String,
311}
312
313impl MagicFixedRuleApply {
314    #[allow(dead_code)]
315    pub(crate) fn relation_with_min_len(
316        &self,
317        idx: usize,
318        len: usize,
319        tx: &SessionTx<'_>,
320        stores: &BTreeMap<MagicSymbol, EpochStore>,
321    ) -> Result<&MagicFixedRuleRuleArg> {
322        #[derive(Error, Diagnostic, Debug)]
323        #[error("Input relation to fixed rule has insufficient arity")]
324        #[diagnostic(help("Arity should be at least {0} but is {1}"))]
325        #[diagnostic(code(fixed_rule::input_relation_bad_arity))]
326        struct InputRelationArityError(usize, usize, #[label] SourceSpan);
327
328        let rel = self.relation(idx)?;
329        let arity = rel.arity(tx, stores)?;
330        ensure!(
331            arity >= len,
332            InputRelationArityError(len, arity, rel.span())
333        );
334        Ok(rel)
335    }
336    pub(crate) fn relations_count(&self) -> usize {
337        self.rule_args.len()
338    }
339    pub(crate) fn relation(&self, idx: usize) -> Result<&MagicFixedRuleRuleArg> {
340        #[derive(Error, Diagnostic, Debug)]
341        #[error("Cannot find a required positional argument at index {idx} for '{rule_name}'")]
342        #[diagnostic(code(fixed_rule::not_enough_args))]
343        pub(crate) struct FixedRuleNotEnoughRelationError {
344            idx: usize,
345            #[label]
346            span: SourceSpan,
347            rule_name: String,
348        }
349
350        Ok(self
351            .rule_args
352            .get(idx)
353            .ok_or_else(|| FixedRuleNotEnoughRelationError {
354                idx,
355                span: self.span,
356                rule_name: self.fixed_handle.name.to_string(),
357            })?)
358    }
359}
360
361impl Debug for MagicFixedRuleApply {
362    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
363        f.debug_struct("FixedRuleApply")
364            .field("name", &self.fixed_handle.name)
365            .field("rules", &self.rule_args)
366            .field("options", &self.options)
367            .finish()
368    }
369}
370
371#[derive(Clone)]
372pub enum FixedRuleArg {
373    InMem {
374        name: Symbol,
375        bindings: Vec<Symbol>,
376        span: SourceSpan,
377    },
378    Stored {
379        name: Symbol,
380        bindings: Vec<Symbol>,
381        valid_at: Option<ValidityTs>,
382        span: SourceSpan,
383    },
384    NamedStored {
385        name: Symbol,
386        bindings: BTreeMap<SmartString<LazyCompact>, Symbol>,
387        valid_at: Option<ValidityTs>,
388        span: SourceSpan,
389    },
390}
391
392impl Debug for FixedRuleArg {
393    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
394        write!(f, "{self}")
395    }
396}
397
398impl Display for FixedRuleArg {
399    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
400        match self {
401            FixedRuleArg::InMem { name, bindings, .. } => {
402                write!(f, "{name}")?;
403                f.debug_list().entries(bindings).finish()?;
404            }
405            FixedRuleArg::Stored { name, bindings, .. } => {
406                write!(f, ":{name}")?;
407                f.debug_list().entries(bindings).finish()?;
408            }
409            FixedRuleArg::NamedStored { name, bindings, .. } => {
410                write!(f, "*")?;
411                let mut sf = f.debug_struct(name);
412                for (k, v) in bindings {
413                    sf.field(k, v);
414                }
415                sf.finish()?;
416            }
417        }
418        Ok(())
419    }
420}
421
422#[derive(Debug)]
423pub(crate) enum MagicFixedRuleRuleArg {
424    InMem {
425        name: MagicSymbol,
426        bindings: Vec<Symbol>,
427        span: SourceSpan,
428    },
429    Stored {
430        name: Symbol,
431        bindings: Vec<Symbol>,
432        valid_at: Option<ValidityTs>,
433        span: SourceSpan,
434    },
435}
436
437impl MagicFixedRuleRuleArg {
438    #[allow(dead_code)]
439    pub(crate) fn bindings(&self) -> &[Symbol] {
440        match self {
441            MagicFixedRuleRuleArg::InMem { bindings, .. }
442            | MagicFixedRuleRuleArg::Stored { bindings, .. } => bindings,
443        }
444    }
445    #[allow(dead_code)]
446    pub(crate) fn span(&self) -> SourceSpan {
447        match self {
448            MagicFixedRuleRuleArg::InMem { span, .. }
449            | MagicFixedRuleRuleArg::Stored { span, .. } => *span,
450        }
451    }
452    pub(crate) fn get_binding_map(&self, starting: usize) -> BTreeMap<Symbol, usize> {
453        let bindings = match self {
454            MagicFixedRuleRuleArg::InMem { bindings, .. }
455            | MagicFixedRuleRuleArg::Stored { bindings, .. } => bindings,
456        };
457        bindings
458            .iter()
459            .enumerate()
460            .map(|(idx, symb)| (symb.clone(), idx + starting))
461            .collect()
462    }
463}
464
465/// This is a single query, as you'd find between `{}` in a chained query script or with no `{}` in a single query script.
466#[derive(Debug, Clone)]
467pub struct InputProgram {
468    /// A mapping of names to rules.  The entry rule must be named `?`.
469    ///
470    /// Ex: `?` in `?[a, b] := ...`
471    pub prog: BTreeMap<Symbol, InputInlineRulesOrFixed>,
472    pub out_opts: QueryOutOptions,
473    pub disable_magic_rewrite: bool,
474}
475
476impl Display for InputProgram {
477    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
478        for (name, rules) in &self.prog {
479            match rules {
480                InputInlineRulesOrFixed::Rules { rules, .. } => {
481                    for InputInlineRule {
482                        head, aggr, body, ..
483                    } in rules
484                    {
485                        write!(f, "{name}[")?;
486
487                        for (i, (h, a)) in head.iter().zip(aggr).enumerate() {
488                            if i > 0 {
489                                write!(f, ", ")?;
490                            }
491                            if let Some((aggr, aggr_args)) = a {
492                                write!(f, "{}({}", aggr.name, h)?;
493                                for aga in aggr_args {
494                                    write!(f, ", {aga}")?;
495                                }
496                                write!(f, ")")?;
497                            } else {
498                                write!(f, "{h}")?;
499                            }
500                        }
501                        write!(f, "] := ")?;
502                        for (i, atom) in body.iter().enumerate() {
503                            if i > 0 {
504                                write!(f, ", ")?;
505                            }
506                            write!(f, "{atom}")?;
507                        }
508                        writeln!(f, ";")?;
509                    }
510                }
511                InputInlineRulesOrFixed::Fixed {
512                    fixed:
513                        FixedRuleApply {
514                            fixed_handle: handle,
515                            rule_args,
516                            options,
517                            head,
518                            ..
519                        },
520                } => {
521                    write!(f, "{name}")?;
522                    f.debug_list().entries(head).finish()?;
523                    write!(f, " <~ ")?;
524                    write!(f, "{}(", handle.name)?;
525                    let mut first = true;
526                    for rule_arg in rule_args {
527                        if first {
528                            first = false;
529                        } else {
530                            write!(f, ", ")?;
531                        }
532                        write!(f, "{rule_arg}")?;
533                    }
534                    for (k, v) in options.as_ref() {
535                        if first {
536                            first = false;
537                        } else {
538                            write!(f, ", ")?;
539                        }
540                        write!(f, "{k}: {v}")?;
541                    }
542                    writeln!(f, ");")?;
543                }
544            }
545        }
546        write!(f, "{}", self.out_opts)?;
547        Ok(())
548    }
549}
550
551#[derive(Debug, Diagnostic, Error)]
552#[error("Entry head not found")]
553#[diagnostic(code(parser::no_entry_head))]
554#[diagnostic(help("You need to explicitly name your entry arguments"))]
555struct EntryHeadNotExplicitlyDefinedError(#[label] SourceSpan);
556
557#[derive(Debug, Diagnostic, Error)]
558#[error("Program has no entry")]
559#[diagnostic(code(parser::no_entry))]
560#[diagnostic(help("You need to have one rule named '?'"))]
561pub(crate) struct NoEntryError;
562
563impl InputProgram {
564    pub(crate) fn needs_write_lock(&self) -> Option<SmartString<LazyCompact>> {
565        if let Some((h, _, _)) = &self.out_opts.store_relation {
566            if !h.name.name.starts_with('_') {
567                Some(h.name.name.clone())
568            } else {
569                None
570            }
571        } else {
572            None
573        }
574    }
575
576    pub(crate) fn get_entry_arity(&self) -> Result<usize> {
577        if let Some(entry) = self.prog.get(&Symbol::new(PROG_ENTRY, SourceSpan(0, 0))) {
578            return match entry {
579                InputInlineRulesOrFixed::Rules { rules } => Ok(rules.last().unwrap().head.len()),
580                InputInlineRulesOrFixed::Fixed { fixed } => fixed.arity(),
581            };
582        }
583
584        Err(NoEntryError.into())
585    }
586    pub(crate) fn get_entry_out_head_or_default(&self) -> Result<Vec<Symbol>> {
587        match self.get_entry_out_head() {
588            Ok(r) => Ok(r),
589            Err(_) => {
590                let arity = self.get_entry_arity()?;
591                Ok((0..arity)
592                    .map(|i| Symbol::new(format!("_{i}"), SourceSpan(0, 0)))
593                    .collect())
594            }
595        }
596    }
597    pub(crate) fn get_entry_out_head(&self) -> Result<Vec<Symbol>> {
598        if let Some(entry) = self.prog.get(&Symbol::new(PROG_ENTRY, SourceSpan(0, 0))) {
599            return match entry {
600                InputInlineRulesOrFixed::Rules { rules } => {
601                    let head = &rules.last().unwrap().head;
602                    let mut ret = Vec::with_capacity(head.len());
603                    let aggrs = &rules.last().unwrap().aggr;
604                    for (symb, aggr) in head.iter().zip(aggrs.iter()) {
605                        if let Some((aggr, _)) = aggr {
606                            ret.push(Symbol::new(
607                                format!(
608                                    "{}({})",
609                                    aggr.name
610                                        .strip_prefix("AGGR_")
611                                        .unwrap()
612                                        .to_ascii_lowercase(),
613                                    symb
614                                ),
615                                symb.span,
616                            ))
617                        } else {
618                            ret.push(symb.clone())
619                        }
620                    }
621                    Ok(ret)
622                }
623                InputInlineRulesOrFixed::Fixed { fixed } => {
624                    if fixed.head.is_empty() {
625                        Err(EntryHeadNotExplicitlyDefinedError(entry.first_span()).into())
626                    } else {
627                        Ok(fixed.head.to_vec())
628                    }
629                }
630            };
631        }
632
633        Err(NoEntryError.into())
634    }
635    pub(crate) fn into_normalized_program(
636        self,
637        tx: &SessionTx<'_>,
638    ) -> Result<(NormalFormProgram, QueryOutOptions)> {
639        let mut prog: BTreeMap<Symbol, _> = Default::default();
640        for (k, rules_or_fixed) in self.prog {
641            match rules_or_fixed {
642                InputInlineRulesOrFixed::Rules { rules } => {
643                    let mut collected_rules = vec![];
644                    for rule in rules {
645                        let mut counter = -1;
646                        let mut gen_symb = |span| {
647                            counter += 1;
648                            Symbol::new(&format!("***{counter}") as &str, span)
649                        };
650                        let normalized_body = InputAtom::Conjunction {
651                            inner: rule.body,
652                            span: rule.span,
653                        }
654                        .disjunctive_normal_form(tx)?;
655                        let mut new_head = Vec::with_capacity(rule.head.len());
656                        let mut seen: BTreeMap<&Symbol, Vec<Symbol>> = BTreeMap::default();
657                        for symb in rule.head.iter() {
658                            match seen.entry(symb) {
659                                Entry::Vacant(e) => {
660                                    e.insert(vec![]);
661                                    new_head.push(symb.clone());
662                                }
663                                Entry::Occupied(mut e) => {
664                                    let new_symb = gen_symb(symb.span);
665                                    e.get_mut().push(new_symb.clone());
666                                    new_head.push(new_symb);
667                                }
668                            }
669                        }
670                        for conj in normalized_body.inner {
671                            let mut body = conj.0;
672                            for (old_symb, new_symbs) in seen.iter() {
673                                for new_symb in new_symbs.iter() {
674                                    body.push(NormalFormAtom::Unification(Unification {
675                                        binding: new_symb.clone(),
676                                        expr: Expr::Binding {
677                                            var: (*old_symb).clone(),
678                                            tuple_pos: None,
679                                        },
680                                        one_many_unif: false,
681                                        span: new_symb.span,
682                                    }))
683                                }
684                            }
685                            let normalized_rule = NormalFormInlineRule {
686                                head: new_head.clone(),
687                                aggr: rule.aggr.clone(),
688                                body,
689                            };
690                            collected_rules.push(normalized_rule.convert_to_well_ordered_rule()?);
691                        }
692                    }
693                    prog.insert(
694                        k.clone(),
695                        NormalFormRulesOrFixed::Rules {
696                            rules: collected_rules,
697                        },
698                    );
699                }
700                InputInlineRulesOrFixed::Fixed { fixed } => {
701                    prog.insert(k.clone(), NormalFormRulesOrFixed::Fixed { fixed });
702                }
703            }
704        }
705        Ok((
706            NormalFormProgram {
707                prog,
708                disable_magic_rewrite: self.disable_magic_rewrite,
709            },
710            self.out_opts,
711        ))
712    }
713}
714
715#[derive(Debug)]
716pub(crate) struct StratifiedNormalFormProgram(pub(crate) Vec<NormalFormProgram>);
717
718#[derive(Debug)]
719pub(crate) enum NormalFormRulesOrFixed {
720    Rules { rules: Vec<NormalFormInlineRule> },
721    Fixed { fixed: FixedRuleApply },
722}
723
724impl NormalFormRulesOrFixed {
725    pub(crate) fn rules(&self) -> Option<&[NormalFormInlineRule]> {
726        match self {
727            NormalFormRulesOrFixed::Rules { rules: r } => Some(r),
728            NormalFormRulesOrFixed::Fixed { fixed: _ } => None,
729        }
730    }
731}
732
733#[derive(Debug, Default)]
734pub(crate) struct NormalFormProgram {
735    pub(crate) prog: BTreeMap<Symbol, NormalFormRulesOrFixed>,
736    pub(crate) disable_magic_rewrite: bool,
737}
738
739#[derive(Debug)]
740pub(crate) struct StratifiedMagicProgram(pub(crate) Vec<MagicProgram>);
741
742#[derive(Debug)]
743pub(crate) enum MagicRulesOrFixed {
744    Rules { rules: Vec<MagicInlineRule> },
745    Fixed { fixed: MagicFixedRuleApply },
746}
747
748impl Default for MagicRulesOrFixed {
749    fn default() -> Self {
750        Self::Rules { rules: vec![] }
751    }
752}
753
754impl MagicRulesOrFixed {
755    pub(crate) fn arity(&self) -> Result<usize> {
756        Ok(match self {
757            MagicRulesOrFixed::Rules { rules } => rules.first().unwrap().head.len(),
758            MagicRulesOrFixed::Fixed { fixed } => fixed.arity,
759        })
760    }
761    pub(crate) fn mut_rules(&mut self) -> Option<&mut Vec<MagicInlineRule>> {
762        match self {
763            MagicRulesOrFixed::Rules { rules } => Some(rules),
764            MagicRulesOrFixed::Fixed { fixed: _ } => None,
765        }
766    }
767}
768
769#[derive(Debug)]
770pub(crate) struct MagicProgram {
771    pub(crate) prog: BTreeMap<MagicSymbol, MagicRulesOrFixed>,
772}
773
774#[derive(Clone, Ord, PartialOrd, Eq, PartialEq)]
775pub(crate) enum MagicSymbol {
776    Muggle {
777        inner: Symbol,
778    },
779    Magic {
780        inner: Symbol,
781        adornment: SmallVec<[bool; 8]>,
782    },
783    Input {
784        inner: Symbol,
785        adornment: SmallVec<[bool; 8]>,
786    },
787    Sup {
788        inner: Symbol,
789        adornment: SmallVec<[bool; 8]>,
790        rule_idx: u16,
791        sup_idx: u16,
792    },
793}
794
795impl MagicSymbol {
796    pub(crate) fn symbol(&self) -> &Symbol {
797        match self {
798            MagicSymbol::Muggle { inner, .. }
799            | MagicSymbol::Magic { inner, .. }
800            | MagicSymbol::Input { inner, .. }
801            | MagicSymbol::Sup { inner, .. } => inner,
802        }
803    }
804}
805
806impl Display for MagicSymbol {
807    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
808        write!(f, "{self:?}")
809    }
810}
811
812impl Debug for MagicSymbol {
813    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
814        match self {
815            MagicSymbol::Muggle { inner } => write!(f, "{}", inner.name),
816            MagicSymbol::Magic { inner, adornment } => {
817                write!(f, "{}|M", inner.name)?;
818                for b in adornment {
819                    if *b {
820                        write!(f, "b")?
821                    } else {
822                        write!(f, "f")?
823                    }
824                }
825                Ok(())
826            }
827            MagicSymbol::Input { inner, adornment } => {
828                write!(f, "{}|I", inner.name)?;
829                for b in adornment {
830                    if *b {
831                        write!(f, "b")?
832                    } else {
833                        write!(f, "f")?
834                    }
835                }
836                Ok(())
837            }
838            MagicSymbol::Sup {
839                inner,
840                adornment,
841                rule_idx,
842                sup_idx,
843            } => {
844                write!(f, "{}|S.{}.{}", inner.name, rule_idx, sup_idx)?;
845                for b in adornment {
846                    if *b {
847                        write!(f, "b")?
848                    } else {
849                        write!(f, "f")?
850                    }
851                }
852                Ok(())
853            }
854        }
855    }
856}
857
858impl MagicSymbol {
859    pub(crate) fn as_plain_symbol(&self) -> &Symbol {
860        match self {
861            MagicSymbol::Muggle { inner, .. }
862            | MagicSymbol::Magic { inner, .. }
863            | MagicSymbol::Input { inner, .. }
864            | MagicSymbol::Sup { inner, .. } => inner,
865        }
866    }
867    pub(crate) fn magic_adornment(&self) -> &[bool] {
868        match self {
869            MagicSymbol::Muggle { .. } => &[],
870            MagicSymbol::Magic { adornment, .. }
871            | MagicSymbol::Input { adornment, .. }
872            | MagicSymbol::Sup { adornment, .. } => adornment,
873        }
874    }
875    pub(crate) fn has_bound_adornment(&self) -> bool {
876        self.magic_adornment().iter().any(|b| *b)
877    }
878    pub(crate) fn is_prog_entry(&self) -> bool {
879        if let MagicSymbol::Muggle { inner } = self {
880            inner.is_prog_entry()
881        } else {
882            false
883        }
884    }
885}
886
887#[derive(Debug, Clone)]
888pub struct InputInlineRule {
889    pub head: Vec<Symbol>,
890    pub aggr: Vec<Option<(Aggregation, Vec<DataValue>)>>,
891    pub body: Vec<InputAtom>,
892    pub span: SourceSpan,
893}
894
895#[derive(Debug)]
896pub(crate) struct NormalFormInlineRule {
897    pub(crate) head: Vec<Symbol>,
898    pub(crate) aggr: Vec<Option<(Aggregation, Vec<DataValue>)>>,
899    pub(crate) body: Vec<NormalFormAtom>,
900}
901
902#[derive(Debug)]
903pub(crate) struct MagicInlineRule {
904    pub(crate) head: Vec<Symbol>,
905    pub(crate) aggr: Vec<Option<(Aggregation, Vec<DataValue>)>>,
906    pub(crate) body: Vec<MagicAtom>,
907}
908
909impl MagicInlineRule {
910    pub(crate) fn contained_rules(&self) -> BTreeMap<MagicSymbol, ContainedRuleMultiplicity> {
911        let mut coll = BTreeMap::new();
912        for atom in self.body.iter() {
913            match atom {
914                MagicAtom::Rule(rule) | MagicAtom::NegatedRule(rule) => {
915                    match coll.entry(rule.name.clone()) {
916                        Entry::Vacant(ent) => {
917                            ent.insert(ContainedRuleMultiplicity::One);
918                        }
919                        Entry::Occupied(mut ent) => {
920                            *ent.get_mut() = ContainedRuleMultiplicity::Many;
921                        }
922                    }
923                }
924                _ => {}
925            }
926        }
927        coll
928    }
929}
930
931#[derive(Clone)]
932pub enum InputAtom {
933    Rule {
934        inner: InputRuleApplyAtom,
935    },
936    NamedFieldRelation {
937        inner: InputNamedFieldRelationApplyAtom,
938    },
939    Relation {
940        inner: InputRelationApplyAtom,
941    },
942    Predicate {
943        inner: Expr,
944    },
945    Negation {
946        inner: Box<InputAtom>,
947        span: SourceSpan,
948    },
949    Conjunction {
950        inner: Vec<InputAtom>,
951        span: SourceSpan,
952    },
953    Disjunction {
954        inner: Vec<InputAtom>,
955        span: SourceSpan,
956    },
957    /// `x = y` or `x in y`
958    Unification {
959        inner: Unification,
960    },
961    Search {
962        inner: SearchInput,
963    },
964}
965
966#[derive(Clone)]
967pub struct SearchInput {
968    pub relation: Symbol,
969    pub index: Symbol,
970    pub bindings: BTreeMap<SmartString<LazyCompact>, Expr>,
971    pub parameters: BTreeMap<SmartString<LazyCompact>, Expr>,
972    pub span: SourceSpan,
973}
974
975#[derive(Clone, Debug)]
976pub(crate) struct HnswSearch {
977    pub(crate) base_handle: RelationHandle,
978    pub(crate) idx_handle: RelationHandle,
979    pub(crate) manifest: HnswIndexManifest,
980    pub(crate) bindings: Vec<Symbol>,
981    pub(crate) k: usize,
982    pub(crate) ef: usize,
983    pub(crate) query: Symbol,
984    pub(crate) bind_field: Option<Symbol>,
985    pub(crate) bind_field_idx: Option<Symbol>,
986    pub(crate) bind_distance: Option<Symbol>,
987    pub(crate) bind_vector: Option<Symbol>,
988    pub(crate) radius: Option<f64>,
989    pub(crate) filter: Option<Expr>,
990    pub(crate) span: SourceSpan,
991}
992
993#[derive(Copy, Clone, Debug, PartialEq, Eq)]
994pub(crate) enum FtsScoreKind {
995    TfIdf,
996    Tf,
997}
998
999#[derive(Clone, Debug)]
1000pub(crate) struct FtsSearch {
1001    pub(crate) base_handle: RelationHandle,
1002    pub(crate) idx_handle: RelationHandle,
1003    pub(crate) manifest: FtsIndexManifest,
1004    pub(crate) bindings: Vec<Symbol>,
1005    pub(crate) k: usize,
1006    // pub(crate) k1: f64,
1007    // pub(crate) b: f64,
1008    pub(crate) query: Symbol,
1009    pub(crate) score_kind: FtsScoreKind,
1010    pub(crate) bind_score: Option<Symbol>,
1011    // pub(crate) lax_mode: bool,
1012    pub(crate) filter: Option<Expr>,
1013    pub(crate) span: SourceSpan,
1014}
1015
1016impl HnswSearch {
1017    pub(crate) fn all_bindings(&self) -> impl Iterator<Item = &Symbol> {
1018        self.bindings
1019            .iter()
1020            .chain(self.bind_field.iter())
1021            .chain(self.bind_field_idx.iter())
1022            .chain(self.bind_distance.iter())
1023            .chain(self.bind_vector.iter())
1024    }
1025}
1026
1027impl FtsSearch {
1028    pub(crate) fn all_bindings(&self) -> impl Iterator<Item = &Symbol> {
1029        self.bindings.iter().chain(self.bind_score.iter())
1030    }
1031}
1032
1033impl SearchInput {
1034    fn normalize_lsh(
1035        mut self,
1036        base_handle: RelationHandle,
1037        idx_handle: RelationHandle,
1038        manifest: MinHashLshIndexManifest,
1039        gen: &mut TempSymbGen,
1040    ) -> Result<Disjunction> {
1041        let mut conj = Vec::with_capacity(self.bindings.len() + 8);
1042        let mut bindings = Vec::with_capacity(self.bindings.len());
1043        let mut seen_variables = BTreeSet::new();
1044
1045        for col in base_handle
1046            .metadata
1047            .keys
1048            .iter()
1049            .chain(base_handle.metadata.non_keys.iter())
1050        {
1051            if let Some(arg) = self.bindings.remove(&col.name) {
1052                match arg {
1053                    Expr::Binding { var, .. } => {
1054                        if var.is_ignored_symbol() {
1055                            bindings.push(gen.next_ignored(var.span));
1056                        } else if seen_variables.insert(var.clone()) {
1057                            bindings.push(var);
1058                        } else {
1059                            let span = var.span;
1060                            let dup = gen.next(span);
1061                            let unif = NormalFormAtom::Unification(Unification {
1062                                binding: dup.clone(),
1063                                expr: Expr::Binding {
1064                                    var,
1065                                    tuple_pos: None,
1066                                },
1067                                one_many_unif: false,
1068                                span,
1069                            });
1070                            conj.push(unif);
1071                            bindings.push(dup);
1072                        }
1073                    }
1074                    expr => {
1075                        let span = expr.span();
1076                        let kw = gen.next(span);
1077                        bindings.push(kw.clone());
1078                        let unif = NormalFormAtom::Unification(Unification {
1079                            binding: kw,
1080                            expr,
1081                            one_many_unif: false,
1082                            span,
1083                        });
1084                        conj.push(unif)
1085                    }
1086                }
1087            } else {
1088                bindings.push(gen.next_ignored(self.span));
1089            }
1090        }
1091
1092        if let Some((name, _)) = self.bindings.pop_first() {
1093            bail!(NamedFieldNotFound(
1094                self.relation.name.to_string(),
1095                name.to_string(),
1096                self.span
1097            ));
1098        }
1099
1100        #[derive(Debug, Error, Diagnostic)]
1101        #[error("Field `{0}` is required for LSH search")]
1102        #[diagnostic(code(parser::hnsw_query_required))]
1103        struct LshRequiredMissing(String, #[label] SourceSpan);
1104
1105        #[derive(Debug, Error, Diagnostic)]
1106        #[error("Expected a list of keys for LSH search")]
1107        #[diagnostic(code(parser::expected_list_for_lsh_keys))]
1108        struct ExpectedListForLshKeys(#[label] SourceSpan);
1109
1110        #[derive(Debug, Error, Diagnostic)]
1111        #[error("Wrong arity for LSH keys, expected {1}, got {2}")]
1112        #[diagnostic(code(parser::wrong_arity_for_lsh_keys))]
1113        struct WrongArityForKeys(#[label] SourceSpan, usize, usize);
1114
1115        let query = match self
1116            .parameters
1117            .remove("query")
1118            .ok_or_else(|| miette!(LshRequiredMissing("query".to_string(), self.span)))?
1119        {
1120            Expr::Binding { var, .. } => var,
1121            expr => {
1122                let span = expr.span();
1123                let kw = gen.next(span);
1124                let unif = NormalFormAtom::Unification(Unification {
1125                    binding: kw.clone(),
1126                    expr,
1127                    one_many_unif: false,
1128                    span,
1129                });
1130                conj.push(unif);
1131                kw
1132            }
1133        };
1134
1135        let k = match self.parameters.remove("k") {
1136            None => None,
1137            Some(k_expr) => {
1138                let k = k_expr.eval_to_const()?;
1139                let k = k.get_int().ok_or(ExpectedPosIntForFtsK(self.span))?;
1140
1141                #[derive(Debug, Error, Diagnostic)]
1142                #[error("Expected positive integer for `k`")]
1143                #[diagnostic(code(parser::expected_int_for_hnsw_k))]
1144                struct ExpectedPosIntForFtsK(#[label] SourceSpan);
1145
1146                ensure!(k > 0, ExpectedPosIntForFtsK(self.span));
1147                Some(k as usize)
1148            }
1149        };
1150
1151        let filter = self.parameters.remove("filter");
1152
1153        #[derive(Debug, Error, Diagnostic)]
1154        #[error("Extra parameters for LSH search: {0:?}")]
1155        #[diagnostic(code(parser::extra_parameters_for_lsh_search))]
1156        struct ExtraParametersForLshSearch(Vec<String>, #[label] SourceSpan);
1157
1158        if !self.parameters.is_empty() {
1159            bail!(ExtraParametersForLshSearch(
1160                self.parameters.keys().map(|s| s.to_string()).collect(),
1161                self.span
1162            ));
1163        }
1164
1165        conj.push(NormalFormAtom::LshSearch(LshSearch {
1166            base_handle,
1167            idx_handle,
1168            manifest,
1169            bindings,
1170            k,
1171            query,
1172            span: self.span,
1173            filter,
1174        }));
1175
1176        Ok(Disjunction::conj(conj))
1177    }
1178    fn normalize_fts(
1179        mut self,
1180        base_handle: RelationHandle,
1181        idx_handle: RelationHandle,
1182        manifest: FtsIndexManifest,
1183        gen: &mut TempSymbGen,
1184    ) -> Result<Disjunction> {
1185        let mut conj = Vec::with_capacity(self.bindings.len() + 8);
1186        let mut bindings = Vec::with_capacity(self.bindings.len());
1187        let mut seen_variables = BTreeSet::new();
1188
1189        for col in base_handle
1190            .metadata
1191            .keys
1192            .iter()
1193            .chain(base_handle.metadata.non_keys.iter())
1194        {
1195            if let Some(arg) = self.bindings.remove(&col.name) {
1196                match arg {
1197                    Expr::Binding { var, .. } => {
1198                        if var.is_ignored_symbol() {
1199                            bindings.push(gen.next_ignored(var.span));
1200                        } else if seen_variables.insert(var.clone()) {
1201                            bindings.push(var);
1202                        } else {
1203                            let span = var.span;
1204                            let dup = gen.next(span);
1205                            let unif = NormalFormAtom::Unification(Unification {
1206                                binding: dup.clone(),
1207                                expr: Expr::Binding {
1208                                    var,
1209                                    tuple_pos: None,
1210                                },
1211                                one_many_unif: false,
1212                                span,
1213                            });
1214                            conj.push(unif);
1215                            bindings.push(dup);
1216                        }
1217                    }
1218                    expr => {
1219                        let span = expr.span();
1220                        let kw = gen.next(span);
1221                        bindings.push(kw.clone());
1222                        let unif = NormalFormAtom::Unification(Unification {
1223                            binding: kw,
1224                            expr,
1225                            one_many_unif: false,
1226                            span,
1227                        });
1228                        conj.push(unif)
1229                    }
1230                }
1231            } else {
1232                bindings.push(gen.next_ignored(self.span));
1233            }
1234        }
1235
1236        if let Some((name, _)) = self.bindings.pop_first() {
1237            bail!(NamedFieldNotFound(
1238                self.relation.name.to_string(),
1239                name.to_string(),
1240                self.span
1241            ));
1242        }
1243
1244        #[derive(Debug, Error, Diagnostic)]
1245        #[error("Field `{0}` is required for HNSW search")]
1246        #[diagnostic(code(parser::hnsw_query_required))]
1247        struct HnswRequiredMissing(String, #[label] SourceSpan);
1248
1249        let query = match self
1250            .parameters
1251            .remove("query")
1252            .ok_or_else(|| miette!(HnswRequiredMissing("query".to_string(), self.span)))?
1253        {
1254            Expr::Binding { var, .. } => var,
1255            expr => {
1256                let span = expr.span();
1257                let kw = gen.next(span);
1258                let unif = NormalFormAtom::Unification(Unification {
1259                    binding: kw.clone(),
1260                    expr,
1261                    one_many_unif: false,
1262                    span,
1263                });
1264                conj.push(unif);
1265                kw
1266            }
1267        };
1268
1269        let k_expr = self
1270            .parameters
1271            .remove("k")
1272            .ok_or_else(|| miette!(HnswRequiredMissing("k".to_string(), self.span)))?;
1273        let k = k_expr.eval_to_const()?;
1274        let k = k.get_int().ok_or(ExpectedPosIntForFtsK(self.span))?;
1275
1276        #[derive(Debug, Error, Diagnostic)]
1277        #[error("Expected positive integer for `k`")]
1278        #[diagnostic(code(parser::expected_int_for_hnsw_k))]
1279        struct ExpectedPosIntForFtsK(#[label] SourceSpan);
1280
1281        ensure!(k > 0, ExpectedPosIntForFtsK(self.span));
1282
1283        let score_kind_expr = self.parameters.remove("score_kind");
1284        let score_kind = match score_kind_expr {
1285            Some(expr) => {
1286                let r = expr.eval_to_const()?;
1287                let r = r
1288                    .get_str()
1289                    .ok_or_else(|| miette!("Score kind for FTS must be a string"))?;
1290
1291                match r {
1292                    "tf_idf" => FtsScoreKind::TfIdf,
1293                    "tf" => FtsScoreKind::Tf,
1294                    s => bail!("Unknown score kind for FTS: {}", s),
1295                }
1296            }
1297            None => FtsScoreKind::TfIdf,
1298        };
1299
1300        let filter = self.parameters.remove("filter");
1301
1302        let bind_score = match self.parameters.remove("bind_score") {
1303            None => None,
1304            Some(Expr::Binding { var, .. }) => Some(var),
1305            Some(expr) => {
1306                let span = expr.span();
1307                let kw = gen.next(span);
1308                let unif = NormalFormAtom::Unification(Unification {
1309                    binding: kw.clone(),
1310                    expr,
1311                    one_many_unif: false,
1312                    span,
1313                });
1314                conj.push(unif);
1315                Some(kw)
1316            }
1317        };
1318
1319        if !self.parameters.is_empty() {
1320            bail!("Unknown parameters for FTS: {:?}", self.parameters.keys());
1321        }
1322
1323        conj.push(NormalFormAtom::FtsSearch(FtsSearch {
1324            base_handle,
1325            idx_handle,
1326            manifest,
1327            bindings,
1328            k: k as usize,
1329            query,
1330            score_kind,
1331            bind_score,
1332            // lax_mode,
1333            // k1,
1334            // b,
1335            filter,
1336            span: self.span,
1337        }));
1338
1339        Ok(Disjunction::conj(conj))
1340    }
1341    fn normalize_hnsw(
1342        mut self,
1343        base_handle: RelationHandle,
1344        idx_handle: RelationHandle,
1345        manifest: HnswIndexManifest,
1346        gen: &mut TempSymbGen,
1347    ) -> Result<Disjunction> {
1348        let mut conj = Vec::with_capacity(self.bindings.len() + 8);
1349        let mut bindings = Vec::with_capacity(self.bindings.len());
1350        let mut seen_variables = BTreeSet::new();
1351
1352        for col in base_handle
1353            .metadata
1354            .keys
1355            .iter()
1356            .chain(base_handle.metadata.non_keys.iter())
1357        {
1358            if let Some(arg) = self.bindings.remove(&col.name) {
1359                match arg {
1360                    Expr::Binding { var, .. } => {
1361                        if var.is_ignored_symbol() {
1362                            bindings.push(gen.next_ignored(var.span));
1363                        } else if seen_variables.insert(var.clone()) {
1364                            bindings.push(var);
1365                        } else {
1366                            let span = var.span;
1367                            let dup = gen.next(span);
1368                            let unif = NormalFormAtom::Unification(Unification {
1369                                binding: dup.clone(),
1370                                expr: Expr::Binding {
1371                                    var,
1372                                    tuple_pos: None,
1373                                },
1374                                one_many_unif: false,
1375                                span,
1376                            });
1377                            conj.push(unif);
1378                            bindings.push(dup);
1379                        }
1380                    }
1381                    expr => {
1382                        let span = expr.span();
1383                        let kw = gen.next(span);
1384                        bindings.push(kw.clone());
1385                        let unif = NormalFormAtom::Unification(Unification {
1386                            binding: kw,
1387                            expr,
1388                            one_many_unif: false,
1389                            span,
1390                        });
1391                        conj.push(unif)
1392                    }
1393                }
1394            } else {
1395                bindings.push(gen.next_ignored(self.span));
1396            }
1397        }
1398
1399        if let Some((name, _)) = self.bindings.pop_first() {
1400            bail!(NamedFieldNotFound(
1401                self.relation.name.to_string(),
1402                name.to_string(),
1403                self.span
1404            ));
1405        }
1406
1407        #[derive(Debug, Error, Diagnostic)]
1408        #[error("Field `{0}` is required for HNSW search")]
1409        #[diagnostic(code(parser::hnsw_query_required))]
1410        struct HnswRequiredMissing(String, #[label] SourceSpan);
1411
1412        let query = match self
1413            .parameters
1414            .remove("query")
1415            .ok_or_else(|| miette!(HnswRequiredMissing("query".to_string(), self.span)))?
1416        {
1417            Expr::Binding { var, .. } => var,
1418            expr => {
1419                let span = expr.span();
1420                let kw = gen.next(span);
1421                let unif = NormalFormAtom::Unification(Unification {
1422                    binding: kw.clone(),
1423                    expr,
1424                    one_many_unif: false,
1425                    span,
1426                });
1427                conj.push(unif);
1428                kw
1429            }
1430        };
1431
1432        let k_expr = self
1433            .parameters
1434            .remove("k")
1435            .ok_or_else(|| miette!(HnswRequiredMissing("k".to_string(), self.span)))?;
1436        let k = k_expr.eval_to_const()?;
1437        let k = k.get_int().ok_or(ExpectedPosIntForHnswK(self.span))?;
1438
1439        #[derive(Debug, Error, Diagnostic)]
1440        #[error("Expected positive integer for `k`")]
1441        #[diagnostic(code(parser::expected_int_for_hnsw_k))]
1442        struct ExpectedPosIntForHnswK(#[label] SourceSpan);
1443
1444        ensure!(k > 0, ExpectedPosIntForHnswK(self.span));
1445
1446        let ef_expr = self
1447            .parameters
1448            .remove("ef")
1449            .ok_or_else(|| miette!(HnswRequiredMissing("ef".to_string(), self.span)))?;
1450        let ef = ef_expr.eval_to_const()?;
1451        let ef = ef.get_int().ok_or(ExpectedPosIntForHnswEf(self.span))?;
1452
1453        #[derive(Debug, Error, Diagnostic)]
1454        #[error("Expected positive integer for `ef`")]
1455        #[diagnostic(code(parser::expected_int_for_hnsw_ef))]
1456        struct ExpectedPosIntForHnswEf(#[label] SourceSpan);
1457
1458        ensure!(ef > 0, ExpectedPosIntForHnswEf(self.span));
1459
1460        let radius_expr = self.parameters.remove("radius");
1461        let radius = match radius_expr {
1462            Some(expr) => {
1463                let r = expr.eval_to_const()?;
1464                let r = r.get_float().ok_or(ExpectedFloatForHnswRadius(self.span))?;
1465
1466                #[derive(Debug, Error, Diagnostic)]
1467                #[error("Expected positive float for `radius`")]
1468                #[diagnostic(code(parser::expected_float_for_hnsw_radius))]
1469                struct ExpectedFloatForHnswRadius(#[label] SourceSpan);
1470
1471                ensure!(r > 0.0, ExpectedFloatForHnswRadius(self.span));
1472                Some(r)
1473            }
1474            None => None,
1475        };
1476
1477        let filter = self.parameters.remove("filter");
1478
1479        let bind_field = match self.parameters.remove("bind_field") {
1480            None => None,
1481            Some(Expr::Binding { var, .. }) => Some(var),
1482            Some(expr) => {
1483                let span = expr.span();
1484                let kw = gen.next(span);
1485                let unif = NormalFormAtom::Unification(Unification {
1486                    binding: kw.clone(),
1487                    expr,
1488                    one_many_unif: false,
1489                    span,
1490                });
1491                conj.push(unif);
1492                Some(kw)
1493            }
1494        };
1495
1496        let bind_field_idx = match self.parameters.remove("bind_field_idx") {
1497            None => None,
1498            Some(Expr::Binding { var, .. }) => Some(var),
1499            Some(expr) => {
1500                let span = expr.span();
1501                let kw = gen.next(span);
1502                let unif = NormalFormAtom::Unification(Unification {
1503                    binding: kw.clone(),
1504                    expr,
1505                    one_many_unif: false,
1506                    span,
1507                });
1508                conj.push(unif);
1509                Some(kw)
1510            }
1511        };
1512
1513        let bind_distance = match self.parameters.remove("bind_distance") {
1514            None => None,
1515            Some(Expr::Binding { var, .. }) => Some(var),
1516            Some(expr) => {
1517                let span = expr.span();
1518                let kw = gen.next(span);
1519                let unif = NormalFormAtom::Unification(Unification {
1520                    binding: kw.clone(),
1521                    expr,
1522                    one_many_unif: false,
1523                    span,
1524                });
1525                conj.push(unif);
1526                Some(kw)
1527            }
1528        };
1529
1530        let bind_vector = match self.parameters.remove("bind_vector") {
1531            None => None,
1532            Some(Expr::Binding { var, .. }) => Some(var),
1533            Some(expr) => {
1534                let span = expr.span();
1535                let kw = gen.next(span);
1536                let unif = NormalFormAtom::Unification(Unification {
1537                    binding: kw.clone(),
1538                    expr,
1539                    one_many_unif: false,
1540                    span,
1541                });
1542                conj.push(unif);
1543                Some(kw)
1544            }
1545        };
1546
1547        if !self.parameters.is_empty() {
1548            bail!("Unexpected parameters for HNSW: {:?}", self.parameters);
1549        }
1550
1551        conj.push(NormalFormAtom::HnswSearch(HnswSearch {
1552            base_handle,
1553            idx_handle,
1554            manifest,
1555            bindings,
1556            k: k as usize,
1557            ef: ef as usize,
1558            query,
1559            bind_field,
1560            bind_field_idx,
1561            bind_distance,
1562            bind_vector,
1563            radius,
1564            filter,
1565            span: self.span,
1566        }));
1567
1568        Ok(Disjunction::conj(conj))
1569    }
1570    pub(crate) fn normalize(
1571        self,
1572        gen: &mut TempSymbGen,
1573        tx: &SessionTx<'_>,
1574    ) -> Result<Disjunction> {
1575        let base_handle = tx.get_relation(&self.relation, false)?;
1576        if base_handle.access_level < AccessLevel::ReadOnly {
1577            bail!(InsufficientAccessLevel(
1578                base_handle.name.to_string(),
1579                "reading rows".to_string(),
1580                base_handle.access_level
1581            ));
1582        }
1583        if let Some((idx_handle, manifest)) =
1584            base_handle.hnsw_indices.get(&self.index.name).cloned()
1585        {
1586            return self.normalize_hnsw(base_handle, idx_handle, manifest, gen);
1587        }
1588        if let Some((idx_handle, manifest)) = base_handle.fts_indices.get(&self.index.name).cloned()
1589        {
1590            return self.normalize_fts(base_handle, idx_handle, manifest, gen);
1591        }
1592        if let Some((idx_handle, _, manifest)) =
1593            base_handle.lsh_indices.get(&self.index.name).cloned()
1594        {
1595            return self.normalize_lsh(base_handle, idx_handle, manifest, gen);
1596        }
1597        #[derive(Debug, Error, Diagnostic)]
1598        #[error("Index {name} not found on relation {relation}")]
1599        #[diagnostic(code(eval::hnsw_index_not_found))]
1600        struct IndexNotFound {
1601            relation: String,
1602            name: String,
1603            #[label]
1604            span: SourceSpan,
1605        }
1606        bail!(IndexNotFound {
1607            relation: self.relation.to_string(),
1608            name: self.index.to_string(),
1609            span: self.span,
1610        })
1611    }
1612}
1613
1614impl Debug for InputAtom {
1615    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1616        write!(f, "{self}")
1617    }
1618}
1619
1620impl Display for InputAtom {
1621    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1622        match self {
1623            InputAtom::Rule {
1624                inner: InputRuleApplyAtom { name, args, .. },
1625            } => {
1626                write!(f, "{name}")?;
1627                f.debug_list().entries(args).finish()?;
1628            }
1629            InputAtom::NamedFieldRelation {
1630                inner: InputNamedFieldRelationApplyAtom { name, args, .. },
1631            } => {
1632                f.write_str("*")?;
1633                let mut sf = f.debug_struct(name);
1634                for (k, v) in args {
1635                    sf.field(k, v);
1636                }
1637                sf.finish()?;
1638            }
1639            InputAtom::Relation {
1640                inner: InputRelationApplyAtom { name, args, .. },
1641            } => {
1642                write!(f, ":{name}")?;
1643                f.debug_list().entries(args).finish()?;
1644            }
1645            InputAtom::Search { inner } => {
1646                write!(f, "~{}:{}{{", inner.relation, inner.index)?;
1647                for (binding, expr) in &inner.bindings {
1648                    write!(f, "{binding}: {expr}, ")?;
1649                }
1650                write!(f, "| ")?;
1651                for (k, v) in inner.parameters.iter() {
1652                    write!(f, "{k}: {v}, ")?;
1653                }
1654                write!(f, "}}")?;
1655            }
1656            InputAtom::Predicate { inner } => {
1657                write!(f, "{inner}")?;
1658            }
1659            InputAtom::Negation { inner, .. } => {
1660                write!(f, "not {inner}")?;
1661            }
1662            InputAtom::Conjunction { inner, .. } => {
1663                for (i, a) in inner.iter().enumerate() {
1664                    if i > 0 {
1665                        write!(f, " and ")?;
1666                    }
1667                    write!(f, "({a})")?;
1668                }
1669            }
1670            InputAtom::Disjunction { inner, .. } => {
1671                for (i, a) in inner.iter().enumerate() {
1672                    if i > 0 {
1673                        write!(f, " or ")?;
1674                    }
1675                    write!(f, "({a})")?;
1676                }
1677            }
1678            InputAtom::Unification {
1679                inner:
1680                    Unification {
1681                        binding,
1682                        expr,
1683                        one_many_unif,
1684                        ..
1685                    },
1686            } => {
1687                write!(f, "{binding}")?;
1688                if *one_many_unif {
1689                    write!(f, " in ")?;
1690                } else {
1691                    write!(f, " = ")?;
1692                }
1693                write!(f, "{expr}")?;
1694            }
1695        }
1696        Ok(())
1697    }
1698}
1699
1700impl InputAtom {
1701    // pub(crate) fn used_rule(&self, rule_name: &Symbol) -> bool {
1702    //     match self {
1703    //         InputAtom::Rule { inner } => inner.name == *rule_name,
1704    //         InputAtom::Negation { inner, .. } => inner.used_rule(rule_name),
1705    //         InputAtom::Conjunction { inner, .. } | InputAtom::Disjunction { inner, .. } => {
1706    //             inner.iter().any(|a| a.used_rule(rule_name))
1707    //         }
1708    //         _ => false,
1709    //     }
1710    // }
1711    pub(crate) fn span(&self) -> SourceSpan {
1712        match self {
1713            InputAtom::Negation { span, .. }
1714            | InputAtom::Conjunction { span, .. }
1715            | InputAtom::Disjunction { span, .. } => *span,
1716            InputAtom::Rule { inner, .. } => inner.span,
1717            InputAtom::NamedFieldRelation { inner, .. } => inner.span,
1718            InputAtom::Relation { inner, .. } => inner.span,
1719            InputAtom::Predicate { inner, .. } => inner.span(),
1720            InputAtom::Unification { inner, .. } => inner.span,
1721            InputAtom::Search { inner, .. } => inner.span,
1722        }
1723    }
1724}
1725
1726#[derive(Debug, Clone)]
1727pub(crate) enum NormalFormAtom {
1728    Rule(NormalFormRuleApplyAtom),
1729    Relation(NormalFormRelationApplyAtom),
1730    NegatedRule(NormalFormRuleApplyAtom),
1731    NegatedRelation(NormalFormRelationApplyAtom),
1732    Predicate(Expr),
1733    Unification(Unification),
1734    HnswSearch(HnswSearch),
1735    FtsSearch(FtsSearch),
1736    LshSearch(LshSearch),
1737}
1738
1739#[derive(Debug, Clone)]
1740pub(crate) enum MagicAtom {
1741    Rule(MagicRuleApplyAtom),
1742    Relation(MagicRelationApplyAtom),
1743    Predicate(Expr),
1744    NegatedRule(MagicRuleApplyAtom),
1745    NegatedRelation(MagicRelationApplyAtom),
1746    Unification(Unification),
1747    HnswSearch(HnswSearch),
1748    FtsSearch(FtsSearch),
1749    LshSearch(LshSearch),
1750}
1751
1752#[derive(Clone, Debug)]
1753pub struct InputRuleApplyAtom {
1754    pub name: Symbol,
1755    pub args: Vec<Expr>,
1756    pub span: SourceSpan,
1757}
1758
1759#[derive(Clone, Debug)]
1760pub struct InputNamedFieldRelationApplyAtom {
1761    pub name: Symbol,
1762    pub args: BTreeMap<SmartString<LazyCompact>, Expr>,
1763    pub valid_at: Option<ValidityTs>,
1764    pub span: SourceSpan,
1765}
1766
1767#[derive(Clone, Debug)]
1768pub struct InputRelationApplyAtom {
1769    pub name: Symbol,
1770    pub args: Vec<Expr>,
1771    pub valid_at: Option<ValidityTs>,
1772    pub span: SourceSpan,
1773}
1774
1775#[derive(Clone, Debug)]
1776pub(crate) struct NormalFormRuleApplyAtom {
1777    pub(crate) name: Symbol,
1778    pub(crate) args: Vec<Symbol>,
1779    pub(crate) span: SourceSpan,
1780}
1781
1782#[derive(Clone, Debug)]
1783pub(crate) struct NormalFormRelationApplyAtom {
1784    pub(crate) name: Symbol,
1785    pub(crate) args: Vec<Symbol>,
1786    pub(crate) valid_at: Option<ValidityTs>,
1787    pub(crate) span: SourceSpan,
1788}
1789
1790#[derive(Clone, Debug)]
1791pub(crate) struct MagicRuleApplyAtom {
1792    pub(crate) name: MagicSymbol,
1793    pub(crate) args: Vec<Symbol>,
1794    pub(crate) span: SourceSpan,
1795}
1796
1797#[derive(Clone, Debug)]
1798pub(crate) struct MagicRelationApplyAtom {
1799    pub(crate) name: Symbol,
1800    pub(crate) args: Vec<Symbol>,
1801    pub(crate) valid_at: Option<ValidityTs>,
1802    pub(crate) span: SourceSpan,
1803}
1804
1805#[derive(Clone, Debug)]
1806pub struct Unification {
1807    /// Symbol to bind expression to.
1808    pub binding: Symbol,
1809    pub expr: Expr,
1810    /// If false, `=`, if true, `in`. If true, one row is created for each value in the list in `expr`.
1811    pub one_many_unif: bool,
1812    pub span: SourceSpan,
1813}
1814
1815impl Unification {
1816    pub(crate) fn is_const(&self) -> bool {
1817        matches!(self.expr, Expr::Const { .. })
1818    }
1819    pub(crate) fn bindings_in_expr(&self) -> Result<BTreeSet<Symbol>> {
1820        self.expr.bindings()
1821    }
1822}