cozo/runtime/
imperative.rs

1/*
2 * Copyright 2022, The Cozo Project Authors.
3 *
4 * This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
5 * If a copy of the MPL was not distributed with this file,
6 * You can obtain one at https://mozilla.org/MPL/2.0/.
7 */
8
9use std::collections::{BTreeMap, BTreeSet};
10use std::sync::atomic::Ordering;
11
12use either::{Either, Left, Right};
13use itertools::Itertools;
14use miette::{bail, Diagnostic, Report, Result};
15use smartstring::{LazyCompact, SmartString};
16use thiserror::Error;
17
18use crate::data::program::RelationOp;
19use crate::data::relation::{ColType, ColumnDef, NullableColType, StoredRelationMetadata};
20use crate::data::symb::Symbol;
21use crate::parse::{ImperativeCondition, ImperativeProgram, ImperativeStmt, SourceSpan};
22use crate::runtime::callback::CallbackCollector;
23use crate::runtime::db::{seconds_since_the_epoch, RunningQueryCleanup, RunningQueryHandle};
24use crate::runtime::relation::InputRelationHandle;
25use crate::runtime::transact::SessionTx;
26use crate::{DataValue, Db, NamedRows, Poison, Storage, ValidityTs};
27
28enum ControlCode {
29    Termination(NamedRows),
30    Break(Option<SmartString<LazyCompact>>, SourceSpan),
31    Continue(Option<SmartString<LazyCompact>>, SourceSpan),
32}
33
34impl<'s, S: Storage<'s>> Db<S> {
35    fn execute_imperative_condition(
36        &'s self,
37        p: &ImperativeCondition,
38        tx: &mut SessionTx<'_>,
39        cleanups: &mut Vec<(Vec<u8>, Vec<u8>)>,
40        cur_vld: ValidityTs,
41        callback_targets: &BTreeSet<SmartString<LazyCompact>>,
42        callback_collector: &mut CallbackCollector,
43    ) -> Result<bool> {
44        let res = match p {
45            Left(rel) => {
46                let relation = tx.get_relation(rel, false)?;
47                relation.as_named_rows(tx)?
48            }
49            Right(p) => self.execute_single_program(
50                p.prog.clone(),
51                tx,
52                cleanups,
53                cur_vld,
54                callback_targets,
55                callback_collector,
56            )?,
57        };
58        if let Right(pg) = &p {
59            if let Some(store_as) = &pg.store_as {
60                tx.script_store_as_relation(self, store_as, &res, cur_vld)?;
61            }
62        }
63        Ok(!res.rows.is_empty())
64    }
65
66    fn execute_imperative_stmts(
67        &'s self,
68        ps: &ImperativeProgram,
69        tx: &mut SessionTx<'_>,
70        cleanups: &mut Vec<(Vec<u8>, Vec<u8>)>,
71        cur_vld: ValidityTs,
72        callback_targets: &BTreeSet<SmartString<LazyCompact>>,
73        callback_collector: &mut CallbackCollector,
74        poison: &Poison,
75        readonly: bool,
76    ) -> Result<Either<NamedRows, ControlCode>> {
77        let mut ret = NamedRows::default();
78        for p in ps {
79            poison.check()?;
80            match p {
81                ImperativeStmt::Break { target, span, .. } => {
82                    return Ok(Right(ControlCode::Break(target.clone(), *span)));
83                }
84                ImperativeStmt::Continue { target, span, .. } => {
85                    return Ok(Right(ControlCode::Continue(target.clone(), *span)));
86                }
87                ImperativeStmt::Return { returns } => {
88                    if returns.is_empty() {
89                        return Ok(Right(ControlCode::Termination(NamedRows::default())));
90                    }
91                    let mut current = None;
92                    for nxt in returns.iter().rev() {
93                        let mut nr = match nxt {
94                            Left(prog) => self.execute_single_program(
95                                prog.prog.clone(),
96                                tx,
97                                cleanups,
98                                cur_vld,
99                                callback_targets,
100                                callback_collector,
101                            )?,
102                            Right(rel) => {
103                                let relation = tx.get_relation(rel, false)?;
104                                relation.as_named_rows(tx)?
105                            }
106                        };
107                        if let Left(pg) = nxt {
108                            if let Some(store_as) = &pg.store_as {
109                                tx.script_store_as_relation(self, store_as, &nr, cur_vld)?;
110                            }
111                        }
112                        nr.next = current;
113                        current = Some(Box::new(nr))
114                    }
115                    return Ok(Right(ControlCode::Termination(*current.unwrap())));
116                }
117                ImperativeStmt::TempDebug { temp, .. } => {
118                    let relation = tx.get_relation(temp, false)?;
119                    println!("{}: {:?}", temp, relation.as_named_rows(tx)?);
120                    ret = NamedRows::default();
121                }
122                ImperativeStmt::SysOp { sysop, .. } => {
123                    ret = self.run_sys_op_with_tx(tx, &sysop.sysop, readonly, true)?;
124                    if let Some(store_as) = &sysop.store_as {
125                        tx.script_store_as_relation(self, store_as, &ret, cur_vld)?;
126                    }
127                }
128                ImperativeStmt::Program { prog, .. } => {
129                    ret = self.execute_single_program(
130                        prog.prog.clone(),
131                        tx,
132                        cleanups,
133                        cur_vld,
134                        callback_targets,
135                        callback_collector,
136                    )?;
137                    if let Some(store_as) = &prog.store_as {
138                        tx.script_store_as_relation(self, store_as, &ret, cur_vld)?;
139                    }
140                }
141                ImperativeStmt::IgnoreErrorProgram { prog, .. } => {
142                    match self.execute_single_program(
143                        prog.prog.clone(),
144                        tx,
145                        cleanups,
146                        cur_vld,
147                        callback_targets,
148                        callback_collector,
149                    ) {
150                        Ok(res) => {
151                            if let Some(store_as) = &prog.store_as {
152                                tx.script_store_as_relation(self, store_as, &res, cur_vld)?;
153                            }
154                            ret = res
155                        }
156                        Err(_) => {
157                            ret = NamedRows::new(
158                                vec!["status".to_string()],
159                                vec![vec![DataValue::from("FAILED")]],
160                            )
161                        }
162                    }
163                }
164                ImperativeStmt::If {
165                    condition,
166                    then_branch,
167                    else_branch,
168                    negated,
169                    ..
170                } => {
171                    let cond_val = self.execute_imperative_condition(
172                        condition,
173                        tx,
174                        cleanups,
175                        cur_vld,
176                        callback_targets,
177                        callback_collector,
178                    )?;
179                    let cond_val = if *negated { !cond_val } else { cond_val };
180                    let to_execute = if cond_val { then_branch } else { else_branch };
181                    match self.execute_imperative_stmts(
182                        to_execute,
183                        tx,
184                        cleanups,
185                        cur_vld,
186                        callback_targets,
187                        callback_collector,
188                        poison,
189                        readonly,
190                    )? {
191                        Left(rows) => {
192                            ret = rows;
193                        }
194                        Right(ctrl) => return Ok(Right(ctrl)),
195                    }
196                }
197                ImperativeStmt::Loop { label, body, .. } => {
198                    ret = Default::default();
199                    loop {
200                        poison.check()?;
201
202                        match self.execute_imperative_stmts(
203                            body,
204                            tx,
205                            cleanups,
206                            cur_vld,
207                            callback_targets,
208                            callback_collector,
209                            poison,
210                            readonly,
211                        )? {
212                            Left(_) => {}
213                            Right(ctrl) => match ctrl {
214                                ControlCode::Termination(ret) => {
215                                    return Ok(Right(ControlCode::Termination(ret)))
216                                }
217                                ControlCode::Break(break_label, span) => {
218                                    if break_label.is_none() || break_label == *label {
219                                        break;
220                                    } else {
221                                        return Ok(Right(ControlCode::Break(break_label, span)));
222                                    }
223                                }
224                                ControlCode::Continue(cont_label, span) => {
225                                    if cont_label.is_none() || cont_label == *label {
226                                        continue;
227                                    } else {
228                                        return Ok(Right(ControlCode::Continue(cont_label, span)));
229                                    }
230                                }
231                            },
232                        }
233                    }
234                }
235                ImperativeStmt::TempSwap { left, right, .. } => {
236                    tx.rename_temp_relation(
237                        Symbol::new(left.clone(), Default::default()),
238                        Symbol::new(SmartString::from("_*temp*"), Default::default()),
239                    )?;
240                    tx.rename_temp_relation(
241                        Symbol::new(right.clone(), Default::default()),
242                        Symbol::new(left.clone(), Default::default()),
243                    )?;
244                    tx.rename_temp_relation(
245                        Symbol::new(SmartString::from("_*temp*"), Default::default()),
246                        Symbol::new(right.clone(), Default::default()),
247                    )?;
248                    ret = NamedRows::default();
249                    break;
250                }
251            }
252        }
253        Ok(Left(ret))
254    }
255    pub(crate) fn execute_imperative(
256        &'s self,
257        cur_vld: ValidityTs,
258        ps: &ImperativeProgram,
259        readonly: bool,
260    ) -> Result<NamedRows, Report> {
261        let mut callback_collector = BTreeMap::new();
262        let mut write_lock_names = BTreeSet::new();
263        for p in ps {
264            p.needs_write_locks(&mut write_lock_names);
265        }
266        if readonly && !write_lock_names.is_empty() {
267            bail!("Read-only imperative program attempted to acquire write locks");
268        }
269        let is_write = !write_lock_names.is_empty();
270        let write_lock = self.obtain_relation_locks(write_lock_names.iter());
271        let _write_lock_guards = write_lock.iter().map(|l| l.read().unwrap()).collect_vec();
272
273        let callback_targets = if is_write {
274            self.current_callback_targets()
275        } else {
276            Default::default()
277        };
278        let mut cleanups: Vec<(Vec<u8>, Vec<u8>)> = vec![];
279        let ret;
280        {
281            let mut tx = if is_write {
282                self.transact_write()?
283            } else {
284                self.transact()?
285            };
286
287            let poison = Poison::default();
288            let qid = self.queries_count.fetch_add(1, Ordering::AcqRel);
289            let since_the_epoch = seconds_since_the_epoch()?;
290
291            let q_handle = RunningQueryHandle {
292                started_at: since_the_epoch,
293                poison: poison.clone(),
294            };
295            self.running_queries.lock().unwrap().insert(qid, q_handle);
296            let _guard = RunningQueryCleanup {
297                id: qid,
298                running_queries: self.running_queries.clone(),
299            };
300
301            match self.execute_imperative_stmts(
302                ps,
303                &mut tx,
304                &mut cleanups,
305                cur_vld,
306                &callback_targets,
307                &mut callback_collector,
308                &poison,
309                readonly,
310            )? {
311                Left(res) => ret = res,
312                Right(ctrl) => match ctrl {
313                    ControlCode::Termination(res) => {
314                        ret = res;
315                    }
316                    ControlCode::Break(_, span) | ControlCode::Continue(_, span) => {
317                        #[derive(Debug, Error, Diagnostic)]
318                        #[error("control flow has nowhere to go")]
319                        #[diagnostic(code(eval::dangling_ctrl_flow))]
320                        struct DanglingControlFlow(#[label] SourceSpan);
321
322                        bail!(DanglingControlFlow(span))
323                    }
324                },
325            }
326
327            for (lower, upper) in cleanups {
328                tx.store_tx.del_range_from_persisted(&lower, &upper)?;
329            }
330
331            tx.commit_tx()?;
332        }
333        #[cfg(not(target_arch = "wasm32"))]
334        if !callback_collector.is_empty() {
335            self.send_callbacks(callback_collector)
336        }
337
338        Ok(ret)
339    }
340}
341
342impl SessionTx<'_> {
343    fn script_store_as_relation<'s, S: Storage<'s>>(
344        &mut self,
345        db: &Db<S>,
346        name: &str,
347        rels: &NamedRows,
348        cur_vld: ValidityTs,
349    ) -> Result<()> {
350        let mut key_bindings = vec![];
351        for k in rels.headers.iter() {
352            let k = k.replace('(', "_").replace(')', "");
353            let k = Symbol::new(k.clone(), Default::default());
354            if key_bindings.contains(&k) {
355                bail!(
356                    "Duplicate variable name {}, please use distinct variables in `as` construct.",
357                    k
358                );
359            }
360            key_bindings.push(k);
361        }
362        let keys = key_bindings
363            .iter()
364            .map(|s| ColumnDef {
365                name: s.name.clone(),
366                typing: NullableColType {
367                    coltype: ColType::Any,
368                    nullable: true,
369                },
370                default_gen: None,
371            })
372            .collect_vec();
373
374        let meta = InputRelationHandle {
375            name: Symbol::new(name, Default::default()),
376            metadata: StoredRelationMetadata {
377                keys,
378                non_keys: vec![],
379            },
380            key_bindings,
381            dep_bindings: vec![],
382            span: Default::default(),
383        };
384        let headers = meta.key_bindings.clone();
385        self.execute_relation(
386            db,
387            rels.rows.iter().cloned(),
388            RelationOp::Replace,
389            &meta,
390            &headers,
391            cur_vld,
392            &Default::default(),
393            &mut Default::default(),
394            true,
395            "",
396        )?;
397        Ok(())
398    }
399}