1use std::cmp::{max, min};
13use std::collections::{BTreeMap, BTreeSet};
14use std::fmt::{Display, Formatter};
15use std::sync::Arc;
16
17use either::{Either, Left};
18use miette::{bail, Diagnostic, IntoDiagnostic, Result};
19use pest::error::InputLocation;
20use pest::Parser;
21use smartstring::{LazyCompact, SmartString};
22use thiserror::Error;
23
24use crate::data::program::InputProgram;
25use crate::data::relation::NullableColType;
26use crate::data::value::{DataValue, ValidityTs};
27use crate::parse::expr::build_expr;
28use crate::parse::imperative::parse_imperative_block;
29use crate::parse::query::parse_query;
30use crate::parse::schema::parse_nullable_type;
31use crate::parse::sys::{parse_sys, SysOp};
32use crate::{Expr, FixedRule};
33
34pub(crate) mod expr;
35pub(crate) mod fts;
36pub(crate) mod imperative;
37pub(crate) mod query;
38pub(crate) mod schema;
39pub(crate) mod sys;
40
41#[derive(pest_derive::Parser)]
42#[grammar = "cozoscript.pest"]
43pub(crate) struct CozoScriptParser;
44
45pub(crate) type Pair<'a> = pest::iterators::Pair<'a, Rule>;
46pub(crate) type Pairs<'a> = pest::iterators::Pairs<'a, Rule>;
47
48#[derive(Debug)]
50pub enum CozoScript {
51 #[allow(missing_docs)]
52 Single(InputProgram),
53 #[allow(missing_docs)]
54 Imperative(ImperativeProgram),
55 #[allow(missing_docs)]
56 Sys(SysOp),
57}
58
59#[allow(missing_docs)]
60#[derive(Debug)]
61pub struct ImperativeStmtClause {
62 pub prog: InputProgram,
63 pub store_as: Option<SmartString<LazyCompact>>,
64}
65
66#[allow(missing_docs)]
67#[derive(Debug)]
68pub struct ImperativeSysop {
69 pub sysop: SysOp,
70 pub store_as: Option<SmartString<LazyCompact>>,
71}
72
73#[allow(missing_docs)]
74#[derive(Debug)]
75pub enum ImperativeStmt {
76 Break {
77 target: Option<SmartString<LazyCompact>>,
78 span: SourceSpan,
79 },
80 Continue {
81 target: Option<SmartString<LazyCompact>>,
82 span: SourceSpan,
83 },
84 Return {
85 returns: Vec<Either<ImperativeStmtClause, SmartString<LazyCompact>>>,
86 },
87 Program {
88 prog: ImperativeStmtClause,
89 },
90 SysOp {
91 sysop: ImperativeSysop,
92 },
93 IgnoreErrorProgram {
94 prog: ImperativeStmtClause,
95 },
96 If {
97 condition: ImperativeCondition,
98 then_branch: ImperativeProgram,
99 else_branch: ImperativeProgram,
100 negated: bool,
101 },
102 Loop {
103 label: Option<SmartString<LazyCompact>>,
104 body: ImperativeProgram,
105 },
106 TempSwap {
107 left: SmartString<LazyCompact>,
108 right: SmartString<LazyCompact>,
109 },
111 TempDebug {
112 temp: SmartString<LazyCompact>,
113 },
114}
115
116pub(crate) type ImperativeCondition = Either<SmartString<LazyCompact>, ImperativeStmtClause>;
117
118pub type ImperativeProgram = Vec<ImperativeStmt>;
121
122impl ImperativeStmt {
123 pub(crate) fn needs_write_locks(&self, collector: &mut BTreeSet<SmartString<LazyCompact>>) {
124 match self {
125 ImperativeStmt::Program { prog, .. }
126 | ImperativeStmt::IgnoreErrorProgram { prog, .. } => {
127 if let Some(name) = prog.prog.needs_write_lock() {
128 collector.insert(name);
129 }
130 }
131 ImperativeStmt::Return { returns, .. } => {
132 for ret in returns {
133 if let Left(prog) = ret {
134 if let Some(name) = prog.prog.needs_write_lock() {
135 collector.insert(name);
136 }
137 }
138 }
139 }
140 ImperativeStmt::If {
141 condition,
142 then_branch,
143 else_branch,
144 ..
145 } => {
146 if let ImperativeCondition::Right(prog) = condition {
147 if let Some(name) = prog.prog.needs_write_lock() {
148 collector.insert(name);
149 }
150 }
151 for prog in then_branch.iter().chain(else_branch.iter()) {
152 prog.needs_write_locks(collector);
153 }
154 }
155 ImperativeStmt::Loop { body, .. } => {
156 for prog in body {
157 prog.needs_write_locks(collector);
158 }
159 }
160 ImperativeStmt::TempDebug { .. }
161 | ImperativeStmt::Break { .. }
162 | ImperativeStmt::Continue { .. }
163 | ImperativeStmt::TempSwap { .. } => {}
164 ImperativeStmt::SysOp { sysop } => match &sysop.sysop {
165 SysOp::RemoveRelation(rels) => {
166 for rel in rels {
167 collector.insert(rel.name.clone());
168 }
169 }
170 SysOp::RenameRelation(renames) => {
171 for (old, new) in renames {
172 collector.insert(old.name.clone());
173 collector.insert(new.name.clone());
174 }
175 }
176 SysOp::CreateIndex(symb, subs, _) => {
177 collector.insert(symb.name.clone());
178 collector.insert(SmartString::from(format!("{}:{}", symb.name, subs.name)));
179 }
180 SysOp::CreateVectorIndex(m) => {
181 collector.insert(m.base_relation.clone());
182 collector.insert(SmartString::from(format!(
183 "{}:{}",
184 m.base_relation, m.index_name
185 )));
186 }
187 SysOp::CreateFtsIndex(m) => {
188 collector.insert(m.base_relation.clone());
189 collector.insert(SmartString::from(format!(
190 "{}:{}",
191 m.base_relation, m.index_name
192 )));
193 }
194 SysOp::CreateMinHashLshIndex(m) => {
195 collector.insert(m.base_relation.clone());
196 collector.insert(SmartString::from(format!(
197 "{}:{}",
198 m.base_relation, m.index_name
199 )));
200 }
201 SysOp::RemoveIndex(rel, idx) => {
202 collector.insert(SmartString::from(format!("{}:{}", rel.name, idx.name)));
203 }
204 _ => {}
205 },
206 }
207 }
208}
209
210impl CozoScript {
211 pub(crate) fn get_single_program(self) -> Result<InputProgram> {
212 #[derive(Debug, Error, Diagnostic)]
213 #[error("expect script to contain only a single program")]
214 #[diagnostic(code(parser::expect_singleton))]
215 struct ExpectSingleProgram;
216 match self {
217 CozoScript::Single(s) => Ok(s),
218 CozoScript::Imperative(_) | CozoScript::Sys(_) => {
219 bail!(ExpectSingleProgram)
220 }
221 }
222 }
223}
224
225#[derive(
227 Eq, PartialEq, Debug, serde_derive::Serialize, serde_derive::Deserialize, Copy, Clone, Default,
228)]
229pub struct SourceSpan(pub usize, pub usize);
230
231impl Display for SourceSpan {
232 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
233 write!(f, "{}..{}", self.0, self.0 + self.1)
234 }
235}
236
237impl SourceSpan {
238 pub(crate) fn merge(self, other: Self) -> Self {
239 let s1 = self.0;
240 let e1 = self.0 + self.1;
241 let s2 = other.0;
242 let e2 = other.0 + other.1;
243 let s = min(s1, s2);
244 let e = max(e1, e2);
245 Self(s, e - s)
246 }
247}
248
249impl From<&'_ SourceSpan> for miette::SourceSpan {
250 fn from(s: &'_ SourceSpan) -> Self {
251 miette::SourceSpan::new(s.0.into(), s.1.into())
252 }
253}
254
255impl From<SourceSpan> for miette::SourceSpan {
256 fn from(s: SourceSpan) -> Self {
257 miette::SourceSpan::new(s.0.into(), s.1.into())
258 }
259}
260
261#[derive(thiserror::Error, Diagnostic, Debug)]
262#[error("The query parser has encountered unexpected input / end of input at {span}")]
263#[diagnostic(code(parser::pest))]
264pub(crate) struct ParseError {
265 #[label]
266 pub(crate) span: SourceSpan,
267}
268
269pub(crate) fn parse_type(src: &str) -> Result<NullableColType> {
270 let parsed = CozoScriptParser::parse(Rule::col_type_with_term, src)
271 .into_diagnostic()?
272 .next()
273 .unwrap();
274 parse_nullable_type(parsed.into_inner().next().unwrap())
275}
276
277pub(crate) fn parse_expressions(
278 src: &str,
279 param_pool: &BTreeMap<String, DataValue>,
280) -> Result<Expr> {
281 let parsed = CozoScriptParser::parse(Rule::expression_script, src)
282 .map_err(|err| {
283 let span = match err.location {
284 InputLocation::Pos(p) => SourceSpan(p, 0),
285 InputLocation::Span((start, end)) => SourceSpan(start, end - start),
286 };
287 ParseError { span }
288 })?
289 .next()
290 .unwrap();
291
292 build_expr(parsed.into_inner().next().unwrap(), param_pool)
293}
294
295pub fn parse_script(
307 src: &str,
308 param_pool: &BTreeMap<String, DataValue>,
309 fixed_rules: &BTreeMap<String, Arc<Box<dyn FixedRule>>>,
310 cur_vld: ValidityTs,
311) -> Result<CozoScript> {
312 let parsed = CozoScriptParser::parse(Rule::script, src)
313 .map_err(|err| {
314 let span = match err.location {
315 InputLocation::Pos(p) => SourceSpan(p, 0),
316 InputLocation::Span((start, end)) => SourceSpan(start, end - start),
317 };
318 ParseError { span }
319 })?
320 .next()
321 .unwrap();
322 Ok(match parsed.as_rule() {
323 Rule::query_script => {
324 let q = parse_query(parsed.into_inner(), param_pool, fixed_rules, cur_vld)?;
325 CozoScript::Single(q)
326 }
327 Rule::imperative_script => {
328 let p = parse_imperative_block(parsed, param_pool, fixed_rules, cur_vld)?;
329 CozoScript::Imperative(p)
330 }
331
332 Rule::sys_script => CozoScript::Sys(parse_sys(
333 parsed.into_inner(),
334 param_pool,
335 fixed_rules,
336 cur_vld,
337 )?),
338 _ => unreachable!(),
339 })
340}
341
342trait ExtractSpan {
343 fn extract_span(&self) -> SourceSpan;
344}
345
346impl ExtractSpan for Pair<'_> {
347 fn extract_span(&self) -> SourceSpan {
348 let span = self.as_span();
349 let start = span.start();
350 let end = span.end();
351 SourceSpan(start, end - start)
352 }
353}