1use std::collections::HashMap;
31use std::sync::Arc;
32
33use palimpsest_sql::catalog::ColumnType;
34use palimpsest_sql::mir::{AggExpr, ColumnRef, MirGraph, MirNodeKind, OrderKey};
35use palimpsest_wal::{Datum, TableId};
36use petgraph::graph::NodeIndex;
37use petgraph::Direction;
38use smallvec::SmallVec;
39use thiserror::Error;
40
41use crate::palimpsest::eval::{compile_predicate, EvalError, ScalarSchema};
42use crate::palimpsest::relational::{self, AggregateFunc, AggregateValue, SortDirection};
43use crate::palimpsest::wal::Row;
44use crate::{lattice::Lattice, VecCollection};
45
46#[derive(Clone)]
54pub struct CompiledPlan {
55 pub graph: MirGraph,
58 pub root: NodeIndex,
60 pub inputs: Vec<TableId>,
63 pub input_schemas: HashMap<TableId, ScalarSchema>,
67 pub output_schema: ScalarSchema,
69 pub node_schemas: HashMap<NodeIndex, ScalarSchema>,
73 pub recipes: HashMap<NodeIndex, NodeRecipe>,
75}
76
77#[derive(Clone)]
82pub enum NodeRecipe {
83 BaseTable {
85 table: TableId,
87 },
88 Filter {
90 predicate: Arc<dyn Fn(&Row) -> bool + Send + Sync>,
92 },
93 Project {
95 extract: Arc<dyn Fn(&Row) -> Row + Send + Sync>,
98 },
99 Aggregate {
104 group_extract: Arc<dyn Fn(&Row) -> Datum + Send + Sync>,
110 value_extract: Arc<dyn Fn(&Row) -> i64 + Send + Sync>,
114 funcs: Vec<AggregateFunc>,
116 },
117 TopK {
119 sort_key_extract: Arc<dyn Fn(&Row) -> i64 + Send + Sync>,
121 direction: SortDirection,
123 limit: usize,
125 offset: usize,
127 },
128 CteRef {
130 target: NodeIndex,
132 },
133}
134
135#[derive(Debug, Error)]
138pub enum CompileError {
139 #[error("unsupported MIR node: {0}")]
141 Unsupported(String),
142 #[error("expression: {0}")]
144 Expression(#[from] EvalError),
145 #[error("unknown identifier: {0}")]
148 Unknown(String),
149 #[error("MIR graph has a cycle")]
151 Cycle,
152 #[error("unsupported aggregate function: {0}")]
155 UnsupportedAggregate(String),
156 #[error("multi-column GROUP BY not yet supported")]
158 MultiColumnGroupBy,
159 #[error("aggregate columns disagree: {0}")]
162 HeterogeneousAggregateColumns(String),
163 #[error("multi-column ORDER BY not yet supported")]
165 MultiColumnOrderBy,
166}
167
168pub trait TableSchemaLookup {
172 fn lookup(&self, table: &str) -> Option<(TableId, ScalarSchema)>;
175}
176
177impl<F> TableSchemaLookup for F
178where
179 F: Fn(&str) -> Option<(TableId, ScalarSchema)>,
180{
181 fn lookup(&self, table: &str) -> Option<(TableId, ScalarSchema)> {
182 (self)(table)
183 }
184}
185
186pub fn compile_mir<L: TableSchemaLookup>(
196 graph: &MirGraph,
197 tables: &L,
198) -> Result<CompiledPlan, CompileError> {
199 let topo = petgraph::algo::toposort(graph.graph(), None).map_err(|_| CompileError::Cycle)?;
200
201 let mut state = CompileState {
202 node_schemas: HashMap::new(),
203 recipes: HashMap::new(),
204 inputs: Vec::new(),
205 input_schemas: HashMap::new(),
206 };
207
208 for node in topo {
209 compile_node(graph, node, tables, &mut state)?;
210 }
211
212 let root = graph.root();
213 let output_schema = state.node_schemas.get(&root).cloned().unwrap_or_default();
214
215 Ok(CompiledPlan {
216 graph: graph.clone(),
217 root,
218 inputs: state.inputs,
219 input_schemas: state.input_schemas,
220 output_schema,
221 node_schemas: state.node_schemas,
222 recipes: state.recipes,
223 })
224}
225
226struct CompileState {
227 node_schemas: HashMap<NodeIndex, ScalarSchema>,
228 recipes: HashMap<NodeIndex, NodeRecipe>,
229 inputs: Vec<TableId>,
230 input_schemas: HashMap<TableId, ScalarSchema>,
231}
232
233fn compile_node<L: TableSchemaLookup>(
234 graph: &MirGraph,
235 node: NodeIndex,
236 tables: &L,
237 state: &mut CompileState,
238) -> Result<(), CompileError> {
239 let kind = graph.node_kind(node);
240 match kind {
241 MirNodeKind::BaseTable { table, project } => {
242 compile_base_table(node, table, project, tables, state)
243 }
244 MirNodeKind::Filter { predicate } => compile_filter(graph, node, predicate, state),
245 MirNodeKind::Project { columns } => compile_project(graph, node, columns, state),
246 MirNodeKind::Aggregate { group_by, aggs } => {
247 compile_aggregate(graph, node, group_by, aggs, state)
248 }
249 MirNodeKind::TopK {
250 order_by,
251 limit,
252 offset,
253 } => compile_topk(graph, node, order_by, *limit, *offset, state),
254 MirNodeKind::CteRef { cte } => compile_cte_ref(graph, node, cte, state),
255 MirNodeKind::Join { .. } => Err(CompileError::Unsupported("Join".to_owned())),
256 MirNodeKind::Distinct => Err(CompileError::Unsupported("Distinct".to_owned())),
257 MirNodeKind::Union { .. } => Err(CompileError::Unsupported("Union".to_owned())),
258 MirNodeKind::Except { .. } => Err(CompileError::Unsupported("Except".to_owned())),
259 MirNodeKind::Intersect { .. } => Err(CompileError::Unsupported("Intersect".to_owned())),
260 MirNodeKind::Leaf { .. } => Err(CompileError::Unsupported("Leaf".to_owned())),
261 }
262}
263
264fn compile_base_table<L: TableSchemaLookup>(
269 node: NodeIndex,
270 table: &str,
271 project: &[ColumnRef],
272 tables: &L,
273 state: &mut CompileState,
274) -> Result<(), CompileError> {
275 let (table_id, full_schema) = tables
276 .lookup(table)
277 .ok_or_else(|| CompileError::Unknown(format!("table {table}")))?;
278
279 let schema = if project.is_empty() {
282 full_schema.clone()
283 } else {
284 let pairs = project
285 .iter()
286 .map(|col| {
287 full_schema
288 .column_type(&col.name)
289 .ok_or_else(|| CompileError::Unknown(format!("{table}.{}", col.name)))
290 .map(|ty| (col.name.clone(), ty))
291 })
292 .collect::<Result<Vec<_>, _>>()?;
293 ScalarSchema::from_pairs(pairs)
294 };
295
296 if !state.input_schemas.contains_key(&table_id) {
297 state.inputs.push(table_id);
298 state.input_schemas.insert(table_id, full_schema);
299 }
300
301 state.node_schemas.insert(node, schema);
302 state
303 .recipes
304 .insert(node, NodeRecipe::BaseTable { table: table_id });
305 Ok(())
306}
307
308fn compile_filter(
309 graph: &MirGraph,
310 node: NodeIndex,
311 predicate: &str,
312 state: &mut CompileState,
313) -> Result<(), CompileError> {
314 let input_node = single_input(graph, node)?;
315 let input_schema = state
316 .node_schemas
317 .get(&input_node)
318 .ok_or_else(|| CompileError::Unknown("filter input schema".to_owned()))?
319 .clone();
320
321 let pred = compile_predicate(predicate, &input_schema)?;
322 let pred: Arc<dyn Fn(&Row) -> bool + Send + Sync> = Arc::from(pred);
323
324 state.node_schemas.insert(node, input_schema);
325 state
326 .recipes
327 .insert(node, NodeRecipe::Filter { predicate: pred });
328 Ok(())
329}
330
331fn compile_project(
332 graph: &MirGraph,
333 node: NodeIndex,
334 columns: &[String],
335 state: &mut CompileState,
336) -> Result<(), CompileError> {
337 let input_node = single_input(graph, node)?;
338 let input_schema = state
339 .node_schemas
340 .get(&input_node)
341 .ok_or_else(|| CompileError::Unknown("project input schema".to_owned()))?
342 .clone();
343
344 let mut indices = Vec::with_capacity(columns.len());
345 let mut output_pairs = Vec::with_capacity(columns.len());
346 for col in columns {
347 let idx = input_schema
348 .index_of(col)
349 .ok_or_else(|| CompileError::Unknown(format!("project column {col}")))?;
350 let ty = input_schema
351 .column_type(col)
352 .expect("type for known column");
353 indices.push(idx);
354 output_pairs.push((col.clone(), ty));
355 }
356
357 let output_schema = ScalarSchema::from_pairs(output_pairs);
358 let indices_owned = indices;
359 let extract: Arc<dyn Fn(&Row) -> Row + Send + Sync> = Arc::new(move |row: &Row| {
360 let mut out: Row = SmallVec::with_capacity(indices_owned.len());
361 for &i in &indices_owned {
362 out.push(row.get(i).cloned().unwrap_or(Datum::Null));
363 }
364 out
365 });
366
367 state.node_schemas.insert(node, output_schema);
368 state.recipes.insert(node, NodeRecipe::Project { extract });
369 Ok(())
370}
371
372fn compile_aggregate(
373 graph: &MirGraph,
374 node: NodeIndex,
375 group_by: &[ColumnRef],
376 aggs: &[AggExpr],
377 state: &mut CompileState,
378) -> Result<(), CompileError> {
379 let input_node = single_input(graph, node)?;
380 let input_schema = state
381 .node_schemas
382 .get(&input_node)
383 .ok_or_else(|| CompileError::Unknown("aggregate input schema".to_owned()))?
384 .clone();
385
386 if group_by.len() != 1 {
387 return Err(CompileError::MultiColumnGroupBy);
388 }
389 let group_col = &group_by[0].name;
390 let group_idx = input_schema
391 .index_of(group_col)
392 .ok_or_else(|| CompileError::Unknown(format!("group column {group_col}")))?;
393 let group_type = input_schema
394 .column_type(group_col)
395 .expect("type for known column");
396
397 let group_extract: Arc<dyn Fn(&Row) -> Datum + Send + Sync> =
403 Arc::new(move |row: &Row| row.get(group_idx).cloned().unwrap_or(Datum::Null));
404
405 let mut value_column: Option<String> = None;
410 let mut funcs = Vec::with_capacity(aggs.len());
411 let mut output_pairs = Vec::with_capacity(group_by.len() + aggs.len());
412 output_pairs.push((group_col.clone(), group_type));
413
414 for agg in aggs {
415 let func = parse_agg_func(&agg.function)?;
416 funcs.push(func);
417
418 let arg_text = agg.args.first().map(String::as_str).unwrap_or("*");
420 let arg_col = arg_text.trim();
421 if arg_col != "*" && !matches!(func, AggregateFunc::Count) {
422 match &value_column {
423 None => value_column = Some(arg_col.to_owned()),
424 Some(prev) if prev == arg_col => {}
425 Some(prev) => {
426 return Err(CompileError::HeterogeneousAggregateColumns(format!(
427 "{prev} vs {arg_col}"
428 )));
429 }
430 }
431 }
432
433 let alias = agg
434 .alias
435 .clone()
436 .unwrap_or_else(|| format!("{}_{}", agg.function.to_lowercase(), output_pairs.len()));
437 let output_type = match func {
438 AggregateFunc::Avg => ColumnType::Float,
439 _ => ColumnType::Int,
440 };
441 output_pairs.push((alias, output_type));
442 }
443
444 let value_extract: Arc<dyn Fn(&Row) -> i64 + Send + Sync> = match value_column {
447 None => Arc::new(|_row: &Row| 0),
448 Some(col) => {
449 let value_idx = input_schema
450 .index_of(&col)
451 .ok_or_else(|| CompileError::Unknown(format!("aggregate column {col}")))?;
452 Arc::new(move |row: &Row| match row.get(value_idx) {
453 Some(Datum::I64(v)) => *v,
454 Some(Datum::I32(v)) => i64::from(*v),
455 Some(Datum::I16(v)) => i64::from(*v),
456 _ => 0,
457 })
458 }
459 };
460
461 let output_schema = ScalarSchema::from_pairs(output_pairs);
462
463 state.node_schemas.insert(node, output_schema);
464 state.recipes.insert(
465 node,
466 NodeRecipe::Aggregate {
467 group_extract,
468 value_extract,
469 funcs,
470 },
471 );
472 Ok(())
473}
474
475fn parse_agg_func(name: &str) -> Result<AggregateFunc, CompileError> {
476 match name.to_ascii_lowercase().as_str() {
477 "count" => Ok(AggregateFunc::Count),
478 "sum" => Ok(AggregateFunc::Sum),
479 "min" => Ok(AggregateFunc::Min),
480 "max" => Ok(AggregateFunc::Max),
481 "avg" => Ok(AggregateFunc::Avg),
482 other => Err(CompileError::UnsupportedAggregate(other.to_owned())),
483 }
484}
485
486fn compile_topk(
487 graph: &MirGraph,
488 node: NodeIndex,
489 order_by: &[OrderKey],
490 limit: usize,
491 offset: usize,
492 state: &mut CompileState,
493) -> Result<(), CompileError> {
494 let input_node = single_input(graph, node)?;
495 let input_schema = state
496 .node_schemas
497 .get(&input_node)
498 .ok_or_else(|| CompileError::Unknown("topk input schema".to_owned()))?
499 .clone();
500
501 if order_by.len() != 1 {
502 return Err(CompileError::MultiColumnOrderBy);
503 }
504 let key = &order_by[0];
505 let sort_idx = input_schema
506 .index_of(&key.expression)
507 .ok_or_else(|| CompileError::Unknown(format!("order column {}", key.expression)))?;
508
509 let sort_key_extract: Arc<dyn Fn(&Row) -> i64 + Send + Sync> =
510 Arc::new(move |row: &Row| match row.get(sort_idx) {
511 Some(Datum::I64(v)) => *v,
512 Some(Datum::I32(v)) => i64::from(*v),
513 Some(Datum::I16(v)) => i64::from(*v),
514 _ => 0,
515 });
516
517 let direction = if key.descending {
518 SortDirection::Descending
519 } else {
520 SortDirection::Ascending
521 };
522
523 state.node_schemas.insert(node, input_schema);
524 state.recipes.insert(
525 node,
526 NodeRecipe::TopK {
527 sort_key_extract,
528 direction,
529 limit,
530 offset,
531 },
532 );
533 Ok(())
534}
535
536fn compile_cte_ref(
537 graph: &MirGraph,
538 node: NodeIndex,
539 cte: &str,
540 state: &mut CompileState,
541) -> Result<(), CompileError> {
542 use petgraph::visit::EdgeRef;
546 let target = graph
547 .graph()
548 .edges_directed(node, Direction::Incoming)
549 .find(|edge| {
550 matches!(
551 edge.weight(),
552 palimpsest_sql::mir::MirEdgeKind::CteExpansion
553 )
554 })
555 .map(|edge| edge.source());
556 let target = target.ok_or_else(|| CompileError::Unknown(format!("cte {cte}")))?;
557
558 let schema = state
559 .node_schemas
560 .get(&target)
561 .cloned()
562 .ok_or_else(|| CompileError::Unknown(format!("cte target schema {cte}")))?;
563
564 state.node_schemas.insert(node, schema);
565 state.recipes.insert(node, NodeRecipe::CteRef { target });
566 Ok(())
567}
568
569fn single_input(graph: &MirGraph, node: NodeIndex) -> Result<NodeIndex, CompileError> {
574 use petgraph::visit::EdgeRef;
575 let mut inputs = graph
576 .graph()
577 .edges_directed(node, Direction::Incoming)
578 .filter(|edge| matches!(edge.weight(), palimpsest_sql::mir::MirEdgeKind::Input))
579 .map(|edge| edge.source());
580 let first = inputs
581 .next()
582 .ok_or_else(|| CompileError::Unknown("expected input edge".to_owned()))?;
583 if inputs.next().is_some() {
584 return Err(CompileError::Unsupported("multi-input node".to_owned()));
585 }
586 Ok(first)
587}
588
589pub fn install_plan<G>(
600 plan: &CompiledPlan,
601 scope: &mut G,
602 inputs: &HashMap<TableId, VecCollection<G, Row, isize>>,
603) -> VecCollection<G, Row, isize>
604where
605 G: timely::dataflow::Scope,
606 G::Timestamp: Lattice + Ord,
607{
608 let mut cache: HashMap<NodeIndex, VecCollection<G, Row, isize>> = HashMap::new();
609 install_recursive(plan, scope, inputs, plan.root, &mut cache)
610}
611
612fn install_recursive<G>(
613 plan: &CompiledPlan,
614 scope: &mut G,
615 inputs: &HashMap<TableId, VecCollection<G, Row, isize>>,
616 node: NodeIndex,
617 cache: &mut HashMap<NodeIndex, VecCollection<G, Row, isize>>,
618) -> VecCollection<G, Row, isize>
619where
620 G: timely::dataflow::Scope,
621 G::Timestamp: Lattice + Ord,
622{
623 if let Some(c) = cache.get(&node) {
624 return c.clone();
625 }
626
627 let recipe = plan
628 .recipes
629 .get(&node)
630 .expect("compile_mir guarantees a recipe per node");
631 let collection = match recipe {
632 NodeRecipe::BaseTable { table } => inputs
633 .get(table)
634 .expect("install_plan caller wires every BaseTable input")
635 .clone(),
636 NodeRecipe::Filter { predicate } => {
637 let input_node = single_input(&plan.graph, node).expect("compile_mir validated");
638 let input = install_recursive(plan, scope, inputs, input_node, cache);
639 let pred = Arc::clone(predicate);
640 relational::filter(&input, move |row: &Row| pred(row))
641 }
642 NodeRecipe::Project { extract } => {
643 let input_node = single_input(&plan.graph, node).expect("compile_mir validated");
644 let input = install_recursive(plan, scope, inputs, input_node, cache);
645 let ext = Arc::clone(extract);
646 relational::project(&input, move |row: Row| ext(&row))
647 }
648 NodeRecipe::Aggregate {
649 group_extract,
650 value_extract,
651 funcs,
652 } => {
653 let input_node = single_input(&plan.graph, node).expect("compile_mir validated");
654 let input = install_recursive(plan, scope, inputs, input_node, cache);
655
656 let ge = Arc::clone(group_extract);
658 let ve = Arc::clone(value_extract);
659 let projected = relational::project(&input, move |row: Row| (ge(&row), ve(&row)));
660 let funcs = funcs.clone();
661 let aggregated = relational::aggregate_i64(&projected, funcs);
662
663 relational::project(
668 &aggregated,
669 |(group, aggs): (Datum, Vec<AggregateValue>)| {
670 let mut row: Row = SmallVec::with_capacity(1 + aggs.len());
671 row.push(group);
672 for av in aggs {
673 let datum = match av {
674 AggregateValue::Integer(v) => Datum::I64(saturating_i128_to_i64(v)),
675 AggregateValue::Average { sum, count } => {
676 let avg = if count == 0 {
677 0.0
678 } else {
679 sum as f64 / count as f64
680 };
681 Datum::F64(avg.to_bits())
682 }
683 };
684 row.push(datum);
685 }
686 row
687 },
688 )
689 }
690 NodeRecipe::TopK {
691 sort_key_extract,
692 direction,
693 limit,
694 offset,
695 } => {
696 let input_node = single_input(&plan.graph, node).expect("compile_mir validated");
697 let input = install_recursive(plan, scope, inputs, input_node, cache);
698
699 let extract = Arc::clone(sort_key_extract);
704 let with_key = relational::project(&input, move |row: Row| (extract(&row), row));
705 let sliced = relational::topk(&with_key, *direction, *limit, *offset);
706 relational::project(&sliced, |(_, row): (i64, Row)| row)
707 }
708 NodeRecipe::CteRef { target } => install_recursive(plan, scope, inputs, *target, cache),
709 };
710
711 cache.insert(node, collection.clone());
712 collection
713}
714
715fn saturating_i128_to_i64(v: i128) -> i64 {
716 if v > i64::MAX as i128 {
717 i64::MAX
718 } else if v < i64::MIN as i128 {
719 i64::MIN
720 } else {
721 v as i64
722 }
723}
724
725#[cfg(test)]
726mod tests {
727 use super::*;
728 use crate::input::Input;
729 use palimpsest_sql::lower::parse_and_lower;
730
731 fn posts_schema() -> ScalarSchema {
732 ScalarSchema::from_pairs([
733 ("id".to_owned(), ColumnType::Int),
734 ("title".to_owned(), ColumnType::Text),
735 ("published".to_owned(), ColumnType::Bool),
736 ])
737 }
738
739 fn events_schema() -> ScalarSchema {
740 ScalarSchema::from_pairs([
741 ("id".to_owned(), ColumnType::Int),
742 ("category_id".to_owned(), ColumnType::Int),
743 ("value".to_owned(), ColumnType::Int),
744 ])
745 }
746
747 fn lookup(table: &str) -> Option<(TableId, ScalarSchema)> {
748 match table {
749 "posts" => Some((TableId::new(1), posts_schema())),
750 "events" => Some((TableId::new(2), events_schema())),
751 _ => None,
752 }
753 }
754
755 #[test]
756 fn compile_simple_select() {
757 let graph = parse_and_lower("SELECT id, title, published FROM posts").unwrap();
758 let plan = compile_mir(&graph, &lookup).unwrap();
759 assert_eq!(plan.inputs, vec![TableId::new(1)]);
760 assert_eq!(plan.output_schema.len(), 3);
761 }
762
763 #[test]
764 fn compile_filter() {
765 let graph =
766 parse_and_lower("SELECT id, title, published FROM posts WHERE published = true")
767 .unwrap();
768 let plan = compile_mir(&graph, &lookup).unwrap();
769 let recipes_include_filter = plan
770 .recipes
771 .values()
772 .any(|r| matches!(r, NodeRecipe::Filter { .. }));
773 assert!(recipes_include_filter);
774 }
775
776 #[test]
777 fn compile_aggregate_with_cte() {
778 let sql = "WITH per_category AS (
779 SELECT category_id, COUNT(*) AS n, SUM(value) AS total
780 FROM events
781 GROUP BY category_id
782 )
783 SELECT category_id, n, total
784 FROM per_category
785 ORDER BY total DESC
786 LIMIT 8";
787 let graph = parse_and_lower(sql).unwrap();
788 let plan = compile_mir(&graph, &lookup).unwrap();
789 assert_eq!(plan.inputs, vec![TableId::new(2)]);
790 assert_eq!(plan.output_schema.len(), 3);
791 let has_agg = plan
792 .recipes
793 .values()
794 .any(|r| matches!(r, NodeRecipe::Aggregate { .. }));
795 let has_topk = plan
796 .recipes
797 .values()
798 .any(|r| matches!(r, NodeRecipe::TopK { .. }));
799 assert!(has_agg, "aggregate recipe missing");
800 assert!(has_topk, "topk recipe missing");
801 }
802
803 fn datum_row(values: Vec<Datum>) -> Row {
804 values.into_iter().collect()
805 }
806
807 #[test]
808 fn aggregate_preserves_bool_group_key_type() {
809 let sql = "SELECT published, COUNT(*) AS n
814 FROM posts
815 GROUP BY published";
816 let graph = parse_and_lower(sql).unwrap();
817 let posts_schema = ScalarSchema::from_pairs([
818 ("id".to_owned(), ColumnType::Int),
819 ("title".to_owned(), ColumnType::Text),
820 ("published".to_owned(), ColumnType::Bool),
821 ]);
822 let plan = compile_mir(&graph, &|table: &str| match table {
823 "posts" => Some((TableId::new(1), posts_schema.clone())),
824 _ => None,
825 })
826 .unwrap();
827 assert_eq!(
828 plan.output_schema.column_type("published"),
829 Some(ColumnType::Bool)
830 );
831 assert_eq!(plan.output_schema.column_type("n"), Some(ColumnType::Int));
832
833 let seed = vec![
836 datum_row(vec![
837 Datum::I64(1),
838 Datum::Text(bytes::Bytes::from_static(b"a")),
839 Datum::Bool(true),
840 ]),
841 datum_row(vec![
842 Datum::I64(2),
843 Datum::Text(bytes::Bytes::from_static(b"b")),
844 Datum::Bool(true),
845 ]),
846 datum_row(vec![
847 Datum::I64(3),
848 Datum::Text(bytes::Bytes::from_static(b"c")),
849 Datum::Bool(false),
850 ]),
851 ];
852
853 timely::example(move |scope| {
854 let (_, posts) = scope.new_collection_from(seed);
855 let mut inputs: HashMap<TableId, VecCollection<_, Row, isize>> = HashMap::new();
856 inputs.insert(TableId::new(1), posts);
857 let output = install_plan(&plan, scope, &inputs);
858
859 let expected = vec![
860 datum_row(vec![Datum::Bool(true), Datum::I64(2)]),
861 datum_row(vec![Datum::Bool(false), Datum::I64(1)]),
862 ];
863 let expected_coll = scope.new_collection_from(expected).1;
864 output.assert_eq(&expected_coll);
865 });
866 }
867
868 #[test]
869 fn install_aggregate_pipeline_emits_grouped_rows() {
870 let sql = "WITH per_category AS (
871 SELECT category_id, COUNT(*) AS n, SUM(value) AS total
872 FROM events
873 GROUP BY category_id
874 )
875 SELECT category_id, n, total
876 FROM per_category
877 ORDER BY total DESC
878 LIMIT 8";
879 let graph = parse_and_lower(sql).unwrap();
880 let plan = compile_mir(&graph, &lookup).unwrap();
881
882 let seed: Vec<Row> = vec![
883 datum_row(vec![Datum::I64(1), Datum::I64(7), Datum::I64(100)]),
884 datum_row(vec![Datum::I64(2), Datum::I64(7), Datum::I64(50)]),
885 datum_row(vec![Datum::I64(3), Datum::I64(9), Datum::I64(20)]),
886 datum_row(vec![Datum::I64(4), Datum::I64(9), Datum::I64(20)]),
887 ];
888
889 let expected: Vec<Row> = vec![
890 datum_row(vec![Datum::I64(7), Datum::I64(2), Datum::I64(150)]),
891 datum_row(vec![Datum::I64(9), Datum::I64(2), Datum::I64(40)]),
892 ];
893
894 timely::example(move |scope| {
895 let (_, posts) = scope.new_collection_from(Vec::<Row>::new());
896 let (_, events) = scope.new_collection_from(seed);
897 let mut inputs: HashMap<TableId, VecCollection<_, Row, isize>> = HashMap::new();
898 inputs.insert(TableId::new(1), posts);
899 inputs.insert(TableId::new(2), events);
900
901 let output = install_plan(&plan, scope, &inputs);
902 let expected = scope.new_collection_from(expected).1;
903 output.assert_eq(&expected);
904 });
905 }
906}