cozo_ce/parse/
mod.rs

1/*
2 * Copyright 2022, The Cozo Project Authors.
3 *
4 * This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
5 * If a copy of the MPL was not distributed with this file,
6 * You can obtain one at https://mozilla.org/MPL/2.0/.
7 */
8//! AST for Cozo scripts, for generating Cozo scripts programmatically.
9//!
10//! NOTE! This is unstable, the AST structure and method signatures may change in any release. Use at your own risk.
11
12use std::cmp::{max, min};
13use std::collections::{BTreeMap, BTreeSet};
14use std::fmt::{Display, Formatter};
15use std::sync::Arc;
16
17use either::{Either, Left};
18use miette::{bail, Diagnostic, IntoDiagnostic, Result};
19use pest::error::InputLocation;
20use pest::Parser;
21use smartstring::{LazyCompact, SmartString};
22use thiserror::Error;
23
24use crate::data::program::InputProgram;
25use crate::data::relation::NullableColType;
26use crate::data::value::{DataValue, ValidityTs};
27use crate::parse::expr::build_expr;
28use crate::parse::imperative::parse_imperative_block;
29use crate::parse::query::parse_query;
30use crate::parse::schema::parse_nullable_type;
31use crate::parse::sys::{parse_sys, SysOp};
32use crate::{Expr, FixedRule};
33
34pub(crate) mod expr;
35pub(crate) mod fts;
36pub(crate) mod imperative;
37pub(crate) mod query;
38pub(crate) mod schema;
39pub(crate) mod sys;
40
41#[derive(pest_derive::Parser)]
42#[grammar = "cozoscript.pest"]
43pub(crate) struct CozoScriptParser;
44
45pub(crate) type Pair<'a> = pest::iterators::Pair<'a, Rule>;
46pub(crate) type Pairs<'a> = pest::iterators::Pairs<'a, Rule>;
47
48/// This represents a full Cozo script, as you'd pass to `run_script`.
49#[derive(Debug)]
50pub enum CozoScript {
51    #[allow(missing_docs)]
52    Single(InputProgram),
53    #[allow(missing_docs)]
54    Imperative(ImperativeProgram),
55    #[allow(missing_docs)]
56    Sys(SysOp),
57}
58
59#[allow(missing_docs)]
60#[derive(Debug)]
61pub struct ImperativeStmtClause {
62    pub prog: InputProgram,
63    pub store_as: Option<SmartString<LazyCompact>>,
64}
65
66#[allow(missing_docs)]
67#[derive(Debug)]
68pub struct ImperativeSysop {
69    pub sysop: SysOp,
70    pub store_as: Option<SmartString<LazyCompact>>,
71}
72
73#[allow(missing_docs)]
74#[derive(Debug)]
75pub enum ImperativeStmt {
76    Break {
77        target: Option<SmartString<LazyCompact>>,
78        span: SourceSpan,
79    },
80    Continue {
81        target: Option<SmartString<LazyCompact>>,
82        span: SourceSpan,
83    },
84    Return {
85        returns: Vec<Either<ImperativeStmtClause, SmartString<LazyCompact>>>,
86    },
87    Program {
88        prog: ImperativeStmtClause,
89    },
90    SysOp {
91        sysop: ImperativeSysop,
92    },
93    IgnoreErrorProgram {
94        prog: ImperativeStmtClause,
95    },
96    If {
97        condition: ImperativeCondition,
98        then_branch: ImperativeProgram,
99        else_branch: ImperativeProgram,
100        negated: bool,
101    },
102    Loop {
103        label: Option<SmartString<LazyCompact>>,
104        body: ImperativeProgram,
105    },
106    TempSwap {
107        left: SmartString<LazyCompact>,
108        right: SmartString<LazyCompact>,
109        // span: SourceSpan,
110    },
111    TempDebug {
112        temp: SmartString<LazyCompact>,
113    },
114}
115
116pub(crate) type ImperativeCondition = Either<SmartString<LazyCompact>, ImperativeStmtClause>;
117
118/// This is a [chained query](https://docs.cozodb.org/en/latest/stored.html#chaining-queries),
119/// a series of `{}` queries possibly with imperative directives like `%if` and `%loop`.
120pub type ImperativeProgram = Vec<ImperativeStmt>;
121
122impl ImperativeStmt {
123    pub(crate) fn needs_write_locks(&self, collector: &mut BTreeSet<SmartString<LazyCompact>>) {
124        match self {
125            ImperativeStmt::Program { prog, .. }
126            | ImperativeStmt::IgnoreErrorProgram { prog, .. } => {
127                if let Some(name) = prog.prog.needs_write_lock() {
128                    collector.insert(name);
129                }
130            }
131            ImperativeStmt::Return { returns, .. } => {
132                for ret in returns {
133                    if let Left(prog) = ret {
134                        if let Some(name) = prog.prog.needs_write_lock() {
135                            collector.insert(name);
136                        }
137                    }
138                }
139            }
140            ImperativeStmt::If {
141                condition,
142                then_branch,
143                else_branch,
144                ..
145            } => {
146                if let ImperativeCondition::Right(prog) = condition {
147                    if let Some(name) = prog.prog.needs_write_lock() {
148                        collector.insert(name);
149                    }
150                }
151                for prog in then_branch.iter().chain(else_branch.iter()) {
152                    prog.needs_write_locks(collector);
153                }
154            }
155            ImperativeStmt::Loop { body, .. } => {
156                for prog in body {
157                    prog.needs_write_locks(collector);
158                }
159            }
160            ImperativeStmt::TempDebug { .. }
161            | ImperativeStmt::Break { .. }
162            | ImperativeStmt::Continue { .. }
163            | ImperativeStmt::TempSwap { .. } => {}
164            ImperativeStmt::SysOp { sysop } => match &sysop.sysop {
165                SysOp::RemoveRelation(rels) => {
166                    for rel in rels {
167                        collector.insert(rel.name.clone());
168                    }
169                }
170                SysOp::RenameRelation(renames) => {
171                    for (old, new) in renames {
172                        collector.insert(old.name.clone());
173                        collector.insert(new.name.clone());
174                    }
175                }
176                SysOp::CreateIndex(symb, subs, _) => {
177                    collector.insert(symb.name.clone());
178                    collector.insert(SmartString::from(format!("{}:{}", symb.name, subs.name)));
179                }
180                SysOp::CreateVectorIndex(m) => {
181                    collector.insert(m.base_relation.clone());
182                    collector.insert(SmartString::from(format!(
183                        "{}:{}",
184                        m.base_relation, m.index_name
185                    )));
186                }
187                SysOp::CreateFtsIndex(m) => {
188                    collector.insert(m.base_relation.clone());
189                    collector.insert(SmartString::from(format!(
190                        "{}:{}",
191                        m.base_relation, m.index_name
192                    )));
193                }
194                SysOp::CreateMinHashLshIndex(m) => {
195                    collector.insert(m.base_relation.clone());
196                    collector.insert(SmartString::from(format!(
197                        "{}:{}",
198                        m.base_relation, m.index_name
199                    )));
200                }
201                SysOp::RemoveIndex(rel, idx) => {
202                    collector.insert(SmartString::from(format!("{}:{}", rel.name, idx.name)));
203                }
204                _ => {}
205            },
206        }
207    }
208}
209
210impl CozoScript {
211    pub(crate) fn get_single_program(self) -> Result<InputProgram> {
212        #[derive(Debug, Error, Diagnostic)]
213        #[error("expect script to contain only a single program")]
214        #[diagnostic(code(parser::expect_singleton))]
215        struct ExpectSingleProgram;
216        match self {
217            CozoScript::Single(s) => Ok(s),
218            CozoScript::Imperative(_) | CozoScript::Sys(_) => {
219                bail!(ExpectSingleProgram)
220            }
221        }
222    }
223}
224
225/// Span of the element in the source script, with starting and ending positions.
226#[derive(
227    Eq, PartialEq, Debug, serde_derive::Serialize, serde_derive::Deserialize, Copy, Clone, Default,
228)]
229pub struct SourceSpan(pub usize, pub usize);
230
231impl Display for SourceSpan {
232    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
233        write!(f, "{}..{}", self.0, self.0 + self.1)
234    }
235}
236
237impl SourceSpan {
238    pub(crate) fn merge(self, other: Self) -> Self {
239        let s1 = self.0;
240        let e1 = self.0 + self.1;
241        let s2 = other.0;
242        let e2 = other.0 + other.1;
243        let s = min(s1, s2);
244        let e = max(e1, e2);
245        Self(s, e - s)
246    }
247}
248
249impl From<&'_ SourceSpan> for miette::SourceSpan {
250    fn from(s: &'_ SourceSpan) -> Self {
251        miette::SourceSpan::new(s.0.into(), s.1.into())
252    }
253}
254
255impl From<SourceSpan> for miette::SourceSpan {
256    fn from(s: SourceSpan) -> Self {
257        miette::SourceSpan::new(s.0.into(), s.1.into())
258    }
259}
260
261#[derive(thiserror::Error, Diagnostic, Debug)]
262#[error("The query parser has encountered unexpected input / end of input at {span}")]
263#[diagnostic(code(parser::pest))]
264pub(crate) struct ParseError {
265    #[label]
266    pub(crate) span: SourceSpan,
267}
268
269pub(crate) fn parse_type(src: &str) -> Result<NullableColType> {
270    let parsed = CozoScriptParser::parse(Rule::col_type_with_term, src)
271        .into_diagnostic()?
272        .next()
273        .unwrap();
274    parse_nullable_type(parsed.into_inner().next().unwrap())
275}
276
277pub(crate) fn parse_expressions(
278    src: &str,
279    param_pool: &BTreeMap<String, DataValue>,
280) -> Result<Expr> {
281    let parsed = CozoScriptParser::parse(Rule::expression_script, src)
282        .map_err(|err| {
283            let span = match err.location {
284                InputLocation::Pos(p) => SourceSpan(p, 0),
285                InputLocation::Span((start, end)) => SourceSpan(start, end - start),
286            };
287            ParseError { span }
288        })?
289        .next()
290        .unwrap();
291
292    build_expr(parsed.into_inner().next().unwrap(), param_pool)
293}
294
295/// This parses a text script into the AST used by Cozo.
296///
297/// Note! This is an unstable interface, the signature may change between releases. Depend on it at your own risk.
298///
299/// * `src` - the script to parse
300///
301/// * `param_pool` - the list of parameters to execute the script with. These are substituted into the syntax tree during parsing.
302///
303/// * `fixed_rules` - a mapping of fixed rule names to their implementations. These are substituted into the syntax tree during parsing.
304///
305/// * `cur_vld` - the current timestamp, substituted into expressions where validity is relevant.
306pub fn parse_script(
307    src: &str,
308    param_pool: &BTreeMap<String, DataValue>,
309    fixed_rules: &BTreeMap<String, Arc<Box<dyn FixedRule>>>,
310    cur_vld: ValidityTs,
311) -> Result<CozoScript> {
312    let parsed = CozoScriptParser::parse(Rule::script, src)
313        .map_err(|err| {
314            let span = match err.location {
315                InputLocation::Pos(p) => SourceSpan(p, 0),
316                InputLocation::Span((start, end)) => SourceSpan(start, end - start),
317            };
318            ParseError { span }
319        })?
320        .next()
321        .unwrap();
322    Ok(match parsed.as_rule() {
323        Rule::query_script => {
324            let q = parse_query(parsed.into_inner(), param_pool, fixed_rules, cur_vld)?;
325            CozoScript::Single(q)
326        }
327        Rule::imperative_script => {
328            let p = parse_imperative_block(parsed, param_pool, fixed_rules, cur_vld)?;
329            CozoScript::Imperative(p)
330        }
331
332        Rule::sys_script => CozoScript::Sys(parse_sys(
333            parsed.into_inner(),
334            param_pool,
335            fixed_rules,
336            cur_vld,
337        )?),
338        _ => unreachable!(),
339    })
340}
341
342trait ExtractSpan {
343    fn extract_span(&self) -> SourceSpan;
344}
345
346impl ExtractSpan for Pair<'_> {
347    fn extract_span(&self) -> SourceSpan {
348        let span = self.as_span();
349        let start = span.start();
350        let end = span.end();
351        SourceSpan(start, end - start)
352    }
353}