Skip to main content

palimpsest_dataflow/palimpsest/
compile_mir.rs

1// Copyright 2026 Thousand Birds Inc.
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! MIR → dataflow build-plan compiler.
5//!
6//! [`compile_mir`] walks a `MirGraph` topologically and emits a
7//! [`CompiledPlan`] of per-node "recipes" — boxed closures that read
8//! / extract / aggregate values from `Row`s. [`install_plan`]
9//! consumes a plan inside a timely scope, wiring the appropriate
10//! [`relational`](crate::palimpsest::relational) operator for each
11//! node and returning a final `VecCollection<G, Row, isize>` that the
12//! caller arranges into a trace for snapshot + diff delivery.
13//!
14//! The compiler is intentionally permissive about column shapes: every
15//! intermediate collection carries `Row` values rather than typed
16//! tuples, with per-node closures projecting columns by index. That
17//! erases the type-system pressure that comes from compiling SQL into
18//! differential's strongly-typed operators, while still passing each
19//! operator data of the exact shape it expects (e.g. `(i64, i64)`
20//! pairs into `aggregate_i64`).
21//!
22//! Coverage today: `BaseTable`, `Filter` (boolean predicates),
23//! `Project` (column rename / reorder), `Aggregate` (group-by with
24//! `COUNT` / `SUM` / `MIN` / `MAX` / `AVG`), `TopK` (single-column
25//! sort), and `CteRef`. `Join`, `Distinct`, `Union`, `Except`,
26//! `Intersect`, and `Leaf` return
27//! [`CompileError::Unsupported`]. The walker is structured so each of
28//! those reduces to a single `match` arm + recipe variant when wired.
29
30use std::collections::HashMap;
31use std::sync::Arc;
32
33use palimpsest_sql::catalog::ColumnType;
34use palimpsest_sql::mir::{AggExpr, ColumnRef, MirGraph, MirNodeKind, OrderKey};
35use palimpsest_wal::{Datum, TableId};
36use petgraph::graph::NodeIndex;
37use petgraph::Direction;
38use smallvec::SmallVec;
39use thiserror::Error;
40
41use crate::palimpsest::eval::{compile_predicate, EvalError, ScalarSchema};
42use crate::palimpsest::relational::{self, AggregateFunc, AggregateValue, SortDirection};
43use crate::palimpsest::wal::Row;
44use crate::{lattice::Lattice, VecCollection};
45
46// -----------------------------------------------------------------------------
47// Public types
48// -----------------------------------------------------------------------------
49
50/// A query compiled from MIR into a plan that can be instantiated into
51/// a timely scope. The compiler doesn't itself touch timely — the
52/// dataflow host owns that lifecycle.
53#[derive(Clone)]
54pub struct CompiledPlan {
55    /// The MIR the plan was compiled from. Held so the installer can
56    /// re-walk edges to find each node's inputs.
57    pub graph: MirGraph,
58    /// Root node of `graph` — the query's output operator.
59    pub root: NodeIndex,
60    /// Tables this query reads from, in the order the host should wire
61    /// them up. Each entry maps 1:1 to a `BaseTable` MIR node.
62    pub inputs: Vec<TableId>,
63    /// Per-table schemas captured at compile time. The host uses
64    /// these to know how to encode raw `WalUpdate` rows that feed the
65    /// dataflow's input handles.
66    pub input_schemas: HashMap<TableId, ScalarSchema>,
67    /// Output schema of `root` — the rows the query produces.
68    pub output_schema: ScalarSchema,
69    /// Per-node schemas (intermediate + leaf). Useful for debugging
70    /// and for downstream consumers that want to render an explain
71    /// plan.
72    pub node_schemas: HashMap<NodeIndex, ScalarSchema>,
73    /// Per-node compiled recipes. Indexed by the MIR node index.
74    pub recipes: HashMap<NodeIndex, NodeRecipe>,
75}
76
77/// Compiled per-node payload. Each variant is a fragment of dataflow
78/// the installer can lift into a scope. Closures are `Arc`-wrapped so
79/// the installer can clone them across multiple instantiations
80/// (shared subgraph reuse) without re-compiling the MIR.
81#[derive(Clone)]
82pub enum NodeRecipe {
83    /// Read raw rows from one of the dataflow host's input handles.
84    BaseTable {
85        /// Table id passed into `install_plan`'s `inputs` map.
86        table: TableId,
87    },
88    /// Drop rows for which the boolean predicate evaluates to false.
89    Filter {
90        /// Closure that reads from the input row and returns a bool.
91        predicate: Arc<dyn Fn(&Row) -> bool + Send + Sync>,
92    },
93    /// Reorder + rename columns. The closure returns the new row.
94    Project {
95        /// Closure that reads from the input row and returns the
96        /// projected row.
97        extract: Arc<dyn Fn(&Row) -> Row + Send + Sync>,
98    },
99    /// Group-by aggregate. Group key is the value of one column (a
100    /// single-column grouping is all we wire today); each aggregate
101    /// function reads the same value column. The compiler emits a
102    /// `Project` upstream if a richer extraction is needed.
103    Aggregate {
104        /// Closure that reads the group-by column out of the input
105        /// row. Returns a `Datum` so the original column type is
106        /// preserved end-to-end — `aggregate_i64`'s key parameter is
107        /// generic, so we use the full `Datum` rather than coercing
108        /// `Bool` / `Text` group keys onto `i64`.
109        group_extract: Arc<dyn Fn(&Row) -> Datum + Send + Sync>,
110        /// Closure that reads the aggregate value column. For
111        /// `COUNT(*)`-only aggregates this is the constant zero
112        /// extractor (the operator only counts diffs).
113        value_extract: Arc<dyn Fn(&Row) -> i64 + Send + Sync>,
114        /// One entry per aggregate function in projection order.
115        funcs: Vec<AggregateFunc>,
116    },
117    /// Global TopK with a single-column sort key.
118    TopK {
119        /// Closure that reads the sort key out of each row.
120        sort_key_extract: Arc<dyn Fn(&Row) -> i64 + Send + Sync>,
121        /// Ascending vs descending order.
122        direction: SortDirection,
123        /// Limit (max rows retained).
124        limit: usize,
125        /// Offset (rows skipped from the head of the sorted slice).
126        offset: usize,
127    },
128    /// CTE reference: forwards to another node in the same graph.
129    CteRef {
130        /// MIR node index of the CTE's root.
131        target: NodeIndex,
132    },
133}
134
135/// Compile-time errors. Runtime evaluation can't fail — every closure
136/// returns *some* `Row` — so any rejection lands here.
137#[derive(Debug, Error)]
138pub enum CompileError {
139    /// Operator kind not yet implemented by the compiler.
140    #[error("unsupported MIR node: {0}")]
141    Unsupported(String),
142    /// Expression evaluator rejected an inline string.
143    #[error("expression: {0}")]
144    Expression(#[from] EvalError),
145    /// Identifier references a column / table the schema lookup
146    /// doesn't know about.
147    #[error("unknown identifier: {0}")]
148    Unknown(String),
149    /// MIR has a cycle. petgraph's toposort surfaces this.
150    #[error("MIR graph has a cycle")]
151    Cycle,
152    /// Aggregate function unsupported (something other than
153    /// COUNT / SUM / MIN / MAX / AVG / COUNT DISTINCT).
154    #[error("unsupported aggregate function: {0}")]
155    UnsupportedAggregate(String),
156    /// Multi-column group-by (one column is the only shape we wire).
157    #[error("multi-column GROUP BY not yet supported")]
158    MultiColumnGroupBy,
159    /// Aggregates over different value columns aren't wired —
160    /// `aggregate_i64` only takes one input column per call.
161    #[error("aggregate columns disagree: {0}")]
162    HeterogeneousAggregateColumns(String),
163    /// Multi-column ORDER BY not yet wired.
164    #[error("multi-column ORDER BY not yet supported")]
165    MultiColumnOrderBy,
166}
167
168/// Callback signature the compiler uses to look up a base table's
169/// schema. The dataflow host owns the demo's catalog and supplies
170/// this when building the plan.
171pub trait TableSchemaLookup {
172    /// Resolve `table` to its `(table_id, schema)` pair, or `None`
173    /// if the table isn't known.
174    fn lookup(&self, table: &str) -> Option<(TableId, ScalarSchema)>;
175}
176
177impl<F> TableSchemaLookup for F
178where
179    F: Fn(&str) -> Option<(TableId, ScalarSchema)>,
180{
181    fn lookup(&self, table: &str) -> Option<(TableId, ScalarSchema)> {
182        (self)(table)
183    }
184}
185
186// -----------------------------------------------------------------------------
187// Compile entry point
188// -----------------------------------------------------------------------------
189
190/// Walk `graph` and emit a [`CompiledPlan`].
191///
192/// # Errors
193/// Returns [`CompileError`] on cycles, unknown identifiers, or MIR
194/// shapes the walker hasn't been taught yet.
195pub fn compile_mir<L: TableSchemaLookup>(
196    graph: &MirGraph,
197    tables: &L,
198) -> Result<CompiledPlan, CompileError> {
199    let topo = petgraph::algo::toposort(graph.graph(), None).map_err(|_| CompileError::Cycle)?;
200
201    let mut state = CompileState {
202        node_schemas: HashMap::new(),
203        recipes: HashMap::new(),
204        inputs: Vec::new(),
205        input_schemas: HashMap::new(),
206    };
207
208    for node in topo {
209        compile_node(graph, node, tables, &mut state)?;
210    }
211
212    let root = graph.root();
213    let output_schema = state.node_schemas.get(&root).cloned().unwrap_or_default();
214
215    Ok(CompiledPlan {
216        graph: graph.clone(),
217        root,
218        inputs: state.inputs,
219        input_schemas: state.input_schemas,
220        output_schema,
221        node_schemas: state.node_schemas,
222        recipes: state.recipes,
223    })
224}
225
226struct CompileState {
227    node_schemas: HashMap<NodeIndex, ScalarSchema>,
228    recipes: HashMap<NodeIndex, NodeRecipe>,
229    inputs: Vec<TableId>,
230    input_schemas: HashMap<TableId, ScalarSchema>,
231}
232
233fn compile_node<L: TableSchemaLookup>(
234    graph: &MirGraph,
235    node: NodeIndex,
236    tables: &L,
237    state: &mut CompileState,
238) -> Result<(), CompileError> {
239    let kind = graph.node_kind(node);
240    match kind {
241        MirNodeKind::BaseTable { table, project } => {
242            compile_base_table(node, table, project, tables, state)
243        }
244        MirNodeKind::Filter { predicate } => compile_filter(graph, node, predicate, state),
245        MirNodeKind::Project { columns } => compile_project(graph, node, columns, state),
246        MirNodeKind::Aggregate { group_by, aggs } => {
247            compile_aggregate(graph, node, group_by, aggs, state)
248        }
249        MirNodeKind::TopK {
250            order_by,
251            limit,
252            offset,
253        } => compile_topk(graph, node, order_by, *limit, *offset, state),
254        MirNodeKind::CteRef { cte } => compile_cte_ref(graph, node, cte, state),
255        MirNodeKind::Join { .. } => Err(CompileError::Unsupported("Join".to_owned())),
256        MirNodeKind::Distinct => Err(CompileError::Unsupported("Distinct".to_owned())),
257        MirNodeKind::Union { .. } => Err(CompileError::Unsupported("Union".to_owned())),
258        MirNodeKind::Except { .. } => Err(CompileError::Unsupported("Except".to_owned())),
259        MirNodeKind::Intersect { .. } => Err(CompileError::Unsupported("Intersect".to_owned())),
260        MirNodeKind::Leaf { .. } => Err(CompileError::Unsupported("Leaf".to_owned())),
261    }
262}
263
264// -----------------------------------------------------------------------------
265// Per-node compile helpers
266// -----------------------------------------------------------------------------
267
268fn compile_base_table<L: TableSchemaLookup>(
269    node: NodeIndex,
270    table: &str,
271    project: &[ColumnRef],
272    tables: &L,
273    state: &mut CompileState,
274) -> Result<(), CompileError> {
275    let (table_id, full_schema) = tables
276        .lookup(table)
277        .ok_or_else(|| CompileError::Unknown(format!("table {table}")))?;
278
279    // If `project` is empty, expose the table's full schema. Otherwise
280    // narrow to the named columns (in MIR order).
281    let schema = if project.is_empty() {
282        full_schema.clone()
283    } else {
284        let pairs = project
285            .iter()
286            .map(|col| {
287                full_schema
288                    .column_type(&col.name)
289                    .ok_or_else(|| CompileError::Unknown(format!("{table}.{}", col.name)))
290                    .map(|ty| (col.name.clone(), ty))
291            })
292            .collect::<Result<Vec<_>, _>>()?;
293        ScalarSchema::from_pairs(pairs)
294    };
295
296    if !state.input_schemas.contains_key(&table_id) {
297        state.inputs.push(table_id);
298        state.input_schemas.insert(table_id, full_schema);
299    }
300
301    state.node_schemas.insert(node, schema);
302    state
303        .recipes
304        .insert(node, NodeRecipe::BaseTable { table: table_id });
305    Ok(())
306}
307
308fn compile_filter(
309    graph: &MirGraph,
310    node: NodeIndex,
311    predicate: &str,
312    state: &mut CompileState,
313) -> Result<(), CompileError> {
314    let input_node = single_input(graph, node)?;
315    let input_schema = state
316        .node_schemas
317        .get(&input_node)
318        .ok_or_else(|| CompileError::Unknown("filter input schema".to_owned()))?
319        .clone();
320
321    let pred = compile_predicate(predicate, &input_schema)?;
322    let pred: Arc<dyn Fn(&Row) -> bool + Send + Sync> = Arc::from(pred);
323
324    state.node_schemas.insert(node, input_schema);
325    state
326        .recipes
327        .insert(node, NodeRecipe::Filter { predicate: pred });
328    Ok(())
329}
330
331fn compile_project(
332    graph: &MirGraph,
333    node: NodeIndex,
334    columns: &[String],
335    state: &mut CompileState,
336) -> Result<(), CompileError> {
337    let input_node = single_input(graph, node)?;
338    let input_schema = state
339        .node_schemas
340        .get(&input_node)
341        .ok_or_else(|| CompileError::Unknown("project input schema".to_owned()))?
342        .clone();
343
344    let mut indices = Vec::with_capacity(columns.len());
345    let mut output_pairs = Vec::with_capacity(columns.len());
346    for col in columns {
347        let idx = input_schema
348            .index_of(col)
349            .ok_or_else(|| CompileError::Unknown(format!("project column {col}")))?;
350        let ty = input_schema
351            .column_type(col)
352            .expect("type for known column");
353        indices.push(idx);
354        output_pairs.push((col.clone(), ty));
355    }
356
357    let output_schema = ScalarSchema::from_pairs(output_pairs);
358    let indices_owned = indices;
359    let extract: Arc<dyn Fn(&Row) -> Row + Send + Sync> = Arc::new(move |row: &Row| {
360        let mut out: Row = SmallVec::with_capacity(indices_owned.len());
361        for &i in &indices_owned {
362            out.push(row.get(i).cloned().unwrap_or(Datum::Null));
363        }
364        out
365    });
366
367    state.node_schemas.insert(node, output_schema);
368    state.recipes.insert(node, NodeRecipe::Project { extract });
369    Ok(())
370}
371
372fn compile_aggregate(
373    graph: &MirGraph,
374    node: NodeIndex,
375    group_by: &[ColumnRef],
376    aggs: &[AggExpr],
377    state: &mut CompileState,
378) -> Result<(), CompileError> {
379    let input_node = single_input(graph, node)?;
380    let input_schema = state
381        .node_schemas
382        .get(&input_node)
383        .ok_or_else(|| CompileError::Unknown("aggregate input schema".to_owned()))?
384        .clone();
385
386    if group_by.len() != 1 {
387        return Err(CompileError::MultiColumnGroupBy);
388    }
389    let group_col = &group_by[0].name;
390    let group_idx = input_schema
391        .index_of(group_col)
392        .ok_or_else(|| CompileError::Unknown(format!("group column {group_col}")))?;
393    let group_type = input_schema
394        .column_type(group_col)
395        .expect("type for known column");
396
397    // Return the raw `Datum` so the operator key preserves the
398    // column's original SQL type. `aggregate_i64` is generic over
399    // `K`, so passing `Datum` works directly — and the schema we
400    // advertise to clients (group_col with its original `group_type`)
401    // round-trips end-to-end.
402    let group_extract: Arc<dyn Fn(&Row) -> Datum + Send + Sync> =
403        Arc::new(move |row: &Row| row.get(group_idx).cloned().unwrap_or(Datum::Null));
404
405    // `aggregate_i64` reads one input column per call, with each
406    // `AggregateFunc` evaluating against the same value stream.
407    // Enforce that all aggregates either reference the same column
408    // or are `COUNT(*)` (which ignores its argument).
409    let mut value_column: Option<String> = None;
410    let mut funcs = Vec::with_capacity(aggs.len());
411    let mut output_pairs = Vec::with_capacity(group_by.len() + aggs.len());
412    output_pairs.push((group_col.clone(), group_type));
413
414    for agg in aggs {
415        let func = parse_agg_func(&agg.function)?;
416        funcs.push(func);
417
418        // Validate the value column. `*` is the wildcard for COUNT.
419        let arg_text = agg.args.first().map(String::as_str).unwrap_or("*");
420        let arg_col = arg_text.trim();
421        if arg_col != "*" && !matches!(func, AggregateFunc::Count) {
422            match &value_column {
423                None => value_column = Some(arg_col.to_owned()),
424                Some(prev) if prev == arg_col => {}
425                Some(prev) => {
426                    return Err(CompileError::HeterogeneousAggregateColumns(format!(
427                        "{prev} vs {arg_col}"
428                    )));
429                }
430            }
431        }
432
433        let alias = agg
434            .alias
435            .clone()
436            .unwrap_or_else(|| format!("{}_{}", agg.function.to_lowercase(), output_pairs.len()));
437        let output_type = match func {
438            AggregateFunc::Avg => ColumnType::Float,
439            _ => ColumnType::Int,
440        };
441        output_pairs.push((alias, output_type));
442    }
443
444    // If every agg is COUNT(*), there's no value column — pass a
445    // zero extractor; aggregate_i64 only reads diffs in that case.
446    let value_extract: Arc<dyn Fn(&Row) -> i64 + Send + Sync> = match value_column {
447        None => Arc::new(|_row: &Row| 0),
448        Some(col) => {
449            let value_idx = input_schema
450                .index_of(&col)
451                .ok_or_else(|| CompileError::Unknown(format!("aggregate column {col}")))?;
452            Arc::new(move |row: &Row| match row.get(value_idx) {
453                Some(Datum::I64(v)) => *v,
454                Some(Datum::I32(v)) => i64::from(*v),
455                Some(Datum::I16(v)) => i64::from(*v),
456                _ => 0,
457            })
458        }
459    };
460
461    let output_schema = ScalarSchema::from_pairs(output_pairs);
462
463    state.node_schemas.insert(node, output_schema);
464    state.recipes.insert(
465        node,
466        NodeRecipe::Aggregate {
467            group_extract,
468            value_extract,
469            funcs,
470        },
471    );
472    Ok(())
473}
474
475fn parse_agg_func(name: &str) -> Result<AggregateFunc, CompileError> {
476    match name.to_ascii_lowercase().as_str() {
477        "count" => Ok(AggregateFunc::Count),
478        "sum" => Ok(AggregateFunc::Sum),
479        "min" => Ok(AggregateFunc::Min),
480        "max" => Ok(AggregateFunc::Max),
481        "avg" => Ok(AggregateFunc::Avg),
482        other => Err(CompileError::UnsupportedAggregate(other.to_owned())),
483    }
484}
485
486fn compile_topk(
487    graph: &MirGraph,
488    node: NodeIndex,
489    order_by: &[OrderKey],
490    limit: usize,
491    offset: usize,
492    state: &mut CompileState,
493) -> Result<(), CompileError> {
494    let input_node = single_input(graph, node)?;
495    let input_schema = state
496        .node_schemas
497        .get(&input_node)
498        .ok_or_else(|| CompileError::Unknown("topk input schema".to_owned()))?
499        .clone();
500
501    if order_by.len() != 1 {
502        return Err(CompileError::MultiColumnOrderBy);
503    }
504    let key = &order_by[0];
505    let sort_idx = input_schema
506        .index_of(&key.expression)
507        .ok_or_else(|| CompileError::Unknown(format!("order column {}", key.expression)))?;
508
509    let sort_key_extract: Arc<dyn Fn(&Row) -> i64 + Send + Sync> =
510        Arc::new(move |row: &Row| match row.get(sort_idx) {
511            Some(Datum::I64(v)) => *v,
512            Some(Datum::I32(v)) => i64::from(*v),
513            Some(Datum::I16(v)) => i64::from(*v),
514            _ => 0,
515        });
516
517    let direction = if key.descending {
518        SortDirection::Descending
519    } else {
520        SortDirection::Ascending
521    };
522
523    state.node_schemas.insert(node, input_schema);
524    state.recipes.insert(
525        node,
526        NodeRecipe::TopK {
527            sort_key_extract,
528            direction,
529            limit,
530            offset,
531        },
532    );
533    Ok(())
534}
535
536fn compile_cte_ref(
537    graph: &MirGraph,
538    node: NodeIndex,
539    cte: &str,
540    state: &mut CompileState,
541) -> Result<(), CompileError> {
542    // CteExpansion edges point from the CTE's root *into* the
543    // CteRef node — i.e. they're incoming edges on this node,
544    // sourced at the cte_root. See `palimpsest_sql::lower::lower_table_factor`.
545    use petgraph::visit::EdgeRef;
546    let target = graph
547        .graph()
548        .edges_directed(node, Direction::Incoming)
549        .find(|edge| {
550            matches!(
551                edge.weight(),
552                palimpsest_sql::mir::MirEdgeKind::CteExpansion
553            )
554        })
555        .map(|edge| edge.source());
556    let target = target.ok_or_else(|| CompileError::Unknown(format!("cte {cte}")))?;
557
558    let schema = state
559        .node_schemas
560        .get(&target)
561        .cloned()
562        .ok_or_else(|| CompileError::Unknown(format!("cte target schema {cte}")))?;
563
564    state.node_schemas.insert(node, schema);
565    state.recipes.insert(node, NodeRecipe::CteRef { target });
566    Ok(())
567}
568
569// -----------------------------------------------------------------------------
570// Graph helpers
571// -----------------------------------------------------------------------------
572
573fn single_input(graph: &MirGraph, node: NodeIndex) -> Result<NodeIndex, CompileError> {
574    use petgraph::visit::EdgeRef;
575    let mut inputs = graph
576        .graph()
577        .edges_directed(node, Direction::Incoming)
578        .filter(|edge| matches!(edge.weight(), palimpsest_sql::mir::MirEdgeKind::Input))
579        .map(|edge| edge.source());
580    let first = inputs
581        .next()
582        .ok_or_else(|| CompileError::Unknown("expected input edge".to_owned()))?;
583    if inputs.next().is_some() {
584        return Err(CompileError::Unsupported("multi-input node".to_owned()));
585    }
586    Ok(first)
587}
588
589// -----------------------------------------------------------------------------
590// Install plan into a scope
591// -----------------------------------------------------------------------------
592
593/// Materialize a [`CompiledPlan`] inside the given timely `scope`.
594///
595/// The caller supplies `inputs` — one `VecCollection<G, Row, isize>`
596/// per `TableId` the plan references. The function returns the final
597/// `VecCollection<G, Row, isize>` whose rows are the query output;
598/// the caller is responsible for arranging it into a trace.
599pub fn install_plan<G>(
600    plan: &CompiledPlan,
601    scope: &mut G,
602    inputs: &HashMap<TableId, VecCollection<G, Row, isize>>,
603) -> VecCollection<G, Row, isize>
604where
605    G: timely::dataflow::Scope,
606    G::Timestamp: Lattice + Ord,
607{
608    let mut cache: HashMap<NodeIndex, VecCollection<G, Row, isize>> = HashMap::new();
609    install_recursive(plan, scope, inputs, plan.root, &mut cache)
610}
611
612fn install_recursive<G>(
613    plan: &CompiledPlan,
614    scope: &mut G,
615    inputs: &HashMap<TableId, VecCollection<G, Row, isize>>,
616    node: NodeIndex,
617    cache: &mut HashMap<NodeIndex, VecCollection<G, Row, isize>>,
618) -> VecCollection<G, Row, isize>
619where
620    G: timely::dataflow::Scope,
621    G::Timestamp: Lattice + Ord,
622{
623    if let Some(c) = cache.get(&node) {
624        return c.clone();
625    }
626
627    let recipe = plan
628        .recipes
629        .get(&node)
630        .expect("compile_mir guarantees a recipe per node");
631    let collection = match recipe {
632        NodeRecipe::BaseTable { table } => inputs
633            .get(table)
634            .expect("install_plan caller wires every BaseTable input")
635            .clone(),
636        NodeRecipe::Filter { predicate } => {
637            let input_node = single_input(&plan.graph, node).expect("compile_mir validated");
638            let input = install_recursive(plan, scope, inputs, input_node, cache);
639            let pred = Arc::clone(predicate);
640            relational::filter(&input, move |row: &Row| pred(row))
641        }
642        NodeRecipe::Project { extract } => {
643            let input_node = single_input(&plan.graph, node).expect("compile_mir validated");
644            let input = install_recursive(plan, scope, inputs, input_node, cache);
645            let ext = Arc::clone(extract);
646            relational::project(&input, move |row: Row| ext(&row))
647        }
648        NodeRecipe::Aggregate {
649            group_extract,
650            value_extract,
651            funcs,
652        } => {
653            let input_node = single_input(&plan.graph, node).expect("compile_mir validated");
654            let input = install_recursive(plan, scope, inputs, input_node, cache);
655
656            // Project Row → (group_key, value).
657            let ge = Arc::clone(group_extract);
658            let ve = Arc::clone(value_extract);
659            let projected = relational::project(&input, move |row: Row| (ge(&row), ve(&row)));
660            let funcs = funcs.clone();
661            let aggregated = relational::aggregate_i64(&projected, funcs);
662
663            // Project (group_key, Vec<AggregateValue>) → Row. The
664            // group key keeps its original Datum type, so a Bool
665            // group column emits a Bool first column and matches the
666            // schema we advertised to clients.
667            relational::project(
668                &aggregated,
669                |(group, aggs): (Datum, Vec<AggregateValue>)| {
670                    let mut row: Row = SmallVec::with_capacity(1 + aggs.len());
671                    row.push(group);
672                    for av in aggs {
673                        let datum = match av {
674                            AggregateValue::Integer(v) => Datum::I64(saturating_i128_to_i64(v)),
675                            AggregateValue::Average { sum, count } => {
676                                let avg = if count == 0 {
677                                    0.0
678                                } else {
679                                    sum as f64 / count as f64
680                                };
681                                Datum::F64(avg.to_bits())
682                            }
683                        };
684                        row.push(datum);
685                    }
686                    row
687                },
688            )
689        }
690        NodeRecipe::TopK {
691            sort_key_extract,
692            direction,
693            limit,
694            offset,
695        } => {
696            let input_node = single_input(&plan.graph, node).expect("compile_mir validated");
697            let input = install_recursive(plan, scope, inputs, input_node, cache);
698
699            // Use the input's natural Ord — Row's lexicographic order
700            // doesn't always match the desired key. Pre-project to
701            // (sort_key, row) so TopK sorts by `sort_key`, then strip
702            // the prefix once the slice is selected.
703            let extract = Arc::clone(sort_key_extract);
704            let with_key = relational::project(&input, move |row: Row| (extract(&row), row));
705            let sliced = relational::topk(&with_key, *direction, *limit, *offset);
706            relational::project(&sliced, |(_, row): (i64, Row)| row)
707        }
708        NodeRecipe::CteRef { target } => install_recursive(plan, scope, inputs, *target, cache),
709    };
710
711    cache.insert(node, collection.clone());
712    collection
713}
714
715fn saturating_i128_to_i64(v: i128) -> i64 {
716    if v > i64::MAX as i128 {
717        i64::MAX
718    } else if v < i64::MIN as i128 {
719        i64::MIN
720    } else {
721        v as i64
722    }
723}
724
725#[cfg(test)]
726mod tests {
727    use super::*;
728    use crate::input::Input;
729    use palimpsest_sql::lower::parse_and_lower;
730
731    fn posts_schema() -> ScalarSchema {
732        ScalarSchema::from_pairs([
733            ("id".to_owned(), ColumnType::Int),
734            ("title".to_owned(), ColumnType::Text),
735            ("published".to_owned(), ColumnType::Bool),
736        ])
737    }
738
739    fn events_schema() -> ScalarSchema {
740        ScalarSchema::from_pairs([
741            ("id".to_owned(), ColumnType::Int),
742            ("category_id".to_owned(), ColumnType::Int),
743            ("value".to_owned(), ColumnType::Int),
744        ])
745    }
746
747    fn lookup(table: &str) -> Option<(TableId, ScalarSchema)> {
748        match table {
749            "posts" => Some((TableId::new(1), posts_schema())),
750            "events" => Some((TableId::new(2), events_schema())),
751            _ => None,
752        }
753    }
754
755    #[test]
756    fn compile_simple_select() {
757        let graph = parse_and_lower("SELECT id, title, published FROM posts").unwrap();
758        let plan = compile_mir(&graph, &lookup).unwrap();
759        assert_eq!(plan.inputs, vec![TableId::new(1)]);
760        assert_eq!(plan.output_schema.len(), 3);
761    }
762
763    #[test]
764    fn compile_filter() {
765        let graph =
766            parse_and_lower("SELECT id, title, published FROM posts WHERE published = true")
767                .unwrap();
768        let plan = compile_mir(&graph, &lookup).unwrap();
769        let recipes_include_filter = plan
770            .recipes
771            .values()
772            .any(|r| matches!(r, NodeRecipe::Filter { .. }));
773        assert!(recipes_include_filter);
774    }
775
776    #[test]
777    fn compile_aggregate_with_cte() {
778        let sql = "WITH per_category AS (
779            SELECT category_id, COUNT(*) AS n, SUM(value) AS total
780            FROM events
781            GROUP BY category_id
782        )
783        SELECT category_id, n, total
784        FROM per_category
785        ORDER BY total DESC
786        LIMIT 8";
787        let graph = parse_and_lower(sql).unwrap();
788        let plan = compile_mir(&graph, &lookup).unwrap();
789        assert_eq!(plan.inputs, vec![TableId::new(2)]);
790        assert_eq!(plan.output_schema.len(), 3);
791        let has_agg = plan
792            .recipes
793            .values()
794            .any(|r| matches!(r, NodeRecipe::Aggregate { .. }));
795        let has_topk = plan
796            .recipes
797            .values()
798            .any(|r| matches!(r, NodeRecipe::TopK { .. }));
799        assert!(has_agg, "aggregate recipe missing");
800        assert!(has_topk, "topk recipe missing");
801    }
802
803    fn datum_row(values: Vec<Datum>) -> Row {
804        values.into_iter().collect()
805    }
806
807    #[test]
808    fn aggregate_preserves_bool_group_key_type() {
809        // Regression: `GROUP BY <bool column>` used to coerce the
810        // group key into an `i64`, then emit `Datum::I64` on output —
811        // which produced a `schema/datum mismatch at column 0:
812        // schema=Bool, datum=i64` decode failure on the wire.
813        let sql = "SELECT published, COUNT(*) AS n
814                   FROM posts
815                   GROUP BY published";
816        let graph = parse_and_lower(sql).unwrap();
817        let posts_schema = ScalarSchema::from_pairs([
818            ("id".to_owned(), ColumnType::Int),
819            ("title".to_owned(), ColumnType::Text),
820            ("published".to_owned(), ColumnType::Bool),
821        ]);
822        let plan = compile_mir(&graph, &|table: &str| match table {
823            "posts" => Some((TableId::new(1), posts_schema.clone())),
824            _ => None,
825        })
826        .unwrap();
827        assert_eq!(
828            plan.output_schema.column_type("published"),
829            Some(ColumnType::Bool)
830        );
831        assert_eq!(plan.output_schema.column_type("n"), Some(ColumnType::Int));
832
833        // Drive the pipeline through timely and verify the emitted
834        // rows actually carry `Datum::Bool` at column 0.
835        let seed = vec![
836            datum_row(vec![
837                Datum::I64(1),
838                Datum::Text(bytes::Bytes::from_static(b"a")),
839                Datum::Bool(true),
840            ]),
841            datum_row(vec![
842                Datum::I64(2),
843                Datum::Text(bytes::Bytes::from_static(b"b")),
844                Datum::Bool(true),
845            ]),
846            datum_row(vec![
847                Datum::I64(3),
848                Datum::Text(bytes::Bytes::from_static(b"c")),
849                Datum::Bool(false),
850            ]),
851        ];
852
853        timely::example(move |scope| {
854            let (_, posts) = scope.new_collection_from(seed);
855            let mut inputs: HashMap<TableId, VecCollection<_, Row, isize>> = HashMap::new();
856            inputs.insert(TableId::new(1), posts);
857            let output = install_plan(&plan, scope, &inputs);
858
859            let expected = vec![
860                datum_row(vec![Datum::Bool(true), Datum::I64(2)]),
861                datum_row(vec![Datum::Bool(false), Datum::I64(1)]),
862            ];
863            let expected_coll = scope.new_collection_from(expected).1;
864            output.assert_eq(&expected_coll);
865        });
866    }
867
868    #[test]
869    fn install_aggregate_pipeline_emits_grouped_rows() {
870        let sql = "WITH per_category AS (
871            SELECT category_id, COUNT(*) AS n, SUM(value) AS total
872            FROM events
873            GROUP BY category_id
874        )
875        SELECT category_id, n, total
876        FROM per_category
877        ORDER BY total DESC
878        LIMIT 8";
879        let graph = parse_and_lower(sql).unwrap();
880        let plan = compile_mir(&graph, &lookup).unwrap();
881
882        let seed: Vec<Row> = vec![
883            datum_row(vec![Datum::I64(1), Datum::I64(7), Datum::I64(100)]),
884            datum_row(vec![Datum::I64(2), Datum::I64(7), Datum::I64(50)]),
885            datum_row(vec![Datum::I64(3), Datum::I64(9), Datum::I64(20)]),
886            datum_row(vec![Datum::I64(4), Datum::I64(9), Datum::I64(20)]),
887        ];
888
889        let expected: Vec<Row> = vec![
890            datum_row(vec![Datum::I64(7), Datum::I64(2), Datum::I64(150)]),
891            datum_row(vec![Datum::I64(9), Datum::I64(2), Datum::I64(40)]),
892        ];
893
894        timely::example(move |scope| {
895            let (_, posts) = scope.new_collection_from(Vec::<Row>::new());
896            let (_, events) = scope.new_collection_from(seed);
897            let mut inputs: HashMap<TableId, VecCollection<_, Row, isize>> = HashMap::new();
898            inputs.insert(TableId::new(1), posts);
899            inputs.insert(TableId::new(2), events);
900
901            let output = install_plan(&plan, scope, &inputs);
902            let expected = scope.new_collection_from(expected).1;
903            output.assert_eq(&expected);
904        });
905    }
906}