spacetimedb_vm/
expr.rs

1use crate::errors::{ErrorKind, ErrorLang};
2use crate::operator::{OpCmp, OpLogic, OpQuery};
3use crate::relation::{MemTable, RelValue};
4use arrayvec::ArrayVec;
5use core::slice::from_ref;
6use derive_more::From;
7use itertools::Itertools;
8use smallvec::SmallVec;
9use spacetimedb_data_structures::map::{HashSet, IntMap};
10use spacetimedb_lib::db::auth::{StAccess, StTableType};
11use spacetimedb_lib::Identity;
12use spacetimedb_primitives::*;
13use spacetimedb_sats::satn::Satn;
14use spacetimedb_sats::{AlgebraicType, AlgebraicValue, ProductValue};
15use spacetimedb_schema::def::error::{AuthError, RelationError};
16use spacetimedb_schema::relation::{ColExpr, DbTable, FieldName, Header};
17use spacetimedb_schema::schema::TableSchema;
18use std::borrow::Cow;
19use std::cmp::Reverse;
20use std::collections::btree_map::Entry;
21use std::collections::BTreeMap;
22use std::ops::Bound;
23use std::sync::Arc;
24use std::{fmt, iter, mem};
25
26/// Trait for checking if the `caller` have access to `Self`
27pub trait AuthAccess {
28    fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError>;
29}
30
31#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, From)]
32pub enum FieldExpr {
33    Name(FieldName),
34    Value(AlgebraicValue),
35}
36
37impl FieldExpr {
38    pub fn strip_table(self) -> ColExpr {
39        match self {
40            Self::Name(field) => ColExpr::Col(field.col),
41            Self::Value(value) => ColExpr::Value(value),
42        }
43    }
44
45    pub fn name_to_col(self, head: &Header) -> Result<ColExpr, RelationError> {
46        match self {
47            Self::Value(val) => Ok(ColExpr::Value(val)),
48            Self::Name(field) => head.column_pos_or_err(field).map(ColExpr::Col),
49        }
50    }
51}
52
53impl fmt::Display for FieldExpr {
54    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55        match self {
56            FieldExpr::Name(x) => write!(f, "{x}"),
57            FieldExpr::Value(x) => write!(f, "{}", x.to_satn()),
58        }
59    }
60}
61
62#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, From)]
63pub enum FieldOp {
64    #[from]
65    Field(FieldExpr),
66    Cmp {
67        op: OpQuery,
68        lhs: Box<FieldOp>,
69        rhs: Box<FieldOp>,
70    },
71}
72
73type FieldOpFlat = SmallVec<[FieldOp; 1]>;
74
75impl FieldOp {
76    pub fn new(op: OpQuery, lhs: Self, rhs: Self) -> Self {
77        Self::Cmp {
78            op,
79            lhs: Box::new(lhs),
80            rhs: Box::new(rhs),
81        }
82    }
83
84    pub fn cmp(field: impl Into<FieldName>, op: OpCmp, value: impl Into<AlgebraicValue>) -> Self {
85        Self::new(
86            OpQuery::Cmp(op),
87            Self::Field(FieldExpr::Name(field.into())),
88            Self::Field(FieldExpr::Value(value.into())),
89        )
90    }
91
92    pub fn names_to_cols(self, head: &Header) -> Result<ColumnOp, RelationError> {
93        match self {
94            Self::Field(field) => field.name_to_col(head).map(ColumnOp::from),
95            Self::Cmp { op, lhs, rhs } => {
96                let lhs = lhs.names_to_cols(head)?;
97                let rhs = rhs.names_to_cols(head)?;
98                Ok(ColumnOp::new(op, lhs, rhs))
99            }
100        }
101    }
102
103    /// Flattens a nested conjunction of AND expressions.
104    ///
105    /// For example, `a = 1 AND b = 2 AND c = 3` becomes `[a = 1, b = 2, c = 3]`.
106    ///
107    /// This helps with splitting the kinds of `queries`,
108    /// that *could* be answered by a `index`,
109    /// from the ones that need to be executed with a `scan`.
110    pub fn flatten_ands(self) -> FieldOpFlat {
111        fn fill_vec(buf: &mut FieldOpFlat, op: FieldOp) {
112            match op {
113                FieldOp::Cmp {
114                    op: OpQuery::Logic(OpLogic::And),
115                    lhs,
116                    rhs,
117                } => {
118                    fill_vec(buf, *lhs);
119                    fill_vec(buf, *rhs);
120                }
121                op => buf.push(op),
122            }
123        }
124        let mut buf = SmallVec::new();
125        fill_vec(&mut buf, self);
126        buf
127    }
128}
129
130impl fmt::Display for FieldOp {
131    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132        match self {
133            Self::Field(x) => {
134                write!(f, "{x}")
135            }
136            Self::Cmp { op, lhs, rhs } => {
137                write!(f, "{lhs} {op} {rhs}")
138            }
139        }
140    }
141}
142
143#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, From)]
144pub enum ColumnOp {
145    /// The value is the the column at `to_index(col)` in the row, i.e., `row.read_column(to_index(col))`.
146    #[from]
147    Col(ColId),
148    /// The value is the embedded value.
149    #[from]
150    Val(AlgebraicValue),
151    /// The value is `eval_cmp(cmp, row.read_column(to_index(lhs)), rhs)`.
152    /// This is an optimized version of `Cmp`, avoiding one depth of nesting.
153    ColCmpVal {
154        lhs: ColId,
155        cmp: OpCmp,
156        rhs: AlgebraicValue,
157    },
158    /// The value is `eval_cmp(cmp, eval(row, lhs), eval(row, rhs))`.
159    Cmp {
160        lhs: Box<ColumnOp>,
161        cmp: OpCmp,
162        rhs: Box<ColumnOp>,
163    },
164    /// Let `conds = eval(row, operands_i)`.
165    /// For `op = OpLogic::And`, the value is `all(conds)`.
166    /// For `op = OpLogic::Or`, the value is `any(conds)`.
167    Log { op: OpLogic, operands: Box<[ColumnOp]> },
168}
169
170impl ColumnOp {
171    pub fn new(op: OpQuery, lhs: Self, rhs: Self) -> Self {
172        match op {
173            OpQuery::Cmp(cmp) => match (lhs, rhs) {
174                (ColumnOp::Col(lhs), ColumnOp::Val(rhs)) => Self::cmp(lhs, cmp, rhs),
175                (lhs, rhs) => Self::Cmp {
176                    lhs: Box::new(lhs),
177                    cmp,
178                    rhs: Box::new(rhs),
179                },
180            },
181            OpQuery::Logic(op) => Self::Log {
182                op,
183                operands: [lhs, rhs].into(),
184            },
185        }
186    }
187
188    pub fn cmp(col: impl Into<ColId>, cmp: OpCmp, val: impl Into<AlgebraicValue>) -> Self {
189        let lhs = col.into();
190        let rhs = val.into();
191        Self::ColCmpVal { lhs, cmp, rhs }
192    }
193
194    /// Returns a new op where `lhs` and `rhs` are logically AND-ed together.
195    fn and(lhs: Self, rhs: Self) -> Self {
196        let ands = |operands| {
197            let op = OpLogic::And;
198            Self::Log { op, operands }
199        };
200
201        match (lhs, rhs) {
202            // Merge a pair of ⋀ into a single ⋀.
203            (
204                Self::Log {
205                    op: OpLogic::And,
206                    operands: lhs,
207                },
208                Self::Log {
209                    op: OpLogic::And,
210                    operands: rhs,
211                },
212            ) => {
213                let mut operands = Vec::from(lhs);
214                operands.append(&mut Vec::from(rhs));
215                ands(operands.into())
216            }
217            // Merge ⋀ with a single operand.
218            (
219                Self::Log {
220                    op: OpLogic::And,
221                    operands: lhs,
222                },
223                rhs,
224            ) => {
225                let mut operands = Vec::from(lhs);
226                operands.push(rhs);
227                ands(operands.into())
228            }
229            // And together lhs and rhs.
230            (lhs, rhs) => ands([lhs, rhs].into()),
231        }
232    }
233
234    /// Returns an op where `col_i op value_i` are all `AND`ed together.
235    fn and_cmp(op: OpCmp, cols: &ColList, value: AlgebraicValue) -> Self {
236        let cmp = |(col, value): (ColId, _)| Self::cmp(col, op, value);
237
238        // For singleton constraints, the `value` must be used directly.
239        if let Some(head) = cols.as_singleton() {
240            return cmp((head, value));
241        }
242
243        // Otherwise, pair column ids and product fields together.
244        let operands = cols.iter().zip(value.into_product().unwrap()).map(cmp).collect();
245        Self::Log {
246            op: OpLogic::And,
247            operands,
248        }
249    }
250
251    /// Returns an op where `cols` must be within bounds.
252    /// This handles both the case of single-col bounds and multi-col bounds.
253    fn from_op_col_bounds(cols: &ColList, bounds: (Bound<AlgebraicValue>, Bound<AlgebraicValue>)) -> Self {
254        let (cmp, value) = match bounds {
255            // Equality; field <= value && field >= value <=> field = value
256            (Bound::Included(a), Bound::Included(b)) if a == b => (OpCmp::Eq, a),
257            // Inclusive lower bound => field >= value
258            (Bound::Included(value), Bound::Unbounded) => (OpCmp::GtEq, value),
259            // Exclusive lower bound => field > value
260            (Bound::Excluded(value), Bound::Unbounded) => (OpCmp::Gt, value),
261            // Inclusive upper bound => field <= value
262            (Bound::Unbounded, Bound::Included(value)) => (OpCmp::LtEq, value),
263            // Exclusive upper bound => field < value
264            (Bound::Unbounded, Bound::Excluded(value)) => (OpCmp::Lt, value),
265            (Bound::Unbounded, Bound::Unbounded) => unreachable!(),
266            (lower_bound, upper_bound) => {
267                let lhs = Self::from_op_col_bounds(cols, (lower_bound, Bound::Unbounded));
268                let rhs = Self::from_op_col_bounds(cols, (Bound::Unbounded, upper_bound));
269                return ColumnOp::and(lhs, rhs);
270            }
271        };
272        ColumnOp::and_cmp(cmp, cols, value)
273    }
274
275    /// Converts `self` to the lhs `ColId` and the `OpCmp` if this is a comparison.
276    fn as_col_cmp(&self) -> Option<(ColId, OpCmp)> {
277        match self {
278            Self::ColCmpVal { lhs, cmp, rhs: _ } => Some((*lhs, *cmp)),
279            Self::Cmp { lhs, cmp, rhs: _ } => match &**lhs {
280                ColumnOp::Col(col) => Some((*col, *cmp)),
281                _ => None,
282            },
283            _ => None,
284        }
285    }
286
287    /// Evaluate `self` where `ColId`s are translated to values by indexing into `row`.
288    fn eval<'a>(&'a self, row: &'a RelValue<'_>) -> Cow<'a, AlgebraicValue> {
289        let into = |b| Cow::Owned(AlgebraicValue::Bool(b));
290
291        match self {
292            Self::Col(col) => row.read_column(col.idx()).unwrap(),
293            Self::Val(val) => Cow::Borrowed(val),
294            Self::ColCmpVal { lhs, cmp, rhs } => into(Self::eval_cmp_col_val(row, *cmp, *lhs, rhs)),
295            Self::Cmp { lhs, cmp, rhs } => into(Self::eval_cmp(row, *cmp, lhs, rhs)),
296            Self::Log { op, operands } => into(Self::eval_log(row, *op, operands)),
297        }
298    }
299
300    /// Evaluate `self` to a `bool` where `ColId`s are translated to values by indexing into `row`.
301    pub fn eval_bool(&self, row: &RelValue<'_>) -> bool {
302        match self {
303            Self::Col(col) => *row.read_column(col.idx()).unwrap().as_bool().unwrap(),
304            Self::Val(val) => *val.as_bool().unwrap(),
305            Self::ColCmpVal { lhs, cmp, rhs } => Self::eval_cmp_col_val(row, *cmp, *lhs, rhs),
306            Self::Cmp { lhs, cmp, rhs } => Self::eval_cmp(row, *cmp, lhs, rhs),
307            Self::Log { op, operands } => Self::eval_log(row, *op, operands),
308        }
309    }
310
311    /// Evaluates `lhs cmp rhs` according to `Ord for AlgebraicValue`.
312    fn eval_op_cmp(cmp: OpCmp, lhs: &AlgebraicValue, rhs: &AlgebraicValue) -> bool {
313        match cmp {
314            OpCmp::Eq => lhs == rhs,
315            OpCmp::NotEq => lhs != rhs,
316            OpCmp::Lt => lhs < rhs,
317            OpCmp::LtEq => lhs <= rhs,
318            OpCmp::Gt => lhs > rhs,
319            OpCmp::GtEq => lhs >= rhs,
320        }
321    }
322
323    /// Evaluates `lhs` to an [`AlgebraicValue`] and runs the comparison `lhs_av op rhs`.
324    fn eval_cmp_col_val(row: &RelValue<'_>, cmp: OpCmp, lhs: ColId, rhs: &AlgebraicValue) -> bool {
325        let lhs = row.read_column(lhs.idx()).unwrap();
326        Self::eval_op_cmp(cmp, &lhs, rhs)
327    }
328
329    /// Evaluates `lhs` and `rhs` to [`AlgebraicValue`]s
330    /// and then runs the comparison `cmp` on them,
331    /// returning the final `bool` result.
332    fn eval_cmp(row: &RelValue<'_>, cmp: OpCmp, lhs: &Self, rhs: &Self) -> bool {
333        let lhs = lhs.eval(row);
334        let rhs = rhs.eval(row);
335        Self::eval_op_cmp(cmp, &lhs, &rhs)
336    }
337
338    /// Evaluates if
339    /// - `op = OpLogic::And` the conjunctions (`⋀`) of `opers`
340    /// - `op = OpLogic::Or` the disjunctions (`⋁`) of `opers`
341    fn eval_log(row: &RelValue<'_>, op: OpLogic, opers: &[ColumnOp]) -> bool {
342        match op {
343            OpLogic::And => opers.iter().all(|o| o.eval_bool(row)),
344            OpLogic::Or => opers.iter().any(|o| o.eval_bool(row)),
345        }
346    }
347}
348
349impl fmt::Display for ColumnOp {
350    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
351        match self {
352            Self::Col(col) => write!(f, "{col}"),
353            Self::Val(val) => write!(f, "{}", val.to_satn()),
354            Self::ColCmpVal { lhs, cmp, rhs } => write!(f, "{lhs} {cmp} {}", rhs.to_satn()),
355            Self::Cmp { cmp, lhs, rhs } => write!(f, "{lhs} {cmp} {rhs}"),
356            Self::Log { op, operands } => write!(f, "{}", operands.iter().format((*op).into())),
357        }
358    }
359}
360
361impl From<ColExpr> for ColumnOp {
362    fn from(ce: ColExpr) -> Self {
363        match ce {
364            ColExpr::Col(c) => c.into(),
365            ColExpr::Value(v) => v.into(),
366        }
367    }
368}
369
370impl From<Query> for Option<ColumnOp> {
371    fn from(value: Query) -> Self {
372        match value {
373            Query::IndexScan(op) => Some(ColumnOp::from_op_col_bounds(&op.columns, op.bounds)),
374            Query::Select(op) => Some(op),
375            _ => None,
376        }
377    }
378}
379
380/// An identifier for a data source (i.e. a table) in a query plan.
381///
382/// When compiling a query plan, rather than embedding the inputs in the plan,
383/// we annotate each input with a `SourceId`, and the compiled plan refers to its inputs by id.
384/// This allows the plan to be re-used with distinct inputs,
385/// assuming the inputs obey the same schema.
386///
387/// Note that re-using a query plan is only a good idea
388/// if the new inputs are similar to those used for compilation
389/// in terms of cardinality and distribution.
390#[derive(Debug, Copy, Clone, PartialEq, Eq, From, Hash)]
391pub struct SourceId(pub usize);
392
393/// Types that relate [`SourceId`]s to their in-memory tables.
394///
395/// Rather than embedding tables in query plans, we store a [`SourceExpr::InMemory`],
396/// which contains the information necessary for optimization along with a `SourceId`.
397/// Query execution then executes the plan, and when it encounters a `SourceExpr::InMemory`,
398/// retrieves the `Self::Source` table from the corresponding provider.
399/// This allows query plans to be re-used, though each execution might require a new provider.
400///
401/// An in-memory table `Self::Source` is a type capable of producing [`RelValue<'a>`]s.
402/// The general form of this is `Iterator<Item = RelValue<'a>>`.
403/// Depending on the situation, this could be e.g.,
404/// - [`MemTable`], producing [`RelValue::Projection`],
405/// - `&'a [ProductValue]` producing [`RelValue::ProjRef`].
406pub trait SourceProvider<'a> {
407    /// The type of in-memory tables that this provider uses.
408    type Source: 'a + IntoIterator<Item = RelValue<'a>>;
409
410    /// Retrieve the `Self::Source` associated with `id`, if any.
411    ///
412    /// Taking the same `id` a second time may or may not yield the same source.
413    /// Callers should not assume that a generic provider will yield it more than once.
414    /// This means that a query plan may not include multiple references to the same [`SourceId`].
415    ///
416    /// Implementations are also not obligated to inspect `id`, e.g., if there's only one option.
417    fn take_source(&mut self, id: SourceId) -> Option<Self::Source>;
418}
419
420impl<'a, I: 'a + IntoIterator<Item = RelValue<'a>>, F: FnMut(SourceId) -> Option<I>> SourceProvider<'a> for F {
421    type Source = I;
422    fn take_source(&mut self, id: SourceId) -> Option<Self::Source> {
423        self(id)
424    }
425}
426
427impl<'a, I: 'a + IntoIterator<Item = RelValue<'a>>> SourceProvider<'a> for Option<I> {
428    type Source = I;
429    fn take_source(&mut self, _: SourceId) -> Option<Self::Source> {
430        self.take()
431    }
432}
433
434pub struct NoInMemUsed;
435
436impl<'a> SourceProvider<'a> for NoInMemUsed {
437    type Source = iter::Empty<RelValue<'a>>;
438    fn take_source(&mut self, _: SourceId) -> Option<Self::Source> {
439        None
440    }
441}
442
443/// A [`SourceProvider`] backed by an `ArrayVec`.
444///
445/// Internally, the `SourceSet` stores an `Option<T>` for each planned [`SourceId`]
446/// which are [`Option::take`]n out of the set.
447#[derive(Debug, PartialEq, Eq, Clone)]
448#[repr(transparent)]
449pub struct SourceSet<T, const N: usize>(
450    // Benchmarks showed an improvement in performance
451    // on incr-select by ~10% by not using `Vec<Option<T>>`.
452    ArrayVec<Option<T>, N>,
453);
454
455impl<'a, T: 'a + IntoIterator<Item = RelValue<'a>>, const N: usize> SourceProvider<'a> for SourceSet<T, N> {
456    type Source = T;
457    fn take_source(&mut self, id: SourceId) -> Option<T> {
458        self.take(id)
459    }
460}
461
462impl<T, const N: usize> From<[T; N]> for SourceSet<T, N> {
463    #[inline]
464    fn from(sources: [T; N]) -> Self {
465        Self(sources.map(Some).into())
466    }
467}
468
469impl<T, const N: usize> SourceSet<T, N> {
470    /// Returns an empty source set.
471    pub fn empty() -> Self {
472        Self(ArrayVec::new())
473    }
474
475    /// Get a fresh `SourceId` which can be used as the id for a new entry.
476    fn next_id(&self) -> SourceId {
477        SourceId(self.0.len())
478    }
479
480    /// Insert an entry into this `SourceSet` so it can be used in a query plan,
481    /// and return a [`SourceId`] which can be embedded in that plan.
482    pub fn add(&mut self, table: T) -> SourceId {
483        let source_id = self.next_id();
484        self.0.push(Some(table));
485        source_id
486    }
487
488    /// Extract the entry referred to by `id` from this `SourceSet`,
489    /// leaving a "gap" in its place.
490    ///
491    /// Subsequent calls to `take` on the same `id` will return `None`.
492    pub fn take(&mut self, id: SourceId) -> Option<T> {
493        self.0.get_mut(id.0).map(mem::take).unwrap_or_default()
494    }
495
496    /// Returns the number of slots for [`MemTable`]s in this set.
497    ///
498    /// Calling `self.take_mem_table(...)` or `self.take_table(...)` won't affect this number.
499    pub fn len(&self) -> usize {
500        self.0.len()
501    }
502
503    /// Returns whether this set has any slots for [`MemTable`]s.
504    ///
505    /// Calling `self.take_mem_table(...)` or `self.take_table(...)` won't affect whether the set is empty.
506    pub fn is_empty(&self) -> bool {
507        self.0.is_empty()
508    }
509}
510
511impl<T, const N: usize> std::ops::Index<SourceId> for SourceSet<T, N> {
512    type Output = Option<T>;
513
514    fn index(&self, idx: SourceId) -> &Self::Output {
515        &self.0[idx.0]
516    }
517}
518
519impl<T, const N: usize> std::ops::IndexMut<SourceId> for SourceSet<T, N> {
520    fn index_mut(&mut self, idx: SourceId) -> &mut Self::Output {
521        &mut self.0[idx.0]
522    }
523}
524
525impl<const N: usize> SourceSet<Vec<ProductValue>, N> {
526    /// Insert a [`MemTable`] into this `SourceSet` so it can be used in a query plan,
527    /// and return a [`SourceExpr`] which can be embedded in that plan.
528    pub fn add_mem_table(&mut self, table: MemTable) -> SourceExpr {
529        let id = self.add(table.data);
530        SourceExpr::from_mem_table(table.head, table.table_access, id)
531    }
532}
533
534/// A reference to a table within a query plan,
535/// used as the source for selections, scans, filters and joins.
536#[derive(Debug, Clone, Eq, PartialEq, Hash)]
537pub enum SourceExpr {
538    /// A plan for a "virtual" or projected table.
539    ///
540    /// The actual in-memory table, e.g., [`MemTable`] or `&'a [ProductValue]`
541    /// is not stored within the query plan;
542    /// rather, the `source_id` is an index which corresponds to the table in e.g., a [`SourceSet`].
543    ///
544    /// This allows query plans to be reused by supplying e.g., a new [`SourceSet`].
545    InMemory {
546        source_id: SourceId,
547        header: Arc<Header>,
548        table_type: StTableType,
549        table_access: StAccess,
550    },
551    /// A plan for a database table. Because [`DbTable`] is small and efficiently cloneable,
552    /// no indirection into a [`SourceSet`] is required.
553    DbTable(DbTable),
554}
555
556impl SourceExpr {
557    /// If `self` refers to a [`MemTable`], returns the [`SourceId`] for its location in the plan's [`SourceSet`].
558    ///
559    /// Returns `None` if `self` refers to a [`DbTable`], as [`DbTable`]s are stored directly in the `SourceExpr`,
560    /// rather than indirected through the [`SourceSet`].
561    pub fn source_id(&self) -> Option<SourceId> {
562        if let SourceExpr::InMemory { source_id, .. } = self {
563            Some(*source_id)
564        } else {
565            None
566        }
567    }
568
569    pub fn table_name(&self) -> &str {
570        &self.head().table_name
571    }
572
573    pub fn table_type(&self) -> StTableType {
574        match self {
575            SourceExpr::InMemory { table_type, .. } => *table_type,
576            SourceExpr::DbTable(db_table) => db_table.table_type,
577        }
578    }
579
580    pub fn table_access(&self) -> StAccess {
581        match self {
582            SourceExpr::InMemory { table_access, .. } => *table_access,
583            SourceExpr::DbTable(db_table) => db_table.table_access,
584        }
585    }
586
587    pub fn head(&self) -> &Arc<Header> {
588        match self {
589            SourceExpr::InMemory { header, .. } => header,
590            SourceExpr::DbTable(db_table) => &db_table.head,
591        }
592    }
593
594    pub fn is_mem_table(&self) -> bool {
595        matches!(self, SourceExpr::InMemory { .. })
596    }
597
598    pub fn is_db_table(&self) -> bool {
599        matches!(self, SourceExpr::DbTable(_))
600    }
601
602    pub fn from_mem_table(header: Arc<Header>, table_access: StAccess, id: SourceId) -> Self {
603        SourceExpr::InMemory {
604            source_id: id,
605            header,
606            table_type: StTableType::User,
607            table_access,
608        }
609    }
610
611    pub fn table_id(&self) -> Option<TableId> {
612        if let SourceExpr::DbTable(db_table) = self {
613            Some(db_table.table_id)
614        } else {
615            None
616        }
617    }
618
619    /// If `self` refers to a [`DbTable`], get a reference to it.
620    ///
621    /// Returns `None` if `self` refers to a [`MemTable`].
622    /// In that case, retrieving the [`MemTable`] requires inspecting the plan's corresponding [`SourceSet`]
623    /// via [`SourceSet::take_mem_table`] or [`SourceSet::take_table`].
624    pub fn get_db_table(&self) -> Option<&DbTable> {
625        if let SourceExpr::DbTable(db_table) = self {
626            Some(db_table)
627        } else {
628            None
629        }
630    }
631}
632
633impl From<&TableSchema> for SourceExpr {
634    fn from(value: &TableSchema) -> Self {
635        SourceExpr::DbTable(value.into())
636    }
637}
638
639/// A descriptor for an index semi join operation.
640///
641/// The semantics are those of a semijoin with rows from the index or the probe side being returned.
642#[derive(Debug, Clone, Eq, PartialEq, Hash)]
643pub struct IndexJoin {
644    pub probe_side: QueryExpr,
645    pub probe_col: ColId,
646    pub index_side: SourceExpr,
647    pub index_select: Option<ColumnOp>,
648    pub index_col: ColId,
649    /// If true, returns rows from the `index_side`.
650    /// Otherwise, returns rows from the `probe_side`.
651    pub return_index_rows: bool,
652}
653
654impl From<IndexJoin> for QueryExpr {
655    fn from(join: IndexJoin) -> Self {
656        let source: SourceExpr = if join.return_index_rows {
657            join.index_side.clone()
658        } else {
659            join.probe_side.source.clone()
660        };
661        QueryExpr {
662            source,
663            query: vec![Query::IndexJoin(join)],
664        }
665    }
666}
667
668impl IndexJoin {
669    // Reorder the index and probe sides of an index join.
670    // This is necessary if the indexed table has been replaced by a delta table.
671    // A delta table is a virtual table consisting of changes or updates to a physical table.
672    pub fn reorder(self, row_count: impl Fn(TableId, &str) -> i64) -> Self {
673        // The probe table must be a physical table.
674        if self.probe_side.source.is_mem_table() {
675            return self;
676        }
677        // It must have an index defined on the join field.
678        if !self
679            .probe_side
680            .source
681            .head()
682            .has_constraint(self.probe_col, Constraints::indexed())
683        {
684            return self;
685        }
686        // It must be a linear pipeline of selections.
687        if !self
688            .probe_side
689            .query
690            .iter()
691            .all(|op| matches!(op, Query::Select(_) | Query::IndexScan(_)))
692        {
693            return self;
694        }
695        match self.index_side.get_db_table() {
696            // If the size of the indexed table is sufficiently large,
697            // do not reorder.
698            //
699            // TODO: This determination is quite arbitrary.
700            // Ultimately we should be using cardinality estimation.
701            Some(DbTable { head, table_id, .. }) if row_count(*table_id, &head.table_name) > 500 => self,
702            // If this is a delta table, we must reorder.
703            // If this is a sufficiently small physical table, we should reorder.
704            _ => {
705                // Merge all selections from the original probe side into a single predicate.
706                // This includes an index scan if present.
707                let predicate = self
708                    .probe_side
709                    .query
710                    .into_iter()
711                    .filter_map(<Query as Into<Option<ColumnOp>>>::into)
712                    .reduce(ColumnOp::and);
713                // Push any selections on the index side to the probe side.
714                let probe_side = if let Some(predicate) = self.index_select {
715                    QueryExpr {
716                        source: self.index_side,
717                        query: vec![predicate.into()],
718                    }
719                } else {
720                    self.index_side.into()
721                };
722                IndexJoin {
723                    // The new probe side consists of the updated rows.
724                    // Plus any selections from the original index probe.
725                    probe_side,
726                    // The new probe field is the previous index field.
727                    probe_col: self.index_col,
728                    // The original probe table is now the table that is being probed.
729                    index_side: self.probe_side.source,
730                    // Any selections from the original probe side are pulled above the index lookup.
731                    index_select: predicate,
732                    // The new index field is the previous probe field.
733                    index_col: self.probe_col,
734                    // Because we have swapped the original index and probe sides of the join,
735                    // the new index join needs to return rows from the opposite side.
736                    return_index_rows: !self.return_index_rows,
737                }
738            }
739        }
740    }
741
742    // Convert this index join to an inner join, followed by a projection.
743    // This is needed for incremental evaluation of index joins.
744    // In particular when there are updates to both the left and right tables.
745    // In other words, when an index join has two delta tables.
746    pub fn to_inner_join(self) -> QueryExpr {
747        if self.return_index_rows {
748            let (col_lhs, col_rhs) = (self.index_col, self.probe_col);
749            let rhs = self.probe_side;
750
751            let source = self.index_side;
752            let inner_join = Query::JoinInner(JoinExpr::new(rhs, col_lhs, col_rhs, None));
753            let query = if let Some(predicate) = self.index_select {
754                vec![predicate.into(), inner_join]
755            } else {
756                vec![inner_join]
757            };
758            QueryExpr { source, query }
759        } else {
760            let (col_lhs, col_rhs) = (self.probe_col, self.index_col);
761            let mut rhs: QueryExpr = self.index_side.into();
762
763            if let Some(predicate) = self.index_select {
764                rhs.query.push(predicate.into());
765            }
766
767            let source = self.probe_side.source;
768            let inner_join = Query::JoinInner(JoinExpr::new(rhs, col_lhs, col_rhs, None));
769            let query = vec![inner_join];
770            QueryExpr { source, query }
771        }
772    }
773}
774
775#[derive(Debug, Clone, Eq, PartialEq, Hash)]
776pub struct JoinExpr {
777    pub rhs: QueryExpr,
778    pub col_lhs: ColId,
779    pub col_rhs: ColId,
780    /// If None, this is a left semi-join, returning rows only from the source table,
781    /// using the `rhs` as a filter.
782    ///
783    /// If Some(_), this is an inner join, returning the concatenation of the matching rows.
784    pub inner: Option<Arc<Header>>,
785}
786
787impl JoinExpr {
788    pub fn new(rhs: QueryExpr, col_lhs: ColId, col_rhs: ColId, inner: Option<Arc<Header>>) -> Self {
789        Self {
790            rhs,
791            col_lhs,
792            col_rhs,
793            inner,
794        }
795    }
796}
797
798#[derive(Debug, Clone, Copy, Eq, PartialEq)]
799pub enum DbType {
800    Table,
801    Index,
802    Sequence,
803    Constraint,
804}
805
806#[derive(Debug, Clone, Copy, Eq, PartialEq)]
807pub enum Crud {
808    Query,
809    Insert,
810    Update,
811    Delete,
812    Create(DbType),
813    Drop(DbType),
814    Config,
815}
816
817#[derive(Debug, Eq, PartialEq)]
818pub enum CrudExpr {
819    Query(QueryExpr),
820    Insert {
821        table: DbTable,
822        rows: Vec<ProductValue>,
823    },
824    Update {
825        delete: QueryExpr,
826        assignments: IntMap<ColId, ColExpr>,
827    },
828    Delete {
829        query: QueryExpr,
830    },
831    SetVar {
832        name: String,
833        literal: String,
834    },
835    ReadVar {
836        name: String,
837    },
838}
839
840impl CrudExpr {
841    pub fn optimize(self, row_count: &impl Fn(TableId, &str) -> i64) -> Self {
842        match self {
843            CrudExpr::Query(x) => CrudExpr::Query(x.optimize(row_count)),
844            _ => self,
845        }
846    }
847
848    pub fn is_reads<'a>(exprs: impl IntoIterator<Item = &'a CrudExpr>) -> bool {
849        exprs
850            .into_iter()
851            .all(|expr| matches!(expr, CrudExpr::Query(_) | CrudExpr::ReadVar { .. }))
852    }
853}
854
855#[derive(Debug, Clone, Eq, PartialEq, Hash)]
856pub struct IndexScan {
857    pub table: DbTable,
858    pub columns: ColList,
859    pub bounds: (Bound<AlgebraicValue>, Bound<AlgebraicValue>),
860}
861
862impl IndexScan {
863    /// Returns whether this is a point range.
864    pub fn is_point(&self) -> bool {
865        match &self.bounds {
866            (Bound::Included(lower), Bound::Included(upper)) => lower == upper,
867            _ => false,
868        }
869    }
870}
871
872/// A projection operation in a query.
873#[derive(Debug, Clone, Eq, PartialEq, From, Hash)]
874pub struct ProjectExpr {
875    pub cols: Vec<ColExpr>,
876    // The table id for a qualified wildcard project, if any.
877    // If present, further optimizations are possible.
878    pub wildcard_table: Option<TableId>,
879    pub header_after: Arc<Header>,
880}
881
882// An individual operation in a query.
883#[derive(Debug, Clone, Eq, PartialEq, From, Hash)]
884pub enum Query {
885    // Fetching rows via an index.
886    IndexScan(IndexScan),
887    // Joining rows via an index.
888    // Equivalent to Index Nested Loop Join.
889    IndexJoin(IndexJoin),
890    // A filter over an intermediate relation.
891    // In particular it does not utilize any indexes.
892    // If it could it would have already been transformed into an IndexScan.
893    Select(ColumnOp),
894    // Projects a set of columns.
895    Project(ProjectExpr),
896    // A join of two relations (base or intermediate) based on equality.
897    // Equivalent to a Nested Loop Join.
898    // Its operands my use indexes but the join itself does not.
899    JoinInner(JoinExpr),
900}
901
902impl Query {
903    /// Iterate over all [`SourceExpr`]s involved in the [`Query`].
904    ///
905    /// Sources are yielded from left to right. Duplicates are not filtered out.
906    pub fn walk_sources<E>(&self, on_source: &mut impl FnMut(&SourceExpr) -> Result<(), E>) -> Result<(), E> {
907        match self {
908            Self::Select(..) | Self::Project(..) => Ok(()),
909            Self::IndexScan(scan) => on_source(&SourceExpr::DbTable(scan.table.clone())),
910            Self::IndexJoin(join) => join.probe_side.walk_sources(on_source),
911            Self::JoinInner(join) => join.rhs.walk_sources(on_source),
912        }
913    }
914}
915
916// IndexArgument represents an equality or range predicate that can be answered
917// using an index.
918#[derive(Debug, PartialEq, Clone)]
919enum IndexArgument<'a> {
920    Eq {
921        columns: &'a ColList,
922        value: AlgebraicValue,
923    },
924    LowerBound {
925        columns: &'a ColList,
926        value: AlgebraicValue,
927        inclusive: bool,
928    },
929    UpperBound {
930        columns: &'a ColList,
931        value: AlgebraicValue,
932        inclusive: bool,
933    },
934}
935
936#[derive(Debug, PartialEq, Clone)]
937enum IndexColumnOp<'a> {
938    Index(IndexArgument<'a>),
939    Scan(&'a ColumnOp),
940}
941
942fn make_index_arg(cmp: OpCmp, columns: &ColList, value: AlgebraicValue) -> IndexColumnOp<'_> {
943    let arg = match cmp {
944        OpCmp::Eq => IndexArgument::Eq { columns, value },
945        OpCmp::NotEq => unreachable!("No IndexArgument for NotEq, caller should've filtered out"),
946        // a < 5 => exclusive upper bound
947        OpCmp::Lt => IndexArgument::UpperBound {
948            columns,
949            value,
950            inclusive: false,
951        },
952        // a > 5 => exclusive lower bound
953        OpCmp::Gt => IndexArgument::LowerBound {
954            columns,
955            value,
956            inclusive: false,
957        },
958        // a <= 5 => inclusive upper bound
959        OpCmp::LtEq => IndexArgument::UpperBound {
960            columns,
961            value,
962            inclusive: true,
963        },
964        // a >= 5 => inclusive lower bound
965        OpCmp::GtEq => IndexArgument::LowerBound {
966            columns,
967            value,
968            inclusive: true,
969        },
970    };
971    IndexColumnOp::Index(arg)
972}
973
974#[derive(Debug)]
975struct ColValue<'a> {
976    parent: &'a ColumnOp,
977    col: ColId,
978    cmp: OpCmp,
979    value: &'a AlgebraicValue,
980}
981
982impl<'a> ColValue<'a> {
983    pub fn new(parent: &'a ColumnOp, col: ColId, cmp: OpCmp, value: &'a AlgebraicValue) -> Self {
984        Self {
985            parent,
986            col,
987            cmp,
988            value,
989        }
990    }
991}
992
993type IndexColumnOpSink<'a> = SmallVec<[IndexColumnOp<'a>; 1]>;
994type ColsIndexed = HashSet<(ColId, OpCmp)>;
995
996/// Pick the best indices that can serve the constraints in `op`
997/// where the indices are taken from `header`.
998///
999/// This function is designed to handle complex scenarios when selecting the optimal index for a query.
1000/// The scenarios include:
1001///
1002/// - Combinations of multi- and single-column indexes that could refer to the same column.
1003///   For example, the table could have indexes `[a]` and `[a, b]]`
1004///   and a user could query for `WHERE a = 1 AND b = 2 AND a = 3`.
1005///
1006/// - Query constraints can be supplied in any order;
1007///   i.e., both `WHERE a = 1 AND b = 2`
1008///   and `WHERE b = 2 AND a = 1` are valid.
1009///
1010/// - Queries against multi-col indices must use `=`, for now, in their constraints.
1011///   Otherwise, the index cannot be used.
1012///
1013/// - The use of multiple tables could generate redundant/duplicate operations like
1014///   `[ScanOrIndex::Index(a = 1), ScanOrIndex::Index(a = 1), ScanOrIndex::Scan(a = 1)]`.
1015///   This *cannot* be handled here.
1016///
1017/// # Returns
1018///
1019/// - A vector of `ScanOrIndex` representing the selected `index` OR `scan` operations.
1020///
1021/// - A HashSet of `(ColId, OpCmp)` representing the columns
1022///   and operators that can be served by an index.
1023///
1024///   This is required to remove the redundant operation on e.g.,
1025///   `[ScanOrIndex::Index(a = 1), ScanOrIndex::Index(a = 1), ScanOrIndex::Scan(a = 1)]`,
1026///   that could be generated by calling this function several times by using multiple `JOINS`.
1027///
1028/// # Example
1029///
1030/// If we have a table with `indexes`: `[a], [b], [b, c]` and then try to
1031/// optimize `WHERE a = 1 AND d > 2 AND c = 2 AND b = 1` we should return
1032///
1033/// -`ScanOrIndex::Index([c, b] = [1, 2])`
1034/// -`ScanOrIndex::Index(a = 1)`
1035/// -`ScanOrIndex::Scan(c = 2)`
1036///
1037/// # Note
1038///
1039/// NOTE: For a query like `SELECT * FROM students WHERE age > 18 AND height < 180`
1040/// we cannot serve this with a single `IndexScan`,
1041/// but rather, `select_best_index`
1042/// would give us two separate `IndexScan`s.
1043/// However, the upper layers of `QueryExpr` building will convert both of those into `Select`s.
1044fn select_best_index<'a>(
1045    cols_indexed: &mut ColsIndexed,
1046    header: &'a Header,
1047    op: &'a ColumnOp,
1048) -> IndexColumnOpSink<'a> {
1049    // Collect and sort indices by their lengths, with longest first.
1050    // We do this so that multi-col indices are used first, as they are more efficient.
1051    // TODO(Centril): This could be computed when `Header` is constructed.
1052    let mut indices = header
1053        .constraints
1054        .iter()
1055        .filter(|(_, c)| c.has_indexed())
1056        .map(|(cl, _)| cl)
1057        .collect::<SmallVec<[_; 1]>>();
1058    indices.sort_unstable_by_key(|cl| Reverse(cl.len()));
1059
1060    let mut found: IndexColumnOpSink = IndexColumnOpSink::default();
1061
1062    // Collect fields into a multi-map `(col_id, cmp) -> [col value]`.
1063    // This gives us `log(N)` seek + deletion.
1064    // TODO(Centril): Consider https://docs.rs/small-map/0.1.3/small_map/enum.SmallMap.html
1065    let mut col_map = BTreeMap::<_, SmallVec<[_; 1]>>::new();
1066    extract_cols(op, &mut col_map, &mut found);
1067
1068    // Go through each index,
1069    // consuming all column constraints that can be served by an index.
1070    for col_list in indices {
1071        // (1) No columns left? We're done.
1072        if col_map.is_empty() {
1073            break;
1074        }
1075
1076        if let Some(head) = col_list.as_singleton() {
1077            // Go through each operator.
1078            // NOTE: We do not consider `OpCmp::NotEq` at the moment
1079            // since those are typically not answered using an index.
1080            for cmp in [OpCmp::Eq, OpCmp::Lt, OpCmp::LtEq, OpCmp::Gt, OpCmp::GtEq] {
1081                // For a single column index,
1082                // we want to avoid the `ProductValue` indirection of below.
1083                for ColValue { cmp, value, col, .. } in col_map.remove(&(head, cmp)).into_iter().flatten() {
1084                    found.push(make_index_arg(cmp, col_list, value.clone()));
1085                    cols_indexed.insert((col, cmp));
1086                }
1087            }
1088        } else {
1089            // We have a multi column index.
1090            // Try to fit constraints `c_0 = v_0, ..., c_n = v_n` to this index.
1091            //
1092            // For the time being, we restrict multi-col index scans to `=` only.
1093            // This is what our infrastructure is set-up to handle soundly.
1094            // To extend this support to ranges requires deeper changes.
1095            // TODO(Centril, 2024-05-30): extend this support to ranges.
1096            let cmp = OpCmp::Eq;
1097
1098            // Compute the minimum number of `=` constraints that every column in the index has.
1099            let mut min_all_cols_num_eq = col_list
1100                .iter()
1101                .map(|col| col_map.get(&(col, cmp)).map_or(0, |fs| fs.len()))
1102                .min()
1103                .unwrap_or_default();
1104
1105            // For all of these sets of constraints,
1106            // construct the value to compare against.
1107            while min_all_cols_num_eq > 0 {
1108                let mut elems = Vec::with_capacity(col_list.len() as usize);
1109                for col in col_list.iter() {
1110                    // Cannot panic as `min_all_cols_num_eq > 0`.
1111                    let col_val = pop_multimap(&mut col_map, (col, cmp)).unwrap();
1112                    cols_indexed.insert((col_val.col, cmp));
1113                    // Add the column value to the product value.
1114                    elems.push(col_val.value.clone());
1115                }
1116                // Construct the index scan.
1117                let value = AlgebraicValue::product(elems);
1118                found.push(make_index_arg(cmp, col_list, value));
1119                min_all_cols_num_eq -= 1;
1120            }
1121        }
1122    }
1123
1124    // The remaining constraints must be served by a scan.
1125    found.extend(
1126        col_map
1127            .into_iter()
1128            .flat_map(|(_, fs)| fs)
1129            .map(|f| IndexColumnOp::Scan(f.parent)),
1130    );
1131
1132    found
1133}
1134
1135/// Pop an element from `map[key]` in the multimap `map`,
1136/// removing the entry entirely if there are no more elements left after popping.
1137fn pop_multimap<K: Ord, V, const N: usize>(map: &mut BTreeMap<K, SmallVec<[V; N]>>, key: K) -> Option<V> {
1138    let Entry::Occupied(mut entry) = map.entry(key) else {
1139        return None;
1140    };
1141    let fields = entry.get_mut();
1142    let val = fields.pop();
1143    if fields.is_empty() {
1144        entry.remove();
1145    }
1146    val
1147}
1148
1149/// Extracts a list of `col = val` constraints that *could* be answered by an index
1150/// and populates those into `col_map`.
1151/// The [`ColumnOp`]s that don't fit `col = val`
1152/// are made into [`IndexColumnOp::Scan`]s immediately which are added to `found`.
1153fn extract_cols<'a>(
1154    op: &'a ColumnOp,
1155    col_map: &mut BTreeMap<(ColId, OpCmp), SmallVec<[ColValue<'a>; 1]>>,
1156    found: &mut IndexColumnOpSink<'a>,
1157) {
1158    let mut add_field = |parent, op, col, val| {
1159        let fv = ColValue::new(parent, col, op, val);
1160        col_map.entry((col, op)).or_default().push(fv);
1161    };
1162
1163    match op {
1164        ColumnOp::Cmp { cmp, lhs, rhs } => {
1165            if let (ColumnOp::Col(col), ColumnOp::Val(val)) = (&**lhs, &**rhs) {
1166                // `lhs` must be a field that exists and `rhs` must be a value.
1167                add_field(op, *cmp, *col, val);
1168            }
1169        }
1170        ColumnOp::ColCmpVal { lhs, cmp, rhs } => add_field(op, *cmp, *lhs, rhs),
1171        ColumnOp::Log {
1172            op: OpLogic::And,
1173            operands,
1174        } => {
1175            for oper in operands.iter() {
1176                extract_cols(oper, col_map, found);
1177            }
1178        }
1179        ColumnOp::Log { op: OpLogic::Or, .. } | ColumnOp::Col(_) | ColumnOp::Val(_) => {
1180            found.push(IndexColumnOp::Scan(op));
1181        }
1182    }
1183}
1184
1185#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1186// TODO(bikeshedding): Refactor this struct so that `IndexJoin`s replace the `table`,
1187// rather than appearing as the first element of the `query`.
1188//
1189// `IndexJoin`s do not behave like filters; in fact they behave more like data sources.
1190// A query conceptually starts with either a single table or an `IndexJoin`,
1191// and then stacks a set of filters on top of that.
1192pub struct QueryExpr {
1193    pub source: SourceExpr,
1194    pub query: Vec<Query>,
1195}
1196
1197impl From<SourceExpr> for QueryExpr {
1198    fn from(source: SourceExpr) -> Self {
1199        QueryExpr { source, query: vec![] }
1200    }
1201}
1202
1203impl QueryExpr {
1204    pub fn new<T: Into<SourceExpr>>(source: T) -> Self {
1205        Self {
1206            source: source.into(),
1207            query: vec![],
1208        }
1209    }
1210
1211    /// Iterate over all [`SourceExpr`]s involved in the [`QueryExpr`].
1212    ///
1213    /// Sources are yielded from left to right. Duplicates are not filtered out.
1214    pub fn walk_sources<E>(&self, on_source: &mut impl FnMut(&SourceExpr) -> Result<(), E>) -> Result<(), E> {
1215        on_source(&self.source)?;
1216        self.query.iter().try_for_each(|q| q.walk_sources(on_source))
1217    }
1218
1219    /// Returns the last [`Header`] of this query.
1220    ///
1221    /// Starts the scan from the back to the front,
1222    /// looking for query operations that change the `Header`.
1223    /// These are `JoinInner` and `Project`.
1224    /// If there are no operations that alter the `Header`,
1225    /// this falls back to the origin `self.source.head()`.
1226    pub fn head(&self) -> &Arc<Header> {
1227        self.query
1228            .iter()
1229            .rev()
1230            .find_map(|op| match op {
1231                Query::Select(_) => None,
1232                Query::IndexScan(scan) => Some(&scan.table.head),
1233                Query::IndexJoin(join) if join.return_index_rows => Some(join.index_side.head()),
1234                Query::IndexJoin(join) => Some(join.probe_side.head()),
1235                Query::Project(proj) => Some(&proj.header_after),
1236                Query::JoinInner(join) => join.inner.as_ref(),
1237            })
1238            .unwrap_or_else(|| self.source.head())
1239    }
1240
1241    /// Does this query read from a given table?
1242    pub fn reads_from_table(&self, id: &TableId) -> bool {
1243        self.source.table_id() == Some(*id)
1244            || self.query.iter().any(|q| match q {
1245                Query::Select(_) | Query::Project(..) => false,
1246                Query::IndexScan(scan) => scan.table.table_id == *id,
1247                Query::JoinInner(join) => join.rhs.reads_from_table(id),
1248                Query::IndexJoin(join) => {
1249                    join.index_side.table_id() == Some(*id) || join.probe_side.reads_from_table(id)
1250                }
1251            })
1252    }
1253
1254    // Generate an index scan for an equality predicate if this is the first operator.
1255    // Otherwise generate a select.
1256    // TODO: Replace these methods with a proper query optimization pass.
1257    pub fn with_index_eq(mut self, table: DbTable, columns: ColList, value: AlgebraicValue) -> Self {
1258        let point = |v: AlgebraicValue| (Bound::Included(v.clone()), Bound::Included(v));
1259
1260        // if this is the first operator in the list, generate index scan
1261        let Some(query) = self.query.pop() else {
1262            let bounds = point(value);
1263            self.query.push(Query::IndexScan(IndexScan { table, columns, bounds }));
1264            return self;
1265        };
1266        match query {
1267            // try to push below join's lhs
1268            Query::JoinInner(JoinExpr {
1269                rhs:
1270                    QueryExpr {
1271                        source: SourceExpr::DbTable(ref db_table),
1272                        ..
1273                    },
1274                ..
1275            }) if table.table_id != db_table.table_id => {
1276                self = self.with_index_eq(db_table.clone(), columns, value);
1277                self.query.push(query);
1278                self
1279            }
1280            // try to push below join's rhs
1281            Query::JoinInner(JoinExpr {
1282                rhs,
1283                col_lhs,
1284                col_rhs,
1285                inner: semi,
1286            }) => {
1287                self.query.push(Query::JoinInner(JoinExpr {
1288                    rhs: rhs.with_index_eq(table, columns, value),
1289                    col_lhs,
1290                    col_rhs,
1291                    inner: semi,
1292                }));
1293                self
1294            }
1295            // merge with a preceding select
1296            Query::Select(filter) => {
1297                let op = ColumnOp::and_cmp(OpCmp::Eq, &columns, value);
1298                self.query.push(Query::Select(ColumnOp::and(filter, op)));
1299                self
1300            }
1301            // else generate a new select
1302            query => {
1303                self.query.push(query);
1304                let op = ColumnOp::and_cmp(OpCmp::Eq, &columns, value);
1305                self.query.push(Query::Select(op));
1306                self
1307            }
1308        }
1309    }
1310
1311    // Generate an index scan for a range predicate or try merging with a previous index scan.
1312    // Otherwise generate a select.
1313    // TODO: Replace these methods with a proper query optimization pass.
1314    pub fn with_index_lower_bound(
1315        mut self,
1316        table: DbTable,
1317        columns: ColList,
1318        value: AlgebraicValue,
1319        inclusive: bool,
1320    ) -> Self {
1321        // if this is the first operator in the list, generate an index scan
1322        let Some(query) = self.query.pop() else {
1323            let bounds = (Self::bound(value, inclusive), Bound::Unbounded);
1324            self.query.push(Query::IndexScan(IndexScan { table, columns, bounds }));
1325            return self;
1326        };
1327        match query {
1328            // try to push below join's lhs
1329            Query::JoinInner(JoinExpr {
1330                rhs:
1331                    QueryExpr {
1332                        source: SourceExpr::DbTable(ref db_table),
1333                        ..
1334                    },
1335                ..
1336            }) if table.table_id != db_table.table_id => {
1337                self = self.with_index_lower_bound(table, columns, value, inclusive);
1338                self.query.push(query);
1339                self
1340            }
1341            // try to push below join's rhs
1342            Query::JoinInner(JoinExpr {
1343                rhs,
1344                col_lhs,
1345                col_rhs,
1346                inner: semi,
1347            }) => {
1348                self.query.push(Query::JoinInner(JoinExpr {
1349                    rhs: rhs.with_index_lower_bound(table, columns, value, inclusive),
1350                    col_lhs,
1351                    col_rhs,
1352                    inner: semi,
1353                }));
1354                self
1355            }
1356            // merge with a preceding upper bounded index scan (inclusive)
1357            Query::IndexScan(IndexScan {
1358                columns: lhs_col_id,
1359                bounds: (Bound::Unbounded, Bound::Included(upper)),
1360                ..
1361            }) if columns == lhs_col_id => {
1362                let bounds = (Self::bound(value, inclusive), Bound::Included(upper));
1363                self.query.push(Query::IndexScan(IndexScan { table, columns, bounds }));
1364                self
1365            }
1366            // merge with a preceding upper bounded index scan (exclusive)
1367            Query::IndexScan(IndexScan {
1368                columns: lhs_col_id,
1369                bounds: (Bound::Unbounded, Bound::Excluded(upper)),
1370                ..
1371            }) if columns == lhs_col_id => {
1372                // Queries like `WHERE x < 5 AND x > 5` never return any rows and are likely mistakes.
1373                // Detect such queries and log a warning.
1374                // Compute this condition early, then compute the resulting query and log it.
1375                // TODO: We should not emit an `IndexScan` in this case.
1376                // Further design work is necessary to decide whether this should be an error at query compile time,
1377                // or whether we should emit a query plan which explicitly says that it will return 0 rows.
1378                // The current behavior is a hack
1379                // because this patch was written (2024-04-01 pgoldman) a short time before the BitCraft alpha,
1380                // and a more invasive change was infeasible.
1381                let is_never = !inclusive && value == upper;
1382
1383                let bounds = (Self::bound(value, inclusive), Bound::Excluded(upper));
1384                self.query.push(Query::IndexScan(IndexScan { table, columns, bounds }));
1385
1386                if is_never {
1387                    log::warn!("Query will select no rows due to equal excluded bounds: {self:?}")
1388                }
1389
1390                self
1391            }
1392            // merge with a preceding select
1393            Query::Select(filter) => {
1394                let bounds = (Self::bound(value, inclusive), Bound::Unbounded);
1395                let op = ColumnOp::from_op_col_bounds(&columns, bounds);
1396                self.query.push(Query::Select(ColumnOp::and(filter, op)));
1397                self
1398            }
1399            // else generate a new select
1400            query => {
1401                self.query.push(query);
1402                let bounds = (Self::bound(value, inclusive), Bound::Unbounded);
1403                let op = ColumnOp::from_op_col_bounds(&columns, bounds);
1404                self.query.push(Query::Select(op));
1405                self
1406            }
1407        }
1408    }
1409
1410    // Generate an index scan for a range predicate or try merging with a previous index scan.
1411    // Otherwise generate a select.
1412    // TODO: Replace these methods with a proper query optimization pass.
1413    pub fn with_index_upper_bound(
1414        mut self,
1415        table: DbTable,
1416        columns: ColList,
1417        value: AlgebraicValue,
1418        inclusive: bool,
1419    ) -> Self {
1420        // if this is the first operator in the list, generate an index scan
1421        let Some(query) = self.query.pop() else {
1422            self.query.push(Query::IndexScan(IndexScan {
1423                table,
1424                columns,
1425                bounds: (Bound::Unbounded, Self::bound(value, inclusive)),
1426            }));
1427            return self;
1428        };
1429        match query {
1430            // try to push below join's lhs
1431            Query::JoinInner(JoinExpr {
1432                rhs:
1433                    QueryExpr {
1434                        source: SourceExpr::DbTable(ref db_table),
1435                        ..
1436                    },
1437                ..
1438            }) if table.table_id != db_table.table_id => {
1439                self = self.with_index_upper_bound(table, columns, value, inclusive);
1440                self.query.push(query);
1441                self
1442            }
1443            // try to push below join's rhs
1444            Query::JoinInner(JoinExpr {
1445                rhs,
1446                col_lhs,
1447                col_rhs,
1448                inner: semi,
1449            }) => {
1450                self.query.push(Query::JoinInner(JoinExpr {
1451                    rhs: rhs.with_index_upper_bound(table, columns, value, inclusive),
1452                    col_lhs,
1453                    col_rhs,
1454                    inner: semi,
1455                }));
1456                self
1457            }
1458            // merge with a preceding lower bounded index scan (inclusive)
1459            Query::IndexScan(IndexScan {
1460                columns: lhs_col_id,
1461                bounds: (Bound::Included(lower), Bound::Unbounded),
1462                ..
1463            }) if columns == lhs_col_id => {
1464                let bounds = (Bound::Included(lower), Self::bound(value, inclusive));
1465                self.query.push(Query::IndexScan(IndexScan { table, columns, bounds }));
1466                self
1467            }
1468            // merge with a preceding lower bounded index scan (exclusive)
1469            Query::IndexScan(IndexScan {
1470                columns: lhs_col_id,
1471                bounds: (Bound::Excluded(lower), Bound::Unbounded),
1472                ..
1473            }) if columns == lhs_col_id => {
1474                // Queries like `WHERE x < 5 AND x > 5` never return any rows and are likely mistakes.
1475                // Detect such queries and log a warning.
1476                // Compute this condition early, then compute the resulting query and log it.
1477                // TODO: We should not emit an `IndexScan` in this case.
1478                // Further design work is necessary to decide whether this should be an error at query compile time,
1479                // or whether we should emit a query plan which explicitly says that it will return 0 rows.
1480                // The current behavior is a hack
1481                // because this patch was written (2024-04-01 pgoldman) a short time before the BitCraft alpha,
1482                // and a more invasive change was infeasible.
1483                let is_never = !inclusive && value == lower;
1484
1485                let bounds = (Bound::Excluded(lower), Self::bound(value, inclusive));
1486                self.query.push(Query::IndexScan(IndexScan { table, columns, bounds }));
1487
1488                if is_never {
1489                    log::warn!("Query will select no rows due to equal excluded bounds: {self:?}")
1490                }
1491
1492                self
1493            }
1494            // merge with a preceding select
1495            Query::Select(filter) => {
1496                let bounds = (Bound::Unbounded, Self::bound(value, inclusive));
1497                let op = ColumnOp::from_op_col_bounds(&columns, bounds);
1498                self.query.push(Query::Select(ColumnOp::and(filter, op)));
1499                self
1500            }
1501            // else generate a new select
1502            query => {
1503                self.query.push(query);
1504                let bounds = (Bound::Unbounded, Self::bound(value, inclusive));
1505                let op = ColumnOp::from_op_col_bounds(&columns, bounds);
1506                self.query.push(Query::Select(op));
1507                self
1508            }
1509        }
1510    }
1511
1512    pub fn with_select<O>(mut self, op: O) -> Result<Self, RelationError>
1513    where
1514        O: Into<FieldOp>,
1515    {
1516        let op = op.into();
1517        let Some(query) = self.query.pop() else {
1518            return self.add_base_select(op);
1519        };
1520
1521        match (query, op) {
1522            (
1523                Query::JoinInner(JoinExpr {
1524                    rhs,
1525                    col_lhs,
1526                    col_rhs,
1527                    inner,
1528                }),
1529                FieldOp::Cmp {
1530                    op: OpQuery::Cmp(cmp),
1531                    lhs: field,
1532                    rhs: value,
1533                },
1534            ) => match (*field, *value) {
1535                (FieldOp::Field(FieldExpr::Name(field)), FieldOp::Field(FieldExpr::Value(value)))
1536                // Field is from lhs, so push onto join's left arg
1537                if self.head().column_pos(field).is_some() =>
1538                    {
1539                        // No typing restrictions on `field cmp value`,
1540                        // and there are no binary operators to recurse into.
1541                        self = self.with_select(FieldOp::cmp(field, cmp, value))?;
1542                        self.query.push(Query::JoinInner(JoinExpr { rhs, col_lhs, col_rhs, inner }));
1543                        Ok(self)
1544                    }
1545                (FieldOp::Field(FieldExpr::Name(field)), FieldOp::Field(FieldExpr::Value(value)))
1546                // Field is from rhs, so push onto join's right arg
1547                if rhs.head().column_pos(field).is_some() =>
1548                    {
1549                        // No typing restrictions on `field cmp value`,
1550                        // and there are no binary operators to recurse into.
1551                        let rhs = rhs.with_select(FieldOp::cmp(field, cmp, value))?;
1552                        self.query.push(Query::JoinInner(JoinExpr {
1553                            rhs,
1554                            col_lhs,
1555                            col_rhs,
1556                            inner,
1557                        }));
1558                        Ok(self)
1559                    }
1560                (field, value) => {
1561                    self.query.push(Query::JoinInner(JoinExpr { rhs, col_lhs, col_rhs, inner, }));
1562
1563                    // As we have `field op value` we need not demand `bool`,
1564                    // but we must still recuse into each side.
1565                    self.check_field_op_logics(&field)?;
1566                    self.check_field_op_logics(&value)?;
1567                    // Convert to `ColumnOp`.
1568                    let col = field.names_to_cols(self.head()).unwrap();
1569                    let value = value.names_to_cols(self.head()).unwrap();
1570                    // Add `col op value` filter to query.
1571                    self.query.push(Query::Select(ColumnOp::new(OpQuery::Cmp(cmp), col, value)));
1572                    Ok(self)
1573                }
1574            },
1575            // We have a previous filter `lhs`, so join with `rhs` forming `lhs AND rhs`.
1576            (Query::Select(lhs), rhs) => {
1577                // Type check `rhs`, demanding `bool`.
1578                self.check_field_op(&rhs)?;
1579                // Convert to `ColumnOp`.
1580                let rhs = rhs.names_to_cols(self.head()).unwrap();
1581                // Add `lhs AND op` to query.
1582                self.query.push(Query::Select(ColumnOp::and(lhs, rhs)));
1583                Ok(self)
1584            }
1585            // No previous filter, so add a base one.
1586            (query, op) => {
1587                self.query.push(query);
1588                self.add_base_select(op)
1589            }
1590        }
1591    }
1592
1593    /// Add a base `Select` query that filters according to `op`.
1594    /// The `op` is checked to produce a `bool` value.
1595    fn add_base_select(mut self, op: FieldOp) -> Result<Self, RelationError> {
1596        // Type check the filter, demanding `bool`.
1597        self.check_field_op(&op)?;
1598        // Convert to `ColumnOp`.
1599        let op = op.names_to_cols(self.head()).unwrap();
1600        // Add the filter.
1601        self.query.push(Query::Select(op));
1602        Ok(self)
1603    }
1604
1605    /// Type checks a `FieldOp` with respect to `self`,
1606    /// ensuring that query evaluation cannot get stuck or panic due to `reduce_bool`.
1607    fn check_field_op(&self, op: &FieldOp) -> Result<(), RelationError> {
1608        use OpQuery::*;
1609        match op {
1610            // `lhs` and `rhs` must both be typed at `bool`.
1611            FieldOp::Cmp { op: Logic(_), lhs, rhs } => {
1612                self.check_field_op(lhs)?;
1613                self.check_field_op(rhs)?;
1614                Ok(())
1615            }
1616            // `lhs` and `rhs` have no typing restrictions.
1617            // The result of `lhs op rhs` will always be a `bool`
1618            // either by `Eq` or `Ord` on `AlgebraicValue` (see `ColumnOp::compare_bin_op`).
1619            // However, we still have to recurse into `lhs` and `rhs`
1620            // in case we have e.g., `a == (b == c)`.
1621            FieldOp::Cmp { op: Cmp(_), lhs, rhs } => {
1622                self.check_field_op_logics(lhs)?;
1623                self.check_field_op_logics(rhs)?;
1624                Ok(())
1625            }
1626            FieldOp::Field(FieldExpr::Value(AlgebraicValue::Bool(_))) => Ok(()),
1627            FieldOp::Field(FieldExpr::Value(v)) => Err(RelationError::NotBoolValue { val: v.clone() }),
1628            FieldOp::Field(FieldExpr::Name(field)) => {
1629                let field = *field;
1630                let head = self.head();
1631                let col_id = head.column_pos_or_err(field)?;
1632                let col_ty = &head.fields[col_id.idx()].algebraic_type;
1633                match col_ty {
1634                    &AlgebraicType::Bool => Ok(()),
1635                    ty => Err(RelationError::NotBoolType { field, ty: ty.clone() }),
1636                }
1637            }
1638        }
1639    }
1640
1641    /// Traverses `op`, checking any logical operators for bool-typed operands.
1642    fn check_field_op_logics(&self, op: &FieldOp) -> Result<(), RelationError> {
1643        use OpQuery::*;
1644        match op {
1645            FieldOp::Field(_) => Ok(()),
1646            FieldOp::Cmp { op: Cmp(_), lhs, rhs } => {
1647                self.check_field_op_logics(lhs)?;
1648                self.check_field_op_logics(rhs)?;
1649                Ok(())
1650            }
1651            FieldOp::Cmp { op: Logic(_), lhs, rhs } => {
1652                self.check_field_op(lhs)?;
1653                self.check_field_op(rhs)?;
1654                Ok(())
1655            }
1656        }
1657    }
1658
1659    pub fn with_select_cmp<LHS, RHS, O>(self, op: O, lhs: LHS, rhs: RHS) -> Result<Self, RelationError>
1660    where
1661        LHS: Into<FieldExpr>,
1662        RHS: Into<FieldExpr>,
1663        O: Into<OpQuery>,
1664    {
1665        let op = FieldOp::new(op.into(), FieldOp::Field(lhs.into()), FieldOp::Field(rhs.into()));
1666        self.with_select(op)
1667    }
1668
1669    // Appends a project operation to the query operator pipeline.
1670    // The `wildcard_table_id` represents a projection of the form `table.*`.
1671    // This is used to determine if an inner join can be rewritten as an index join.
1672    pub fn with_project(
1673        mut self,
1674        fields: Vec<FieldExpr>,
1675        wildcard_table: Option<TableId>,
1676    ) -> Result<Self, RelationError> {
1677        if !fields.is_empty() {
1678            let header_before = self.head();
1679
1680            // Translate the field expressions to column expressions.
1681            let mut cols = Vec::with_capacity(fields.len());
1682            for field in fields {
1683                cols.push(field.name_to_col(header_before)?);
1684            }
1685
1686            // Project the header.
1687            // We'll store that so subsequent operations use that as a base.
1688            let header_after = Arc::new(header_before.project(&cols)?);
1689
1690            // Add the projection.
1691            self.query.push(Query::Project(ProjectExpr {
1692                cols,
1693                wildcard_table,
1694                header_after,
1695            }));
1696        }
1697        Ok(self)
1698    }
1699
1700    pub fn with_join_inner_raw(
1701        mut self,
1702        q_rhs: QueryExpr,
1703        c_lhs: ColId,
1704        c_rhs: ColId,
1705        inner: Option<Arc<Header>>,
1706    ) -> Self {
1707        self.query
1708            .push(Query::JoinInner(JoinExpr::new(q_rhs, c_lhs, c_rhs, inner)));
1709        self
1710    }
1711
1712    pub fn with_join_inner(self, q_rhs: impl Into<QueryExpr>, c_lhs: ColId, c_rhs: ColId, semi: bool) -> Self {
1713        let q_rhs = q_rhs.into();
1714        let inner = (!semi).then(|| Arc::new(self.head().extend(q_rhs.head())));
1715        self.with_join_inner_raw(q_rhs, c_lhs, c_rhs, inner)
1716    }
1717
1718    fn bound(value: AlgebraicValue, inclusive: bool) -> Bound<AlgebraicValue> {
1719        if inclusive {
1720            Bound::Included(value)
1721        } else {
1722            Bound::Excluded(value)
1723        }
1724    }
1725
1726    /// Try to turn an inner join followed by a projection into a semijoin.
1727    ///
1728    /// This optimization recognizes queries of the form:
1729    ///
1730    /// ```ignore
1731    /// QueryExpr {
1732    ///   source: LHS,
1733    ///   query: [
1734    ///     JoinInner(JoinExpr {
1735    ///       rhs: RHS,
1736    ///       semi: false,
1737    ///       ..
1738    ///     }),
1739    ///     Project(LHS.*),
1740    ///     ...
1741    ///   ]
1742    /// }
1743    /// ```
1744    ///
1745    /// And combines the `JoinInner` with the `Project` into a `JoinInner` with `semi: true`.
1746    ///
1747    /// Current limitations of this optimization:
1748    /// - The `JoinInner` must be the first (0th) element of the `query`.
1749    ///   Future work could search through the `query` to find any applicable `JoinInner`s,
1750    ///   but the current implementation inspects only the first expr.
1751    ///   This is likely sufficient because this optimization is primarily useful for enabling `try_index_join`,
1752    ///   which is fundamentally limited to operate on the first expr.
1753    ///   Note that we still get to optimize incremental joins, because we first optimize the original query
1754    ///   with [`DbTable`] sources, which results in an [`IndexJoin`]
1755    ///   then we replace the sources with [`MemTable`]s and go back to a [`JoinInner`] with `semi: true`.
1756    /// - The `Project` must immediately follow the `JoinInner`, with no intervening exprs.
1757    ///   Future work could search through intervening exprs to detect that the RHS table is unused.
1758    /// - The LHS/source table must be a [`DbTable`], not a [`MemTable`].
1759    ///   This is so we can recognize a wildcard project by its table id.
1760    ///   Future work could inspect the set of projected fields and compare them to the LHS table's header instead.
1761    pub fn try_semi_join(self) -> QueryExpr {
1762        let QueryExpr { source, query } = self;
1763
1764        let Some(source_table_id) = source.table_id() else {
1765            // Source is a `MemTable`, so we can't recognize a wildcard projection. Bail.
1766            return QueryExpr { source, query };
1767        };
1768
1769        let mut exprs = query.into_iter();
1770        let Some(join_candidate) = exprs.next() else {
1771            // No first (0th) expr to be the join; bail.
1772            return QueryExpr { source, query: vec![] };
1773        };
1774        let Query::JoinInner(join) = join_candidate else {
1775            // First (0th) expr is not an inner join. Bail.
1776            return QueryExpr {
1777                source,
1778                query: itertools::chain![Some(join_candidate), exprs].collect(),
1779            };
1780        };
1781
1782        let Some(project_candidate) = exprs.next() else {
1783            // No second (1st) expr to be the project. Bail.
1784            return QueryExpr {
1785                source,
1786                query: vec![Query::JoinInner(join)],
1787            };
1788        };
1789
1790        let Query::Project(proj) = project_candidate else {
1791            // Second (1st) expr is not a wildcard projection. Bail.
1792            return QueryExpr {
1793                source,
1794                query: itertools::chain![Some(Query::JoinInner(join)), Some(project_candidate), exprs].collect(),
1795            };
1796        };
1797
1798        if proj.wildcard_table != Some(source_table_id) {
1799            // Projection is selecting the RHS table. Bail.
1800            return QueryExpr {
1801                source,
1802                query: itertools::chain![Some(Query::JoinInner(join)), Some(Query::Project(proj)), exprs].collect(),
1803            };
1804        };
1805
1806        // All conditions met; return a semijoin.
1807        let semijoin = JoinExpr { inner: None, ..join };
1808
1809        QueryExpr {
1810            source,
1811            query: itertools::chain![Some(Query::JoinInner(semijoin)), exprs].collect(),
1812        }
1813    }
1814
1815    // Try to turn an applicable join into an index join.
1816    // An applicable join is one that can use an index to probe the lhs.
1817    // It must also project only the columns from the lhs.
1818    //
1819    // Ex. SELECT Left.* FROM Left JOIN Right ON Left.id = Right.id ...
1820    // where `Left` has an index defined on `id`.
1821    fn try_index_join(self) -> QueryExpr {
1822        let mut query = self;
1823        // We expect a single operation - an inner join with `semi: true`.
1824        // These can be transformed by `try_semi_join` from a sequence of two queries, an inner join followed by a wildcard project.
1825        if query.query.len() != 1 {
1826            return query;
1827        }
1828
1829        // If the source is a `MemTable`, it doesn't have any indexes,
1830        // so we can't plan an index join.
1831        if query.source.is_mem_table() {
1832            return query;
1833        }
1834        let source = query.source;
1835        let join = query.query.pop().unwrap();
1836
1837        match join {
1838            Query::JoinInner(join @ JoinExpr { inner: None, .. }) => {
1839                if !join.rhs.query.is_empty() {
1840                    // An applicable join must have an index defined on the correct field.
1841                    if source.head().has_constraint(join.col_lhs, Constraints::indexed()) {
1842                        let index_join = IndexJoin {
1843                            probe_side: join.rhs,
1844                            probe_col: join.col_rhs,
1845                            index_side: source.clone(),
1846                            index_select: None,
1847                            index_col: join.col_lhs,
1848                            return_index_rows: true,
1849                        };
1850                        let query = [Query::IndexJoin(index_join)].into();
1851                        return QueryExpr { source, query };
1852                    }
1853                }
1854                QueryExpr {
1855                    source,
1856                    query: vec![Query::JoinInner(join)],
1857                }
1858            }
1859            first => QueryExpr {
1860                source,
1861                query: vec![first],
1862            },
1863        }
1864    }
1865
1866    /// Look for filters that could use indexes
1867    fn optimize_select(mut q: QueryExpr, op: ColumnOp, tables: &[SourceExpr]) -> QueryExpr {
1868        // Go through each table schema referenced in the query.
1869        // Find the first sargable condition and short-circuit.
1870        let mut fields_found = HashSet::default();
1871        for schema in tables {
1872            for op in select_best_index(&mut fields_found, schema.head(), &op) {
1873                if let IndexColumnOp::Scan(op) = &op {
1874                    // Remove a duplicated/redundant operation on the same `field` and `op`
1875                    // like `[Index(a = 1), Index(a = 1), Scan(a = 1)]`
1876                    if op.as_col_cmp().is_some_and(|cc| !fields_found.insert(cc)) {
1877                        continue;
1878                    }
1879                }
1880
1881                match op {
1882                    // A sargable condition for on one of the table schemas,
1883                    // either an equality or range condition.
1884                    IndexColumnOp::Index(idx) => {
1885                        let table = schema
1886                            .get_db_table()
1887                            .expect("find_sargable_ops(schema, op) implies `schema.is_db_table()`")
1888                            .clone();
1889
1890                        q = match idx {
1891                            IndexArgument::Eq { columns, value } => q.with_index_eq(table, columns.clone(), value),
1892                            IndexArgument::LowerBound {
1893                                columns,
1894                                value,
1895                                inclusive,
1896                            } => q.with_index_lower_bound(table, columns.clone(), value, inclusive),
1897                            IndexArgument::UpperBound {
1898                                columns,
1899                                value,
1900                                inclusive,
1901                            } => q.with_index_upper_bound(table, columns.clone(), value, inclusive),
1902                        };
1903                    }
1904                    // Filter condition cannot be answered using an index.
1905                    IndexColumnOp::Scan(rhs) => {
1906                        let rhs = rhs.clone();
1907                        let op = match q.query.pop() {
1908                            // Merge condition into any pre-existing `Select`.
1909                            Some(Query::Select(lhs)) => ColumnOp::and(lhs, rhs),
1910                            None => rhs,
1911                            Some(other) => {
1912                                q.query.push(other);
1913                                rhs
1914                            }
1915                        };
1916                        q.query.push(Query::Select(op));
1917                    }
1918                }
1919            }
1920        }
1921
1922        q
1923    }
1924
1925    pub fn optimize(mut self, row_count: &impl Fn(TableId, &str) -> i64) -> Self {
1926        let mut q = Self {
1927            source: self.source.clone(),
1928            query: Vec::with_capacity(self.query.len()),
1929        };
1930
1931        if matches!(&*self.query, [Query::IndexJoin(_)]) {
1932            if let Some(Query::IndexJoin(join)) = self.query.pop() {
1933                q.query.push(Query::IndexJoin(join.reorder(row_count)));
1934                return q;
1935            }
1936        }
1937
1938        for query in self.query {
1939            match query {
1940                Query::Select(op) => {
1941                    q = Self::optimize_select(q, op, from_ref(&self.source));
1942                }
1943                Query::JoinInner(join) => {
1944                    q = q.with_join_inner_raw(join.rhs.optimize(row_count), join.col_lhs, join.col_rhs, join.inner);
1945                }
1946                _ => q.query.push(query),
1947            };
1948        }
1949
1950        // Make sure to `try_semi_join` before `try_index_join`, as the latter depends on the former.
1951        let q = q.try_semi_join();
1952        let q = q.try_index_join();
1953        if matches!(&*q.query, [Query::IndexJoin(_)]) {
1954            return q.optimize(row_count);
1955        }
1956        q
1957    }
1958}
1959
1960impl AuthAccess for Query {
1961    fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError> {
1962        if owner == caller {
1963            return Ok(());
1964        }
1965
1966        self.walk_sources(&mut |s| s.check_auth(owner, caller))
1967    }
1968}
1969
1970#[derive(Debug, Eq, PartialEq, From)]
1971pub enum Expr {
1972    #[from]
1973    Value(AlgebraicValue),
1974    Block(Vec<Expr>),
1975    Ident(String),
1976    Crud(Box<CrudExpr>),
1977    Halt(ErrorLang),
1978}
1979
1980impl From<QueryExpr> for Expr {
1981    fn from(x: QueryExpr) -> Self {
1982        Expr::Crud(Box::new(CrudExpr::Query(x)))
1983    }
1984}
1985
1986impl fmt::Display for Query {
1987    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1988        match self {
1989            Query::IndexScan(op) => {
1990                write!(f, "index_scan {op:?}")
1991            }
1992            Query::IndexJoin(op) => {
1993                write!(f, "index_join {op:?}")
1994            }
1995            Query::Select(q) => {
1996                write!(f, "select {q}")
1997            }
1998            Query::Project(proj) => {
1999                let q = &proj.cols;
2000                write!(f, "project")?;
2001                if !q.is_empty() {
2002                    write!(f, " ")?;
2003                }
2004                for (pos, x) in q.iter().enumerate() {
2005                    write!(f, "{x}")?;
2006                    if pos + 1 < q.len() {
2007                        write!(f, ", ")?;
2008                    }
2009                }
2010                Ok(())
2011            }
2012            Query::JoinInner(q) => {
2013                write!(f, "&inner {:?} ON {} = {}", q.rhs, q.col_lhs, q.col_rhs)
2014            }
2015        }
2016    }
2017}
2018
2019impl AuthAccess for SourceExpr {
2020    fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError> {
2021        if owner == caller || self.table_access() == StAccess::Public {
2022            return Ok(());
2023        }
2024
2025        Err(AuthError::TablePrivate {
2026            named: self.table_name().to_string(),
2027        })
2028    }
2029}
2030
2031impl AuthAccess for QueryExpr {
2032    fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError> {
2033        if owner == caller {
2034            return Ok(());
2035        }
2036        self.walk_sources(&mut |s| s.check_auth(owner, caller))
2037    }
2038}
2039
2040impl AuthAccess for CrudExpr {
2041    fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError> {
2042        if owner == caller {
2043            return Ok(());
2044        }
2045        // Anyone may query, so as long as the tables involved are public.
2046        if let CrudExpr::Query(q) = self {
2047            return q.check_auth(owner, caller);
2048        }
2049
2050        // Mutating operations require `owner == caller`.
2051        Err(AuthError::OwnerRequired)
2052    }
2053}
2054
2055#[derive(Debug, PartialEq)]
2056pub struct Update {
2057    pub table_id: TableId,
2058    pub table_name: Box<str>,
2059    pub inserts: Vec<ProductValue>,
2060    pub deletes: Vec<ProductValue>,
2061}
2062
2063#[derive(Debug, PartialEq)]
2064pub enum Code {
2065    Value(AlgebraicValue),
2066    Table(MemTable),
2067    Halt(ErrorLang),
2068    Block(Vec<Code>),
2069    Crud(CrudExpr),
2070    Pass(Option<Update>),
2071}
2072
2073impl fmt::Display for Code {
2074    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2075        match self {
2076            Code::Value(x) => {
2077                write!(f, "{:?}", &x)
2078            }
2079            Code::Block(_) => write!(f, "Block"),
2080            x => todo!("{:?}", x),
2081        }
2082    }
2083}
2084
2085#[derive(Debug, PartialEq)]
2086pub enum CodeResult {
2087    Value(AlgebraicValue),
2088    Table(MemTable),
2089    Block(Vec<CodeResult>),
2090    Halt(ErrorLang),
2091    Pass(Option<Update>),
2092}
2093
2094impl From<Code> for CodeResult {
2095    fn from(code: Code) -> Self {
2096        match code {
2097            Code::Value(x) => Self::Value(x),
2098            Code::Table(x) => Self::Table(x),
2099            Code::Halt(x) => Self::Halt(x),
2100            Code::Block(x) => {
2101                if x.is_empty() {
2102                    Self::Pass(None)
2103                } else {
2104                    Self::Block(x.into_iter().map(CodeResult::from).collect())
2105                }
2106            }
2107            Code::Pass(x) => Self::Pass(x),
2108            x => Self::Halt(ErrorLang::new(
2109                ErrorKind::Compiler,
2110                Some(&format!("Invalid result: {x}")),
2111            )),
2112        }
2113    }
2114}
2115
2116#[cfg(test)]
2117mod tests {
2118    use super::*;
2119
2120    use spacetimedb_lib::db::raw_def::v9::RawModuleDefV9Builder;
2121    use spacetimedb_sats::{product, AlgebraicType, ProductType};
2122    use spacetimedb_schema::{def::ModuleDef, relation::Column, schema::Schema};
2123    use typed_arena::Arena;
2124
2125    const ALICE: Identity = Identity::from_byte_array([1; 32]);
2126    const BOB: Identity = Identity::from_byte_array([2; 32]);
2127
2128    // TODO(kim): Should better do property testing here, but writing generators
2129    // on recursive types (ie. `Query` and friends) is tricky.
2130
2131    fn tables() -> [SourceExpr; 2] {
2132        [
2133            SourceExpr::InMemory {
2134                source_id: SourceId(0),
2135                header: Arc::new(Header {
2136                    table_id: 42.into(),
2137                    table_name: "foo".into(),
2138                    fields: vec![],
2139                    constraints: Default::default(),
2140                }),
2141                table_type: StTableType::User,
2142                table_access: StAccess::Private,
2143            },
2144            SourceExpr::DbTable(DbTable {
2145                head: Arc::new(Header {
2146                    table_id: 42.into(),
2147                    table_name: "foo".into(),
2148                    fields: vec![],
2149                    constraints: [(ColId(42).into(), Constraints::indexed())].into_iter().collect(),
2150                }),
2151                table_id: 42.into(),
2152                table_type: StTableType::User,
2153                table_access: StAccess::Private,
2154            }),
2155        ]
2156    }
2157
2158    fn queries() -> impl IntoIterator<Item = Query> {
2159        let [mem_table, db_table] = tables();
2160        // Skip `Query::Select` and `QueryProject` -- they don't have table
2161        // information
2162        [
2163            Query::IndexScan(IndexScan {
2164                table: db_table.get_db_table().unwrap().clone(),
2165                columns: ColList::new(42.into()),
2166                bounds: (Bound::Included(22.into()), Bound::Unbounded),
2167            }),
2168            Query::IndexJoin(IndexJoin {
2169                probe_side: mem_table.clone().into(),
2170                probe_col: 0.into(),
2171                index_side: SourceExpr::DbTable(DbTable {
2172                    head: Arc::new(Header {
2173                        table_id: db_table.head().table_id,
2174                        table_name: db_table.table_name().into(),
2175                        fields: vec![],
2176                        constraints: Default::default(),
2177                    }),
2178                    table_id: db_table.head().table_id,
2179                    table_type: StTableType::User,
2180                    table_access: StAccess::Public,
2181                }),
2182                index_select: None,
2183                index_col: 22.into(),
2184                return_index_rows: true,
2185            }),
2186            Query::JoinInner(JoinExpr {
2187                col_rhs: 1.into(),
2188                rhs: mem_table.into(),
2189                col_lhs: 1.into(),
2190                inner: None,
2191            }),
2192        ]
2193    }
2194
2195    fn query_exprs() -> impl IntoIterator<Item = QueryExpr> {
2196        tables().map(|table| {
2197            let mut expr = QueryExpr::from(table);
2198            expr.query = queries().into_iter().collect();
2199            expr
2200        })
2201    }
2202
2203    fn assert_owner_private<T: AuthAccess>(auth: &T) {
2204        assert!(auth.check_auth(ALICE, ALICE).is_ok());
2205        assert!(matches!(
2206            auth.check_auth(ALICE, BOB),
2207            Err(AuthError::TablePrivate { .. })
2208        ));
2209    }
2210
2211    fn assert_owner_required<T: AuthAccess>(auth: T) {
2212        assert!(auth.check_auth(ALICE, ALICE).is_ok());
2213        assert!(matches!(auth.check_auth(ALICE, BOB), Err(AuthError::OwnerRequired)));
2214    }
2215
2216    fn mem_table(id: TableId, name: &str, fields: &[(u16, AlgebraicType, bool)]) -> SourceExpr {
2217        let table_access = StAccess::Public;
2218        let head = Header::new(
2219            id,
2220            name.into(),
2221            fields
2222                .iter()
2223                .map(|(col, ty, _)| Column::new(FieldName::new(id, (*col).into()), ty.clone()))
2224                .collect(),
2225            fields
2226                .iter()
2227                .enumerate()
2228                .filter(|(_, (_, _, indexed))| *indexed)
2229                .map(|(i, _)| (ColId::from(i).into(), Constraints::indexed())),
2230        );
2231        SourceExpr::InMemory {
2232            source_id: SourceId(0),
2233            header: Arc::new(head),
2234            table_access,
2235            table_type: StTableType::User,
2236        }
2237    }
2238
2239    #[test]
2240    fn test_index_to_inner_join() {
2241        let index_side = mem_table(
2242            0.into(),
2243            "index",
2244            &[(0, AlgebraicType::U8, false), (1, AlgebraicType::U8, true)],
2245        );
2246        let probe_side = mem_table(
2247            1.into(),
2248            "probe",
2249            &[(0, AlgebraicType::U8, false), (1, AlgebraicType::U8, true)],
2250        );
2251
2252        let index_col = 1.into();
2253        let probe_col = 1.into();
2254        let index_select = ColumnOp::cmp(0, OpCmp::Eq, 0u8);
2255        let join = IndexJoin {
2256            probe_side: probe_side.clone().into(),
2257            probe_col,
2258            index_side: index_side.clone(),
2259            index_select: Some(index_select.clone()),
2260            index_col,
2261            return_index_rows: false,
2262        };
2263
2264        let expr = join.to_inner_join();
2265
2266        assert_eq!(expr.source, probe_side);
2267        assert_eq!(expr.query.len(), 1);
2268
2269        let Query::JoinInner(ref join) = expr.query[0] else {
2270            panic!("expected an inner join, but got {:#?}", expr.query[0]);
2271        };
2272
2273        assert_eq!(join.col_lhs, probe_col);
2274        assert_eq!(join.col_rhs, index_col);
2275        assert_eq!(
2276            join.rhs,
2277            QueryExpr {
2278                source: index_side,
2279                query: vec![index_select.into()]
2280            }
2281        );
2282        assert_eq!(join.inner, None);
2283    }
2284
2285    fn setup_best_index() -> (Header, [ColId; 5], [AlgebraicValue; 5]) {
2286        let table_id = 0.into();
2287
2288        let vals = [1, 2, 3, 4, 5].map(AlgebraicValue::U64);
2289        let col_ids = [0, 1, 2, 3, 4].map(ColId);
2290        let [a, b, c, d, _] = col_ids;
2291        let columns = col_ids.map(|c| Column::new(FieldName::new(table_id, c), AlgebraicType::I8));
2292
2293        let head1 = Header::new(
2294            table_id,
2295            "t1".into(),
2296            columns.to_vec(),
2297            vec![
2298                // Index a
2299                (a.into(), Constraints::primary_key()),
2300                // Index b
2301                (b.into(), Constraints::indexed()),
2302                // Index b + c
2303                (col_list![b, c], Constraints::unique()),
2304                // Index a + b + c + d
2305                (col_list![a, b, c, d], Constraints::indexed()),
2306            ],
2307        );
2308
2309        (head1, col_ids, vals)
2310    }
2311
2312    fn make_field_value((cmp, col, value): (OpCmp, ColId, &AlgebraicValue)) -> ColumnOp {
2313        ColumnOp::cmp(col, cmp, value.clone())
2314    }
2315
2316    fn scan_eq<'a>(arena: &'a Arena<ColumnOp>, col: ColId, val: &'a AlgebraicValue) -> IndexColumnOp<'a> {
2317        scan(arena, OpCmp::Eq, col, val)
2318    }
2319
2320    fn scan<'a>(arena: &'a Arena<ColumnOp>, cmp: OpCmp, col: ColId, val: &'a AlgebraicValue) -> IndexColumnOp<'a> {
2321        IndexColumnOp::Scan(arena.alloc(make_field_value((cmp, col, val))))
2322    }
2323
2324    #[test]
2325    fn best_index() {
2326        let (head1, fields, vals) = setup_best_index();
2327        let [col_a, col_b, col_c, col_d, col_e] = fields;
2328        let [val_a, val_b, val_c, val_d, val_e] = vals;
2329
2330        let arena = Arena::new();
2331        let select_best_index = |fields: &[_]| {
2332            let fields = fields
2333                .iter()
2334                .copied()
2335                .map(|(col, val): (ColId, _)| make_field_value((OpCmp::Eq, col, val)))
2336                .reduce(ColumnOp::and)
2337                .unwrap();
2338            select_best_index(&mut <_>::default(), &head1, arena.alloc(fields))
2339        };
2340
2341        let col_list_arena = Arena::new();
2342        let idx_eq = |cols, val| make_index_arg(OpCmp::Eq, col_list_arena.alloc(cols), val);
2343
2344        // Check for simple scan
2345        assert_eq!(
2346            select_best_index(&[(col_d, &val_e)]),
2347            [scan_eq(&arena, col_d, &val_e)].into(),
2348        );
2349
2350        assert_eq!(
2351            select_best_index(&[(col_a, &val_a)]),
2352            [idx_eq(col_a.into(), val_a.clone())].into(),
2353        );
2354
2355        assert_eq!(
2356            select_best_index(&[(col_b, &val_b)]),
2357            [idx_eq(col_b.into(), val_b.clone())].into(),
2358        );
2359
2360        // Check for permutation
2361        assert_eq!(
2362            select_best_index(&[(col_b, &val_b), (col_c, &val_c)]),
2363            [idx_eq(
2364                col_list![col_b, col_c],
2365                product![val_b.clone(), val_c.clone()].into()
2366            )]
2367            .into(),
2368        );
2369
2370        assert_eq!(
2371            select_best_index(&[(col_c, &val_c), (col_b, &val_b)]),
2372            [idx_eq(
2373                col_list![col_b, col_c],
2374                product![val_b.clone(), val_c.clone()].into()
2375            )]
2376            .into(),
2377        );
2378
2379        // Check for permutation
2380        assert_eq!(
2381            select_best_index(&[(col_a, &val_a), (col_b, &val_b), (col_c, &val_c), (col_d, &val_d)]),
2382            [idx_eq(
2383                col_list![col_a, col_b, col_c, col_d],
2384                product![val_a.clone(), val_b.clone(), val_c.clone(), val_d.clone()].into(),
2385            )]
2386            .into(),
2387        );
2388
2389        assert_eq!(
2390            select_best_index(&[(col_b, &val_b), (col_a, &val_a), (col_d, &val_d), (col_c, &val_c)]),
2391            [idx_eq(
2392                col_list![col_a, col_b, col_c, col_d],
2393                product![val_a.clone(), val_b.clone(), val_c.clone(), val_d.clone()].into(),
2394            )]
2395            .into()
2396        );
2397
2398        // Check mix scan + index
2399        assert_eq!(
2400            select_best_index(&[(col_b, &val_b), (col_a, &val_a), (col_e, &val_e), (col_d, &val_d)]),
2401            [
2402                idx_eq(col_a.into(), val_a.clone()),
2403                idx_eq(col_b.into(), val_b.clone()),
2404                scan_eq(&arena, col_d, &val_d),
2405                scan_eq(&arena, col_e, &val_e),
2406            ]
2407            .into()
2408        );
2409
2410        assert_eq!(
2411            select_best_index(&[(col_b, &val_b), (col_c, &val_c), (col_d, &val_d)]),
2412            [
2413                idx_eq(col_list![col_b, col_c], product![val_b.clone(), val_c.clone()].into(),),
2414                scan_eq(&arena, col_d, &val_d),
2415            ]
2416            .into()
2417        );
2418    }
2419
2420    #[test]
2421    fn best_index_range() {
2422        let arena = Arena::new();
2423
2424        let (head1, cols, vals) = setup_best_index();
2425        let [col_a, col_b, col_c, col_d, _] = cols;
2426        let [val_a, val_b, val_c, val_d, _] = vals;
2427
2428        let select_best_index = |cols: &[_]| {
2429            let fields = cols.iter().map(|x| make_field_value(*x)).reduce(ColumnOp::and).unwrap();
2430            select_best_index(&mut <_>::default(), &head1, arena.alloc(fields))
2431        };
2432
2433        let col_list_arena = Arena::new();
2434        let idx = |cmp, cols: &[ColId], val: &AlgebraicValue| {
2435            let columns = cols.iter().copied().collect::<ColList>();
2436            let columns = col_list_arena.alloc(columns);
2437            make_index_arg(cmp, columns, val.clone())
2438        };
2439
2440        // `a > va AND a < vb` => `[index(a), index(a)]`
2441        assert_eq!(
2442            select_best_index(&[(OpCmp::Gt, col_a, &val_a), (OpCmp::Lt, col_a, &val_b)]),
2443            [idx(OpCmp::Lt, &[col_a], &val_b), idx(OpCmp::Gt, &[col_a], &val_a)].into()
2444        );
2445
2446        // `d > vd AND d < vb` => `[scan(d), scan(d)]`
2447        assert_eq!(
2448            select_best_index(&[(OpCmp::Gt, col_d, &val_d), (OpCmp::Lt, col_d, &val_b)]),
2449            [
2450                scan(&arena, OpCmp::Lt, col_d, &val_b),
2451                scan(&arena, OpCmp::Gt, col_d, &val_d)
2452            ]
2453            .into()
2454        );
2455
2456        // `b > vb AND c < vc` => `[index(b), scan(c)]`.
2457        assert_eq!(
2458            select_best_index(&[(OpCmp::Gt, col_b, &val_b), (OpCmp::Lt, col_c, &val_c)]),
2459            [idx(OpCmp::Gt, &[col_b], &val_b), scan(&arena, OpCmp::Lt, col_c, &val_c)].into()
2460        );
2461
2462        // `b = vb AND a >= va AND c = vc` => `[index(b, c), index(a)]`
2463        let idx_bc = idx(
2464            OpCmp::Eq,
2465            &[col_b, col_c],
2466            &product![val_b.clone(), val_c.clone()].into(),
2467        );
2468        assert_eq!(
2469            //
2470            select_best_index(&[
2471                (OpCmp::Eq, col_b, &val_b),
2472                (OpCmp::GtEq, col_a, &val_a),
2473                (OpCmp::Eq, col_c, &val_c),
2474            ]),
2475            [idx_bc.clone(), idx(OpCmp::GtEq, &[col_a], &val_a),].into()
2476        );
2477
2478        // `b > vb AND a = va AND c = vc` => `[index(a), index(b), scan(c)]`
2479        assert_eq!(
2480            select_best_index(&[
2481                (OpCmp::Gt, col_b, &val_b),
2482                (OpCmp::Eq, col_a, &val_a),
2483                (OpCmp::Lt, col_c, &val_c),
2484            ]),
2485            [
2486                idx(OpCmp::Eq, &[col_a], &val_a),
2487                idx(OpCmp::Gt, &[col_b], &val_b),
2488                scan(&arena, OpCmp::Lt, col_c, &val_c),
2489            ]
2490            .into()
2491        );
2492
2493        // `a = va AND b = vb AND c = vc AND d > vd` => `[index(b, c), index(a), scan(d)]`
2494        assert_eq!(
2495            select_best_index(&[
2496                (OpCmp::Eq, col_a, &val_a),
2497                (OpCmp::Eq, col_b, &val_b),
2498                (OpCmp::Eq, col_c, &val_c),
2499                (OpCmp::Gt, col_d, &val_d),
2500            ]),
2501            [
2502                idx_bc.clone(),
2503                idx(OpCmp::Eq, &[col_a], &val_a),
2504                scan(&arena, OpCmp::Gt, col_d, &val_d),
2505            ]
2506            .into()
2507        );
2508
2509        // `b = vb AND c = vc AND b = vb AND c = vc` => `[index(b, c), index(b, c)]`
2510        assert_eq!(
2511            select_best_index(&[
2512                (OpCmp::Eq, col_b, &val_b),
2513                (OpCmp::Eq, col_c, &val_c),
2514                (OpCmp::Eq, col_b, &val_b),
2515                (OpCmp::Eq, col_c, &val_c),
2516            ]),
2517            [idx_bc.clone(), idx_bc].into()
2518        );
2519    }
2520
2521    #[test]
2522    fn test_auth_table() {
2523        tables().iter().for_each(assert_owner_private)
2524    }
2525
2526    #[test]
2527    fn test_auth_query_code() {
2528        for code in query_exprs() {
2529            assert_owner_private(&code)
2530        }
2531    }
2532
2533    #[test]
2534    fn test_auth_query() {
2535        for query in queries() {
2536            assert_owner_private(&query);
2537        }
2538    }
2539
2540    #[test]
2541    fn test_auth_crud_code_query() {
2542        for query in query_exprs() {
2543            let crud = CrudExpr::Query(query);
2544            assert_owner_private(&crud);
2545        }
2546    }
2547
2548    #[test]
2549    fn test_auth_crud_code_insert() {
2550        for table in tables().into_iter().filter_map(|s| s.get_db_table().cloned()) {
2551            let crud = CrudExpr::Insert { table, rows: vec![] };
2552            assert_owner_required(crud);
2553        }
2554    }
2555
2556    #[test]
2557    fn test_auth_crud_code_update() {
2558        for qc in query_exprs() {
2559            let crud = CrudExpr::Update {
2560                delete: qc,
2561                assignments: Default::default(),
2562            };
2563            assert_owner_required(crud);
2564        }
2565    }
2566
2567    #[test]
2568    fn test_auth_crud_code_delete() {
2569        for query in query_exprs() {
2570            let crud = CrudExpr::Delete { query };
2571            assert_owner_required(crud);
2572        }
2573    }
2574
2575    fn test_def() -> ModuleDef {
2576        let mut builder = RawModuleDefV9Builder::new();
2577        builder.build_table_with_new_type(
2578            "lhs",
2579            ProductType::from([("a", AlgebraicType::I32), ("b", AlgebraicType::String)]),
2580            true,
2581        );
2582        builder.build_table_with_new_type(
2583            "rhs",
2584            ProductType::from([("c", AlgebraicType::I32), ("d", AlgebraicType::I64)]),
2585            true,
2586        );
2587        builder.finish().try_into().expect("test def should be valid")
2588    }
2589
2590    #[test]
2591    /// Tests that [`QueryExpr::optimize`] can rewrite inner joins followed by projections into semijoins.
2592    fn optimize_inner_join_to_semijoin() {
2593        let def: ModuleDef = test_def();
2594        let lhs = TableSchema::from_module_def(&def, def.table("lhs").unwrap(), (), 0.into());
2595        let rhs = TableSchema::from_module_def(&def, def.table("rhs").unwrap(), (), 1.into());
2596
2597        let lhs_source = SourceExpr::from(&lhs);
2598        let rhs_source = SourceExpr::from(&rhs);
2599
2600        let q = QueryExpr::new(lhs_source.clone())
2601            .with_join_inner(rhs_source.clone(), 0.into(), 0.into(), false)
2602            .with_project(
2603                [0, 1]
2604                    .map(|c| FieldExpr::Name(FieldName::new(lhs.table_id, c.into())))
2605                    .into(),
2606                Some(TableId::SENTINEL),
2607            )
2608            .unwrap();
2609        let q = q.optimize(&|_, _| 0);
2610
2611        assert_eq!(q.source, lhs_source, "Optimized query should read from lhs");
2612
2613        assert_eq!(
2614            q.query.len(),
2615            1,
2616            "Optimized query should have a single member, a semijoin"
2617        );
2618        match &q.query[0] {
2619            Query::JoinInner(JoinExpr { rhs, inner: semi, .. }) => {
2620                assert_eq!(semi, &None, "Optimized query should be a semijoin");
2621                assert_eq!(rhs.source, rhs_source, "Optimized query should filter with rhs");
2622                assert!(
2623                    rhs.query.is_empty(),
2624                    "Optimized query should not filter rhs before joining"
2625                );
2626            }
2627            wrong => panic!("Expected an inner join, but found {wrong:?}"),
2628        }
2629    }
2630
2631    #[test]
2632    /// Tests that [`QueryExpr::optimize`] will not rewrite inner joins which are not followed by projections to the LHS table.
2633    fn optimize_inner_join_no_project() {
2634        let def: ModuleDef = test_def();
2635        let lhs = TableSchema::from_module_def(&def, def.table("lhs").unwrap(), (), 0.into());
2636        let rhs = TableSchema::from_module_def(&def, def.table("rhs").unwrap(), (), 1.into());
2637
2638        let lhs_source = SourceExpr::from(&lhs);
2639        let rhs_source = SourceExpr::from(&rhs);
2640
2641        let q = QueryExpr::new(lhs_source.clone()).with_join_inner(rhs_source.clone(), 0.into(), 0.into(), false);
2642        let optimized = q.clone().optimize(&|_, _| 0);
2643        assert_eq!(q, optimized);
2644    }
2645
2646    #[test]
2647    /// Tests that [`QueryExpr::optimize`] will not rewrite inner joins followed by projections to the RHS rather than LHS table.
2648    fn optimize_inner_join_wrong_project() {
2649        let def: ModuleDef = test_def();
2650        let lhs = TableSchema::from_module_def(&def, def.table("lhs").unwrap(), (), 0.into());
2651        let rhs = TableSchema::from_module_def(&def, def.table("rhs").unwrap(), (), 1.into());
2652
2653        let lhs_source = SourceExpr::from(&lhs);
2654        let rhs_source = SourceExpr::from(&rhs);
2655
2656        let q = QueryExpr::new(lhs_source.clone())
2657            .with_join_inner(rhs_source.clone(), 0.into(), 0.into(), false)
2658            .with_project(
2659                [0, 1]
2660                    .map(|c| FieldExpr::Name(FieldName::new(rhs.table_id, c.into())))
2661                    .into(),
2662                Some(TableId(1)),
2663            )
2664            .unwrap();
2665        let optimized = q.clone().optimize(&|_, _| 0);
2666        assert_eq!(q, optimized);
2667    }
2668}