1use super::{
2 aggregate_function_to_proof_expr, column_to_column_ref, expr_to_proof_expr,
3 schema_to_column_fields, table_reference_to_table_ref, AggregateFunc, PlannerError,
4 PlannerResult,
5};
6use alloc::vec::Vec;
7use datafusion::{
8 common::{DFSchema, JoinConstraint, JoinType},
9 logical_expr::{
10 expr::Alias, Aggregate, Expr, Join, Limit, LogicalPlan, Projection, TableScan, Union,
11 },
12 sql::{sqlparser::ast::Ident, TableReference},
13};
14use indexmap::{IndexMap, IndexSet};
15use proof_of_sql::{
16 base::database::{ColumnRef, ColumnType, LiteralValue, SchemaAccessor, TableRef},
17 sql::{
18 proof::ProofPlan,
19 proof_exprs::{AliasedDynProofExpr, ColumnExpr, DynProofExpr, TableExpr},
20 proof_plans::{DynProofPlan, SortMergeJoinExec},
21 },
22};
23
24fn get_aliased_dyn_proof_exprs(
32 table_ref: &TableRef,
33 projection: &[usize],
34 input_schema: &[(Ident, ColumnType)],
35 output_schema: &DFSchema,
36) -> PlannerResult<Vec<AliasedDynProofExpr>> {
37 projection
38 .iter()
39 .enumerate()
40 .map(
41 |(output_index, input_index)| -> PlannerResult<AliasedDynProofExpr> {
42 let alias: Ident = output_schema.field(output_index).name().as_str().into();
44 let (input_column_name, data_type) = input_schema
45 .get(*input_index)
46 .ok_or(PlannerError::ColumnNotFound)?;
47 let expr = DynProofExpr::new_column(ColumnRef::new(
48 table_ref.clone(),
49 input_column_name.clone(),
50 *data_type,
51 ));
52 Ok(AliasedDynProofExpr { expr, alias })
53 },
54 )
55 .collect::<PlannerResult<Vec<_>>>()
56}
57
58fn table_scan_to_projection(
60 table_name: &TableReference,
61 schemas: &impl SchemaAccessor,
62 projection: &[usize],
63 projected_schema: &DFSchema,
64) -> PlannerResult<DynProofPlan> {
65 let table_ref = table_reference_to_table_ref(table_name)?;
67 let input_schema = schemas.lookup_schema(&table_ref);
68 let aliased_dyn_proof_exprs =
70 get_aliased_dyn_proof_exprs(&table_ref, projection, &input_schema, projected_schema)?;
71 let input_column_fields = schema_to_column_fields(input_schema);
72 let table_exec = DynProofPlan::new_table(table_ref, input_column_fields);
73 Ok(DynProofPlan::new_projection(
74 aliased_dyn_proof_exprs,
75 table_exec,
76 ))
77}
78
79fn table_scan_to_filter(
84 table_name: &TableReference,
85 schemas: &impl SchemaAccessor,
86 projection: &[usize],
87 projected_schema: &DFSchema,
88 filters: &[Expr],
89) -> PlannerResult<DynProofPlan> {
90 let table_ref = table_reference_to_table_ref(table_name)?;
92 let input_schema = schemas.lookup_schema(&table_ref);
93 let aliased_dyn_proof_exprs =
95 get_aliased_dyn_proof_exprs(&table_ref, projection, &input_schema, projected_schema)?;
96 let table_expr = TableExpr { table_ref };
97 let consolidated_filter_proof_expr = filters
99 .iter()
100 .map(|f| expr_to_proof_expr(f, &input_schema))
101 .reduce(|a, b| Ok(DynProofExpr::try_new_and(a?, b?)?))
102 .expect("At least one filter expression is required")?;
103 Ok(DynProofPlan::new_filter(
104 aliased_dyn_proof_exprs,
105 table_expr,
106 consolidated_filter_proof_expr,
107 ))
108}
109
110fn try_get_schema_as_vec_from_df_schema(
111 df_schema: &DFSchema,
112) -> PlannerResult<Vec<(Ident, ColumnType)>> {
113 df_schema
114 .inner()
115 .fields()
116 .into_iter()
117 .map(|f| {
118 ColumnType::try_from(f.data_type().clone())
119 .map_err(|_| PlannerError::UnsupportedDataType {
120 data_type: f.data_type().clone(),
121 })
122 .map(|t| (Ident::from(f.name().as_ref()), t))
123 })
124 .collect::<Result<Vec<_>, _>>()
125}
126
127fn projection_to_proof_plan(
129 expr: &[Expr],
130 input: &LogicalPlan,
131 output_schema: &DFSchema,
132 schemas: &impl SchemaAccessor,
133) -> PlannerResult<DynProofPlan> {
134 let input_plan = logical_plan_to_proof_plan(input, schemas)?;
135 let input_schema = try_get_schema_as_vec_from_df_schema(input.schema())?;
136 let aliased_exprs = expr
137 .iter()
138 .zip(output_schema.fields().into_iter())
139 .map(|(e, field)| -> PlannerResult<AliasedDynProofExpr> {
140 let proof_expr = expr_to_proof_expr(e, &input_schema)?;
141 let alias = field.name().as_str().into();
142 Ok(AliasedDynProofExpr {
143 expr: proof_expr,
144 alias,
145 })
146 })
147 .collect::<PlannerResult<Vec<_>>>()?;
148 Ok(DynProofPlan::new_projection(aliased_exprs, input_plan))
149}
150
151fn aggregate_to_proof_plan(
158 input: &LogicalPlan,
159 group_expr: &[Expr],
160 aggr_expr: &[Expr],
161 schemas: &impl SchemaAccessor,
162 alias_map: &IndexMap<&str, &str>,
163) -> PlannerResult<DynProofPlan> {
164 let group_columns = group_expr
166 .iter()
167 .map(|e| match e {
168 Expr::Column(c) => Ok(c),
169 _ => Err(PlannerError::UnsupportedLogicalPlan {
170 plan: Box::new(input.clone()),
171 }),
172 })
173 .collect::<PlannerResult<Vec<_>>>()?;
174 match input {
175 LogicalPlan::TableScan(TableScan {
177 table_name,
178 filters,
179 fetch: None,
180 ..
181 }) => {
182 let table_ref = table_reference_to_table_ref(table_name)?;
183 let input_schema = schemas.lookup_schema(&table_ref);
184 let table_expr = TableExpr { table_ref };
185 let consolidated_filter_proof_expr = filters
187 .iter()
188 .map(|f| expr_to_proof_expr(f, &input_schema))
189 .reduce(|a, b| Ok(DynProofExpr::try_new_and(a?, b?)?))
190 .unwrap_or_else(|| Ok(DynProofExpr::new_literal(LiteralValue::Boolean(true))))?;
191 if aggr_expr.is_empty() {
197 return Err(PlannerError::UnsupportedLogicalPlan {
198 plan: Box::new(input.clone()),
199 });
200 }
201 let agg_aliased_proof_exprs: Vec<((AggregateFunc, DynProofExpr), Ident)> = aggr_expr
202 .iter()
203 .map(|e| match e.clone().unalias() {
204 Expr::AggregateFunction(agg) => {
205 let name_string = e.display_name()?;
206 let name = name_string.as_str();
207 let alias = alias_map.get(&name).ok_or_else(|| {
208 PlannerError::UnsupportedLogicalPlan {
209 plan: Box::new(input.clone()),
210 }
211 })?;
212 Ok((
213 aggregate_function_to_proof_expr(&agg, &input_schema)?,
214 (*alias).into(),
215 ))
216 }
217 _ => Err(PlannerError::UnsupportedLogicalPlan {
218 plan: Box::new(input.clone()),
219 }),
220 })
221 .collect::<PlannerResult<Vec<_>>>()?;
222 let (sum_tuples, count_tuple) =
224 agg_aliased_proof_exprs.split_at(agg_aliased_proof_exprs.len() - 1);
225 let sum_is_compliant = sum_tuples
226 .iter()
227 .all(|((op, _), _)| matches!(op, AggregateFunc::Sum));
228 let count_is_compliant = count_tuple
229 .iter()
230 .all(|((op, _), _)| matches!(op, AggregateFunc::Count));
231 if !sum_is_compliant || !count_is_compliant {
232 return Err(PlannerError::UnsupportedLogicalPlan {
233 plan: Box::new(input.clone()),
234 });
235 }
236 let count_alias = agg_aliased_proof_exprs
237 .last()
238 .expect("We have already checked that this exists")
239 .1
240 .clone();
241 let group_by_exprs = group_columns
243 .iter()
244 .map(|column| {
245 Ok(ColumnExpr::new(column_to_column_ref(
246 column,
247 &input_schema,
248 )?))
249 })
250 .collect::<PlannerResult<Vec<_>>>()?;
251 let sum_expr = sum_tuples
253 .iter()
254 .map(|((_, expr), alias)| AliasedDynProofExpr {
255 expr: expr.clone(),
256 alias: alias.clone(),
257 })
258 .collect::<Vec<_>>();
259 Ok(DynProofPlan::new_group_by(
260 group_by_exprs,
261 sum_expr,
262 count_alias,
263 table_expr,
264 consolidated_filter_proof_expr,
265 ))
266 }
267 _ => Err(PlannerError::UnsupportedLogicalPlan {
268 plan: Box::new(input.clone()),
269 }),
270 }
271}
272
273fn join_to_proof_plan(
274 join: &Join,
275 schema_accessor: &impl SchemaAccessor,
276 plan: &LogicalPlan,
277) -> PlannerResult<DynProofPlan> {
278 if join.join_type != JoinType::Inner || join.join_constraint != JoinConstraint::On {
279 return Err(PlannerError::UnsupportedLogicalPlan {
280 plan: Box::new(plan.clone()),
281 });
282 }
283 let left_plan = Box::new(logical_plan_to_proof_plan(&join.left, schema_accessor)?);
284 let right_plan = Box::new(logical_plan_to_proof_plan(&join.right, schema_accessor)?);
285 let left_column_result_fields = left_plan
286 .get_column_result_fields()
287 .into_iter()
288 .map(|c| c.name())
289 .collect::<IndexSet<_>>();
290 let right_column_result_fields = right_plan
291 .get_column_result_fields()
292 .into_iter()
293 .map(|c| c.name())
294 .collect::<IndexSet<_>>();
295 let on_indices_and_idents = join
296 .on
297 .iter()
298 .filter_map(|(left_expr, right_expr)| {
299 Some(match (left_expr, right_expr) {
300 (Expr::Column(col_a), Expr::Column(col_b)) if col_a.name == col_b.name => {
301 let column_id = Ident::new(col_a.name.clone());
302 Ok((
303 (
304 left_column_result_fields.get_index_of(&column_id)?,
305 right_column_result_fields.get_index_of(&column_id)?,
306 ),
307 column_id,
308 ))
309 }
310 _ => Err(PlannerError::UnsupportedLogicalPlan {
311 plan: Box::new(plan.clone()),
312 }),
313 })
314 })
315 .collect::<Result<Vec<_>, _>>()?;
316 let (on_indices, join_idents): (Vec<(usize, usize)>, Vec<Ident>) =
317 on_indices_and_idents.into_iter().unzip();
318 let (left_indices, right_indices): (Vec<usize>, Vec<usize>) = on_indices.into_iter().unzip();
319 let (left_indices_cloned, right_indices_cloned) = (left_indices.clone(), right_indices.clone());
320 let left_other_column_idents = left_column_result_fields
321 .clone()
322 .into_iter()
323 .enumerate()
324 .filter_map(|(i, col_ident)| (!left_indices.contains(&i)).then_some(col_ident));
325 let right_other_column_idents = right_column_result_fields
326 .into_iter()
327 .enumerate()
328 .filter_map(|(i, col_ident)| (!right_indices.contains(&i)).then_some(col_ident));
329 Ok(DynProofPlan::SortMergeJoin(SortMergeJoinExec::new(
330 left_plan,
331 right_plan,
332 left_indices_cloned,
333 right_indices_cloned,
334 join_idents
335 .into_iter()
336 .chain(left_other_column_idents)
337 .chain(right_other_column_idents)
338 .collect(),
339 )))
340}
341
342#[expect(clippy::too_many_lines)]
344pub fn logical_plan_to_proof_plan(
345 plan: &LogicalPlan,
346 schema_accessor: &impl SchemaAccessor,
347) -> PlannerResult<DynProofPlan> {
348 match plan {
349 LogicalPlan::EmptyRelation { .. } => Ok(DynProofPlan::new_empty()),
350 LogicalPlan::TableScan(TableScan {
352 table_name,
353 projection: Some(projection),
354 projected_schema,
355 filters,
356 fetch,
357 ..
358 }) => {
359 let base_plan = if filters.is_empty() {
360 table_scan_to_projection(table_name, schema_accessor, projection, projected_schema)
361 } else {
362 table_scan_to_filter(
363 table_name,
364 schema_accessor,
365 projection,
366 projected_schema,
367 filters,
368 )
369 }?;
370 if let Some(fetch) = fetch {
371 Ok(DynProofPlan::new_slice(base_plan, 0, Some(*fetch)))
372 } else {
373 Ok(base_plan)
374 }
375 }
376 LogicalPlan::Aggregate(Aggregate {
378 input,
379 group_expr,
380 aggr_expr,
381 schema,
382 ..
383 }) => {
384 let name_strings = group_expr
385 .iter()
386 .chain(aggr_expr.iter())
387 .map(Expr::display_name)
388 .collect::<Result<Vec<_>, _>>()?;
389 let alias_map = name_strings
390 .iter()
391 .zip(schema.fields().iter())
392 .map(|(name_string, field)| {
393 let name = name_string.as_str();
394 let alias = field.name().as_str();
395 Ok((name, alias))
396 })
397 .collect::<PlannerResult<IndexMap<_, _>>>()?;
398 aggregate_to_proof_plan(input, group_expr, aggr_expr, schema_accessor, &alias_map)
399 }
400 LogicalPlan::Projection(Projection {
402 input,
403 expr,
404 schema,
405 ..
406 }) => {
407 match &**input {
408 LogicalPlan::Aggregate(Aggregate {
409 input: agg_input,
410 group_expr,
411 aggr_expr,
412 ..
413 }) => {
414 let alias_map = expr
416 .iter()
417 .map(|e| match e {
418 Expr::Column(c) => Ok((c.name.as_str(), c.name.as_str())),
419 Expr::Alias(Alias { expr, name, .. }) => {
420 if let Expr::Column(c) = expr.as_ref() {
421 Ok((c.name.as_str(), name.as_str()))
422 } else {
423 Err(PlannerError::UnsupportedLogicalPlan {
424 plan: Box::new(plan.clone()),
425 })
426 }
427 }
428 _ => Err(PlannerError::UnsupportedLogicalPlan {
429 plan: Box::new(plan.clone()),
430 }),
431 })
432 .collect::<PlannerResult<IndexMap<_, _>>>()?;
433 aggregate_to_proof_plan(
434 agg_input,
435 group_expr,
436 aggr_expr,
437 schema_accessor,
438 &alias_map,
439 )
440 }
441 _ => projection_to_proof_plan(expr, input, schema, schema_accessor),
442 }
443 }
444 LogicalPlan::Limit(Limit { input, fetch, skip }) => {
446 let input_plan = logical_plan_to_proof_plan(input, schema_accessor)?;
447 Ok(DynProofPlan::new_slice(input_plan, *skip, *fetch))
448 }
449 LogicalPlan::Union(Union { inputs, schema: _ }) => {
451 let input_plans = inputs
452 .iter()
453 .map(|input| logical_plan_to_proof_plan(input, schema_accessor))
454 .collect::<PlannerResult<Vec<_>>>()?;
455 Ok(DynProofPlan::try_new_union(input_plans)?)
456 }
457 LogicalPlan::Join(join) => join_to_proof_plan(join, schema_accessor, plan),
458 _ => Err(PlannerError::UnsupportedLogicalPlan {
459 plan: Box::new(plan.clone()),
460 }),
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467 use crate::{df_util::*, PoSqlTableSource};
468 use ahash::AHasher;
469 use alloc::{sync::Arc, vec};
470 use arrow::datatypes::DataType;
471 use core::ops::Add;
472 use datafusion::{
473 common::{Column, ScalarValue},
474 logical_expr::{
475 expr::{AggregateFunction, AggregateFunctionDefinition},
476 not, BinaryExpr, EmptyRelation, Operator, Prepare, TableScan, TableSource,
477 },
478 physical_plan,
479 };
480 use indexmap::{indexmap, indexmap_with_default};
481 use proof_of_sql::base::{
482 database::{ColumnField, TestSchemaAccessor},
483 math::decimal::Precision,
484 };
485 use std::hash::BuildHasherDefault;
486
487 const SUM: AggregateFunctionDefinition =
488 AggregateFunctionDefinition::BuiltIn(physical_plan::aggregates::AggregateFunction::Sum);
489 const COUNT: AggregateFunctionDefinition =
490 AggregateFunctionDefinition::BuiltIn(physical_plan::aggregates::AggregateFunction::Count);
491 const AVG: AggregateFunctionDefinition =
492 AggregateFunctionDefinition::BuiltIn(physical_plan::aggregates::AggregateFunction::Avg);
493
494 #[expect(non_snake_case)]
495 fn TABLE_REF_TABLE() -> TableRef {
496 TableRef::from_names(None, "table")
497 }
498
499 #[expect(non_snake_case)]
500 fn SCHEMAS() -> impl SchemaAccessor {
501 let schema: IndexMap<Ident, ColumnType, BuildHasherDefault<AHasher>> = indexmap_with_default! {
502 AHasher;
503 "a".into() => ColumnType::BigInt,
504 "b".into() => ColumnType::Int,
505 "c".into() => ColumnType::VarChar,
506 "d".into() => ColumnType::Boolean
507 };
508 let table_ref = TableRef::new("", "table");
509 let schema_accessor = indexmap_with_default! {
510 AHasher;
511 table_ref => schema
512 };
513 TestSchemaAccessor::new(schema_accessor)
514 }
515
516 #[expect(non_snake_case)]
517 fn UNION_SCHEMAS() -> impl SchemaAccessor {
518 TestSchemaAccessor::new(indexmap_with_default! {AHasher;
519 TableRef::new("", "table1") => indexmap_with_default! {AHasher;
520 "a1".into() => ColumnType::BigInt,
521 "b1".into() => ColumnType::Int
522 },
523 TableRef::new("", "table2") => indexmap_with_default! {AHasher;
524 "a2".into() => ColumnType::BigInt,
525 "b2".into() => ColumnType::Int
526 },
527 TableRef::new("schema", "table3") => indexmap_with_default! {AHasher;
528 "a3".into() => ColumnType::BigInt,
529 "b3".into() => ColumnType::Int
530 },
531 })
532 }
533
534 #[expect(non_snake_case)]
535 fn EMPTY_SCHEMAS() -> impl SchemaAccessor {
536 TestSchemaAccessor::new(indexmap_with_default! {AHasher;})
537 }
538
539 #[expect(non_snake_case)]
540 fn TABLE_SOURCE() -> Arc<dyn TableSource> {
541 Arc::new(PoSqlTableSource::new(vec![
542 ColumnField::new("a".into(), ColumnType::BigInt),
543 ColumnField::new("b".into(), ColumnType::Int),
544 ColumnField::new("c".into(), ColumnType::VarChar),
545 ColumnField::new("d".into(), ColumnType::Boolean),
546 ]))
547 }
548
549 #[expect(non_snake_case)]
550 fn ALIASED_A() -> AliasedDynProofExpr {
551 AliasedDynProofExpr {
552 expr: DynProofExpr::new_column(ColumnRef::new(
553 TABLE_REF_TABLE(),
554 "a".into(),
555 ColumnType::BigInt,
556 )),
557 alias: "a".into(),
558 }
559 }
560
561 #[expect(non_snake_case)]
562 fn ALIASED_B() -> AliasedDynProofExpr {
563 AliasedDynProofExpr {
564 expr: DynProofExpr::new_column(ColumnRef::new(
565 TABLE_REF_TABLE(),
566 "b".into(),
567 ColumnType::Int,
568 )),
569 alias: "b".into(),
570 }
571 }
572
573 #[expect(non_snake_case)]
574 fn ALIASED_C() -> AliasedDynProofExpr {
575 AliasedDynProofExpr {
576 expr: DynProofExpr::new_column(ColumnRef::new(
577 TABLE_REF_TABLE(),
578 "c".into(),
579 ColumnType::VarChar,
580 )),
581 alias: "c".into(),
582 }
583 }
584
585 #[expect(non_snake_case)]
586 fn ALIASED_D() -> AliasedDynProofExpr {
587 AliasedDynProofExpr {
588 expr: DynProofExpr::new_column(ColumnRef::new(
589 TABLE_REF_TABLE(),
590 "d".into(),
591 ColumnType::Boolean,
592 )),
593 alias: "d".into(),
594 }
595 }
596
597 #[expect(non_snake_case)]
598 fn COUNT_1() -> Expr {
599 Expr::AggregateFunction(AggregateFunction {
600 func_def: COUNT,
601 args: vec![Expr::Literal(ScalarValue::Int64(Some(1)))],
602 distinct: false,
603 filter: None,
604 order_by: None,
605 null_treatment: None,
606 })
607 }
608
609 #[expect(non_snake_case)]
610 fn SUM_B() -> Expr {
611 Expr::AggregateFunction(AggregateFunction {
612 func_def: SUM,
613 args: vec![df_column("table", "b")],
614 distinct: false,
615 filter: None,
616 order_by: None,
617 null_treatment: None,
618 })
619 }
620
621 #[expect(non_snake_case)]
622 fn SUM_D() -> Expr {
623 Expr::AggregateFunction(AggregateFunction {
624 func_def: SUM,
625 args: vec![df_column("table", "d")],
626 distinct: false,
627 filter: None,
628 order_by: None,
629 null_treatment: None,
630 })
631 }
632
633 #[test]
635 fn we_can_get_aliased_proof_expr_with_specified_projection_columns() {
636 let table_ref = TABLE_REF_TABLE();
638 let input_schema = vec![
639 ("a".into(), ColumnType::BigInt),
640 ("b".into(), ColumnType::Int),
641 ("c".into(), ColumnType::VarChar),
642 (
643 "d".into(),
644 ColumnType::Decimal75(Precision::new(5).unwrap(), 1),
645 ), ];
647 let output_schema = df_schema("table", vec![("b", DataType::Int32), ("c", DataType::Utf8)]);
648 let result =
649 get_aliased_dyn_proof_exprs(&table_ref, &[1, 2], &input_schema, &output_schema)
650 .unwrap();
651 let expected = vec![ALIASED_B(), ALIASED_C()];
652 assert_eq!(result, expected);
653 }
654
655 #[test]
656 fn we_can_get_aliased_proof_expr_without_specified_projection_columns() {
657 let table_ref = TABLE_REF_TABLE();
658 let input_schema = vec![
659 ("a".into(), ColumnType::BigInt),
660 ("b".into(), ColumnType::Int),
661 ("c".into(), ColumnType::VarChar),
662 ("d".into(), ColumnType::Boolean),
663 ];
664 let output_schema = df_schema(
665 "table",
666 vec![
667 ("a", DataType::Int64),
668 ("b", DataType::Int32),
669 ("c", DataType::Utf8),
670 ("d", DataType::Boolean),
671 ],
672 );
673 let result =
674 get_aliased_dyn_proof_exprs(&table_ref, &[0, 1, 2, 3], &input_schema, &output_schema)
675 .unwrap();
676 let expected = vec![ALIASED_A(), ALIASED_B(), ALIASED_C(), ALIASED_D()];
677 assert_eq!(result, expected);
678 }
679
680 #[test]
682 fn we_can_aggregate_with_group_by_and_sum_count() {
683 let group_expr = vec![df_column("table", "a")];
685
686 let aggr_expr = vec![
688 SUM_B(), COUNT_1(), ];
691
692 let input_plan = LogicalPlan::TableScan(
694 TableScan::try_new(
695 "table",
696 TABLE_SOURCE(),
697 Some(vec![0, 1, 2, 3]),
698 vec![],
699 None,
700 )
701 .unwrap(),
702 );
703 let alias_map = indexmap! {
704 "a" => "a",
705 "SUM(table.b)" => "sum_b",
706 "COUNT(Int64(1))" => "count_1",
707 };
708
709 let result =
711 aggregate_to_proof_plan(&input_plan, &group_expr, &aggr_expr, &SCHEMAS(), &alias_map)
712 .unwrap();
713
714 let expected = DynProofPlan::new_group_by(
716 vec![ColumnExpr::new(ColumnRef::new(
717 TABLE_REF_TABLE(),
718 "a".into(),
719 ColumnType::BigInt,
720 ))],
721 vec![AliasedDynProofExpr {
722 expr: DynProofExpr::new_column(ColumnRef::new(
723 TABLE_REF_TABLE(),
724 "b".into(),
725 ColumnType::Int,
726 )),
727 alias: "sum_b".into(),
728 }],
729 "count_1".into(),
730 TableExpr {
731 table_ref: TABLE_REF_TABLE(),
732 },
733 DynProofExpr::new_literal(LiteralValue::Boolean(true)),
734 );
735
736 assert_eq!(result, expected);
737 }
738
739 #[test]
740 fn we_can_aggregate_with_filters() {
741 let group_expr = vec![df_column("table", "a")];
743
744 let aggr_expr = vec![
746 SUM_B(), COUNT_1(), ];
749
750 let filter_exprs = vec![
752 df_column("table", "d"), ];
754
755 let input_plan = LogicalPlan::TableScan(
757 TableScan::try_new(
758 "table",
759 TABLE_SOURCE(),
760 Some(vec![0, 1, 2, 3]),
761 filter_exprs,
762 None,
763 )
764 .unwrap(),
765 );
766 let alias_map = indexmap! {
767 "a" => "a",
768 "SUM(table.b)" => "sum_b",
769 "COUNT(Int64(1))" => "count_1",
770 };
771
772 let result =
774 aggregate_to_proof_plan(&input_plan, &group_expr, &aggr_expr, &SCHEMAS(), &alias_map)
775 .unwrap();
776
777 let expected = DynProofPlan::new_group_by(
779 vec![ColumnExpr::new(ColumnRef::new(
780 TABLE_REF_TABLE(),
781 "a".into(),
782 ColumnType::BigInt,
783 ))],
784 vec![AliasedDynProofExpr {
785 expr: DynProofExpr::new_column(ColumnRef::new(
786 TABLE_REF_TABLE(),
787 "b".into(),
788 ColumnType::Int,
789 )),
790 alias: "sum_b".into(),
791 }],
792 "count_1".into(),
793 TableExpr {
794 table_ref: TABLE_REF_TABLE(),
795 },
796 DynProofExpr::new_column(ColumnRef::new(
797 TABLE_REF_TABLE(),
798 "d".into(),
799 ColumnType::Boolean,
800 )),
801 );
802
803 assert_eq!(result, expected);
804 }
805
806 #[test]
807 fn we_can_aggregate_with_multiple_group_columns() {
808 let group_expr = vec![df_column("table", "a"), df_column("table", "c")];
810
811 let aggr_expr = vec![
813 SUM_B(), COUNT_1(), ];
816
817 let input_plan = LogicalPlan::TableScan(
819 TableScan::try_new(
820 "table",
821 TABLE_SOURCE(),
822 Some(vec![0, 1, 2, 3]),
823 vec![],
824 None,
825 )
826 .unwrap(),
827 );
828 let alias_map = indexmap! {
829 "a" => "a",
830 "c" => "c",
831 "SUM(table.b)" => "sum_b",
832 "COUNT(Int64(1))" => "count_1",
833 };
834
835 let result =
837 aggregate_to_proof_plan(&input_plan, &group_expr, &aggr_expr, &SCHEMAS(), &alias_map)
838 .unwrap();
839
840 let expected = DynProofPlan::new_group_by(
842 vec![
843 ColumnExpr::new(ColumnRef::new(
844 TABLE_REF_TABLE(),
845 "a".into(),
846 ColumnType::BigInt,
847 )),
848 ColumnExpr::new(ColumnRef::new(
849 TABLE_REF_TABLE(),
850 "c".into(),
851 ColumnType::VarChar,
852 )),
853 ],
854 vec![AliasedDynProofExpr {
855 expr: DynProofExpr::new_column(ColumnRef::new(
856 TABLE_REF_TABLE(),
857 "b".into(),
858 ColumnType::Int,
859 )),
860 alias: "sum_b".into(),
861 }],
862 "count_1".into(),
863 TableExpr {
864 table_ref: TABLE_REF_TABLE(),
865 },
866 DynProofExpr::new_literal(LiteralValue::Boolean(true)),
867 );
868
869 assert_eq!(result, expected);
870 }
871
872 #[test]
873 fn we_can_aggregate_with_multiple_sum_expressions() {
874 let group_expr = vec![df_column("table", "a")];
876
877 let aggr_expr = vec![
879 SUM_B(), SUM_D(), COUNT_1(), ];
883
884 let input_plan = LogicalPlan::TableScan(
886 TableScan::try_new(
887 "table",
888 TABLE_SOURCE(),
889 Some(vec![0, 1, 2, 3]),
890 vec![],
891 None,
892 )
893 .unwrap(),
894 );
895 let alias_map = indexmap! {
896 "a" => "a",
897 "SUM(table.b)" => "sum_b",
898 "SUM(table.d)" => "sum_d",
899 "COUNT(Int64(1))" => "count_1",
900 };
901
902 let result =
904 aggregate_to_proof_plan(&input_plan, &group_expr, &aggr_expr, &SCHEMAS(), &alias_map)
905 .unwrap();
906
907 let expected = DynProofPlan::new_group_by(
909 vec![ColumnExpr::new(ColumnRef::new(
910 TABLE_REF_TABLE(),
911 "a".into(),
912 ColumnType::BigInt,
913 ))],
914 vec![
915 AliasedDynProofExpr {
916 expr: DynProofExpr::new_column(ColumnRef::new(
917 TABLE_REF_TABLE(),
918 "b".into(),
919 ColumnType::Int,
920 )),
921 alias: "sum_b".into(),
922 },
923 AliasedDynProofExpr {
924 expr: DynProofExpr::new_column(ColumnRef::new(
925 TABLE_REF_TABLE(),
926 "d".into(),
927 ColumnType::Boolean,
928 )),
929 alias: "sum_d".into(),
930 },
931 ],
932 "count_1".into(),
933 TableExpr {
934 table_ref: TABLE_REF_TABLE(),
935 },
936 DynProofExpr::new_literal(LiteralValue::Boolean(true)),
937 );
938
939 assert_eq!(result, expected);
940 }
941
942 #[test]
943 fn we_can_aggregate_without_sum_expressions() {
944 let group_expr = vec![df_column("table", "a")];
946
947 let aggr_expr = vec![
949 COUNT_1(), ];
951
952 let input_plan = LogicalPlan::TableScan(
954 TableScan::try_new(
955 "table",
956 TABLE_SOURCE(),
957 Some(vec![0, 1, 2, 3]),
958 vec![],
959 None,
960 )
961 .unwrap(),
962 );
963 let alias_map = indexmap! {
964 "a" => "a",
965 "COUNT(Int64(1))" => "count_1",
966 };
967
968 let result =
970 aggregate_to_proof_plan(&input_plan, &group_expr, &aggr_expr, &SCHEMAS(), &alias_map)
971 .unwrap();
972
973 let expected = DynProofPlan::new_group_by(
975 vec![ColumnExpr::new(ColumnRef::new(
976 TABLE_REF_TABLE(),
977 "a".into(),
978 ColumnType::BigInt,
979 ))],
980 vec![], "count_1".into(),
982 TableExpr {
983 table_ref: TABLE_REF_TABLE(),
984 },
985 DynProofExpr::new_literal(LiteralValue::Boolean(true)),
986 );
987
988 assert_eq!(result, expected);
989 }
990
991 #[test]
993 fn we_cannot_aggregate_with_non_column_group_expr() {
994 let group_expr = vec![Expr::BinaryExpr(BinaryExpr::new(
996 Box::new(df_column("table", "a")),
997 Operator::Plus,
998 Box::new(df_column("table", "b")),
999 ))];
1000
1001 let aggr_expr = vec![
1003 Expr::BinaryExpr(BinaryExpr::new(
1004 Box::new(df_column("table", "a")),
1005 Operator::Plus,
1006 Box::new(df_column("table", "b")),
1007 )),
1008 COUNT_1(),
1009 ];
1010
1011 let input_plan = LogicalPlan::TableScan(
1013 TableScan::try_new(
1014 "table",
1015 TABLE_SOURCE(),
1016 Some(vec![0, 1, 2, 3]),
1017 vec![],
1018 None,
1019 )
1020 .unwrap(),
1021 );
1022 let alias_map = indexmap! {
1023 "a+b" => "res",
1024 "COUNT(Int64(1))" => "count_1",
1025 };
1026
1027 let result =
1029 aggregate_to_proof_plan(&input_plan, &group_expr, &aggr_expr, &SCHEMAS(), &alias_map);
1030 assert!(matches!(
1031 result,
1032 Err(PlannerError::UnsupportedLogicalPlan { .. })
1033 ));
1034 }
1035
1036 #[test]
1037 fn we_cannot_aggregate_with_non_aggregate_expression() {
1038 let group_expr = vec![df_column("table", "a")];
1040
1041 let non_agg_expr = Expr::BinaryExpr(BinaryExpr::new(
1043 Box::new(df_column("table", "b")),
1044 Operator::Plus,
1045 Box::new(df_column("table", "c")),
1046 ));
1047
1048 let aliased_non_agg = Expr::Alias(Alias {
1050 expr: Box::new(non_agg_expr),
1051 relation: None,
1052 name: "b_plus_c".to_string(),
1053 });
1054
1055 let aggr_expr = vec![
1057 aliased_non_agg, ];
1059
1060 let input_plan = LogicalPlan::TableScan(
1062 TableScan::try_new(
1063 "table",
1064 TABLE_SOURCE(),
1065 Some(vec![0, 1, 2, 3]),
1066 vec![],
1067 None,
1068 )
1069 .unwrap(),
1070 );
1071 let alias_map = indexmap! {
1072 "b+c" => "b_plus_c",
1073 };
1074
1075 let result =
1077 aggregate_to_proof_plan(&input_plan, &group_expr, &aggr_expr, &SCHEMAS(), &alias_map);
1078 assert!(matches!(
1079 result,
1080 Err(PlannerError::UnsupportedLogicalPlan { .. })
1081 ));
1082 }
1083
1084 #[test]
1085 fn we_cannot_aggregate_with_non_sum_aggregate_function() {
1086 let group_expr = vec![df_column("table", "a")];
1088
1089 let avg_expr = Expr::AggregateFunction(AggregateFunction {
1091 func_def: AVG,
1092 args: vec![df_column("table", "b")],
1093 distinct: false,
1094 filter: None,
1095 order_by: None,
1096 null_treatment: None,
1097 });
1098
1099 let aliased_avg = Expr::Alias(Alias {
1101 expr: Box::new(avg_expr),
1102 relation: None,
1103 name: "avg_b".to_string(),
1104 });
1105
1106 let aggr_expr = vec![
1108 aliased_avg, COUNT_1(), ];
1111
1112 let input_plan = LogicalPlan::TableScan(
1114 TableScan::try_new(
1115 "table",
1116 TABLE_SOURCE(),
1117 Some(vec![0, 1, 2, 3]),
1118 vec![],
1119 None,
1120 )
1121 .unwrap(),
1122 );
1123 let alias_map = indexmap! {
1124 "a" => "a",
1125 "AVG(table.b)" => "avg_b",
1126 "COUNT(Int64(1))" => "count_1",
1127 };
1128
1129 let result =
1131 aggregate_to_proof_plan(&input_plan, &group_expr, &aggr_expr, &SCHEMAS(), &alias_map);
1132 assert!(matches!(
1133 result,
1134 Err(PlannerError::UnsupportedLogicalPlan { .. })
1135 ));
1136 }
1137
1138 #[test]
1139 fn we_cannot_aggregate_with_non_count_last_aggregate() {
1140 let group_expr = vec![df_column("table", "a")];
1142
1143 let sum_expr1 = Expr::AggregateFunction(AggregateFunction {
1145 func_def: SUM,
1146 args: vec![df_column("table", "b")],
1147 distinct: false,
1148 filter: None,
1149 order_by: None,
1150 null_treatment: None,
1151 });
1152
1153 let sum_expr2 = Expr::AggregateFunction(AggregateFunction {
1154 func_def: SUM,
1155 args: vec![df_column("table", "c")],
1156 distinct: false,
1157 filter: None,
1158 order_by: None,
1159 null_treatment: None,
1160 });
1161
1162 let aliased_sum1 = Expr::Alias(Alias {
1164 expr: Box::new(sum_expr1),
1165 relation: None,
1166 name: "sum_b".to_string(),
1167 });
1168
1169 let aliased_sum2 = Expr::Alias(Alias {
1170 expr: Box::new(sum_expr2),
1171 relation: None,
1172 name: "sum_c".to_string(),
1173 });
1174
1175 let aggr_expr = vec![
1177 aliased_sum1, aliased_sum2, ];
1180
1181 let input_plan = LogicalPlan::TableScan(
1183 TableScan::try_new(
1184 "table",
1185 TABLE_SOURCE(),
1186 Some(vec![0, 1, 2, 3]),
1187 vec![],
1188 None,
1189 )
1190 .unwrap(),
1191 );
1192 let alias_map = indexmap! {
1193 "a" => "a",
1194 "SUM(table.b)" => "sum_b",
1195 "SUM(c)" => "sum_c",
1196 };
1197
1198 let result =
1200 aggregate_to_proof_plan(&input_plan, &group_expr, &aggr_expr, &SCHEMAS(), &alias_map);
1201 assert!(matches!(
1202 result,
1203 Err(PlannerError::UnsupportedLogicalPlan { .. })
1204 ));
1205 }
1206
1207 #[test]
1208 fn we_cannot_aggregate_with_fetch_limit() {
1209 let group_expr = vec![df_column("table", "a")];
1211
1212 let aggr_expr = vec![
1214 COUNT_1(), ];
1216
1217 let input_plan = LogicalPlan::TableScan(
1219 TableScan::try_new(
1220 "table",
1221 TABLE_SOURCE(),
1222 Some(vec![0, 1, 2, 3]),
1223 vec![],
1224 Some(10),
1225 )
1226 .unwrap(),
1227 );
1228 let alias_map = indexmap! {
1229 "a" => "a",
1230 "COUNT(Int64(1))" => "count_1",
1231 };
1232
1233 let result =
1235 aggregate_to_proof_plan(&input_plan, &group_expr, &aggr_expr, &SCHEMAS(), &alias_map);
1236 assert!(matches!(
1237 result,
1238 Err(PlannerError::UnsupportedLogicalPlan { .. })
1239 ));
1240 }
1241
1242 #[test]
1243 fn we_cannot_aggregate_with_non_table_scan_input() {
1244 let group_expr = vec![df_column("table", "a")];
1246
1247 let aggr_expr = vec![
1249 COUNT_1(), ];
1251
1252 let input_plan = LogicalPlan::EmptyRelation(EmptyRelation {
1254 produce_one_row: false,
1255 schema: Arc::new(DFSchema::empty()),
1256 });
1257 let alias_map = indexmap! {
1258 "a" => "a",
1259 "COUNT(Int64(1))" => "count_1",
1260 };
1261
1262 let result =
1264 aggregate_to_proof_plan(&input_plan, &group_expr, &aggr_expr, &SCHEMAS(), &alias_map);
1265 assert!(matches!(
1266 result,
1267 Err(PlannerError::UnsupportedLogicalPlan { .. })
1268 ));
1269 }
1270
1271 #[test]
1273 fn we_can_convert_empty_plan_to_proof_plan() {
1274 let empty_plan = LogicalPlan::EmptyRelation(EmptyRelation {
1275 produce_one_row: false,
1276 schema: Arc::new(DFSchema::empty()),
1277 });
1278 let result = logical_plan_to_proof_plan(&empty_plan, &EMPTY_SCHEMAS()).unwrap();
1279 assert_eq!(result, DynProofPlan::new_empty());
1280 }
1281
1282 #[test]
1284 fn we_can_convert_table_scan_plan_to_proof_plan_without_filter_or_fetch_limit() {
1285 let plan = LogicalPlan::TableScan(
1286 TableScan::try_new("table", TABLE_SOURCE(), Some(vec![0, 1, 2]), vec![], None).unwrap(),
1287 );
1288 let schemas = SCHEMAS();
1289 let result = logical_plan_to_proof_plan(&plan, &schemas).unwrap();
1290 let expected = DynProofPlan::new_projection(
1291 vec![ALIASED_A(), ALIASED_B(), ALIASED_C()],
1292 DynProofPlan::new_table(
1293 TABLE_REF_TABLE(),
1294 vec![
1295 ColumnField::new("a".into(), ColumnType::BigInt),
1296 ColumnField::new("b".into(), ColumnType::Int),
1297 ColumnField::new("c".into(), ColumnType::VarChar),
1298 ColumnField::new("d".into(), ColumnType::Boolean),
1299 ],
1300 ),
1301 );
1302 assert_eq!(result, expected);
1303 }
1304
1305 #[test]
1306 fn we_cannot_convert_table_scan_plan_to_proof_plan_without_filter_or_fetch_limit_if_bad_schemas(
1307 ) {
1308 let plan = LogicalPlan::TableScan(
1309 TableScan::try_new(
1310 "table",
1311 TABLE_SOURCE(),
1312 Some(vec![0, 1, 2, 3]),
1313 vec![],
1314 None,
1315 )
1316 .unwrap(),
1317 );
1318 let schemas = EMPTY_SCHEMAS();
1319 let result = logical_plan_to_proof_plan(&plan, &schemas);
1320 assert!(matches!(result, Err(PlannerError::ColumnNotFound)));
1321 }
1322
1323 #[test]
1324 fn we_can_convert_table_scan_plan_to_proof_plan_with_filter_but_without_fetch_limit() {
1325 let filter_exprs = vec![
1326 df_column("table", "a").eq(df_column("table", "b")),
1327 df_column("table", "d"),
1328 ];
1329 let plan = LogicalPlan::TableScan(
1330 TableScan::try_new(
1331 "table",
1332 TABLE_SOURCE(),
1333 Some(vec![0, 2]),
1334 filter_exprs,
1335 None,
1336 )
1337 .unwrap(),
1338 );
1339 let schemas = SCHEMAS();
1340 let result = logical_plan_to_proof_plan(&plan, &schemas).unwrap();
1341 let expected = DynProofPlan::new_filter(
1342 vec![ALIASED_A(), ALIASED_C()],
1343 TableExpr {
1344 table_ref: TABLE_REF_TABLE(),
1345 },
1346 DynProofExpr::try_new_and(
1347 DynProofExpr::try_new_equals(
1348 DynProofExpr::new_column(ColumnRef::new(
1349 TABLE_REF_TABLE(),
1350 "a".into(),
1351 ColumnType::BigInt,
1352 )),
1353 DynProofExpr::new_column(ColumnRef::new(
1354 TABLE_REF_TABLE(),
1355 "b".into(),
1356 ColumnType::Int,
1357 )),
1358 )
1359 .unwrap(),
1360 DynProofExpr::new_column(ColumnRef::new(
1361 TABLE_REF_TABLE(),
1362 "d".into(),
1363 ColumnType::Boolean,
1364 )),
1365 )
1366 .unwrap(),
1367 );
1368 assert_eq!(result, expected);
1369 }
1370
1371 #[test]
1372 fn we_cannot_convert_table_scan_plan_to_proof_plan_with_filter_but_without_fetch_limit_if_bad_schemas(
1373 ) {
1374 let filter_exprs = vec![
1375 df_column("table", "a").eq(df_column("table", "b")),
1376 df_column("table", "d"),
1377 ];
1378 let plan = LogicalPlan::TableScan(
1379 TableScan::try_new(
1380 "table",
1381 TABLE_SOURCE(),
1382 Some(vec![0, 2]),
1383 filter_exprs,
1384 None,
1385 )
1386 .unwrap(),
1387 );
1388 let schemas = EMPTY_SCHEMAS();
1389 let result = logical_plan_to_proof_plan(&plan, &schemas);
1390 assert!(matches!(result, Err(PlannerError::ColumnNotFound)));
1391 }
1392
1393 #[test]
1394 fn we_can_convert_table_scan_plan_to_proof_plan_without_filter_but_with_fetch_limit() {
1395 let plan = LogicalPlan::TableScan(
1396 TableScan::try_new(
1397 "table",
1398 TABLE_SOURCE(),
1399 Some(vec![0, 1, 2, 3]),
1400 vec![],
1401 Some(2),
1402 )
1403 .unwrap(),
1404 );
1405 let schemas = SCHEMAS();
1406 let result = logical_plan_to_proof_plan(&plan, &schemas).unwrap();
1407 let expected = DynProofPlan::new_slice(
1408 DynProofPlan::new_projection(
1409 vec![ALIASED_A(), ALIASED_B(), ALIASED_C(), ALIASED_D()],
1410 DynProofPlan::new_table(
1411 TABLE_REF_TABLE(),
1412 vec![
1413 ColumnField::new("a".into(), ColumnType::BigInt),
1414 ColumnField::new("b".into(), ColumnType::Int),
1415 ColumnField::new("c".into(), ColumnType::VarChar),
1416 ColumnField::new("d".into(), ColumnType::Boolean),
1417 ],
1418 ),
1419 ),
1420 0,
1421 Some(2),
1422 );
1423 assert_eq!(result, expected);
1424 }
1425
1426 #[test]
1427 fn we_can_convert_table_scan_plan_to_proof_plan_with_filter_and_fetch_limit() {
1428 let filter_exprs = vec![
1429 df_column("table", "a").gt(df_column("table", "b")),
1430 df_column("table", "d"),
1431 ];
1432 let plan = LogicalPlan::TableScan(
1433 TableScan::try_new(
1434 "table",
1435 TABLE_SOURCE(),
1436 Some(vec![0, 3]),
1437 filter_exprs,
1438 Some(5),
1439 )
1440 .unwrap(),
1441 );
1442 let schemas = SCHEMAS();
1443 let result = logical_plan_to_proof_plan(&plan, &schemas).unwrap();
1444 let expected = DynProofPlan::new_slice(
1445 DynProofPlan::new_filter(
1446 vec![ALIASED_A(), ALIASED_D()],
1447 TableExpr {
1448 table_ref: TABLE_REF_TABLE(),
1449 },
1450 DynProofExpr::try_new_and(
1451 DynProofExpr::try_new_inequality(
1452 DynProofExpr::new_column(ColumnRef::new(
1453 TABLE_REF_TABLE(),
1454 "a".into(),
1455 ColumnType::BigInt,
1456 )),
1457 DynProofExpr::new_column(ColumnRef::new(
1458 TABLE_REF_TABLE(),
1459 "b".into(),
1460 ColumnType::Int,
1461 )),
1462 false,
1463 )
1464 .unwrap(),
1465 DynProofExpr::new_column(ColumnRef::new(
1466 TABLE_REF_TABLE(),
1467 "d".into(),
1468 ColumnType::Boolean,
1469 )),
1470 )
1471 .unwrap(),
1472 ),
1473 0,
1474 Some(5),
1475 );
1476 assert_eq!(result, expected);
1477 }
1478
1479 #[test]
1481 fn we_can_convert_projection_plan_to_proof_plan() {
1482 let plan = LogicalPlan::Projection(
1483 Projection::try_new(
1484 vec![
1485 Expr::BinaryExpr(BinaryExpr::new(
1486 Box::new(df_column("table", "a")),
1487 Operator::Plus,
1488 Box::new(df_column("table", "b")),
1489 )),
1490 not(df_column("table", "d")),
1491 ],
1492 Arc::new(LogicalPlan::TableScan(
1493 TableScan::try_new("table", TABLE_SOURCE(), Some(vec![0, 1, 3]), vec![], None)
1494 .unwrap(),
1495 )),
1496 )
1497 .unwrap(),
1498 );
1499 let schemas = SCHEMAS();
1500 let result = logical_plan_to_proof_plan(&plan, &schemas).unwrap();
1501 let expected = DynProofPlan::new_projection(
1502 vec![
1503 AliasedDynProofExpr {
1504 expr: DynProofExpr::try_new_add(
1505 DynProofExpr::new_column(ColumnRef::new(
1506 TABLE_REF_TABLE(),
1507 "a".into(),
1508 ColumnType::BigInt,
1509 )),
1510 DynProofExpr::new_column(ColumnRef::new(
1511 TABLE_REF_TABLE(),
1512 "b".into(),
1513 ColumnType::Int,
1514 )),
1515 )
1516 .unwrap(),
1517 alias: "table.a + table.b".into(),
1518 },
1519 AliasedDynProofExpr {
1520 expr: DynProofExpr::try_new_not(DynProofExpr::new_column(ColumnRef::new(
1521 TABLE_REF_TABLE(),
1522 "d".into(),
1523 ColumnType::Boolean,
1524 )))
1525 .unwrap(),
1526 alias: "NOT table.d".into(),
1527 },
1528 ],
1529 DynProofPlan::new_projection(
1530 vec![ALIASED_A(), ALIASED_B(), ALIASED_D()],
1531 DynProofPlan::new_table(
1532 TABLE_REF_TABLE(),
1533 vec![
1534 ColumnField::new("a".into(), ColumnType::BigInt),
1535 ColumnField::new("b".into(), ColumnType::Int),
1536 ColumnField::new("c".into(), ColumnType::VarChar),
1537 ColumnField::new("d".into(), ColumnType::Boolean),
1538 ],
1539 ),
1540 ),
1541 );
1542 assert_eq!(result, expected);
1543 }
1544
1545 #[test]
1548 fn we_can_convert_limit_plan_with_fetch_and_skip_to_proof_plan() {
1549 let plan = LogicalPlan::Limit(Limit {
1550 input: Arc::new(LogicalPlan::TableScan(
1551 TableScan::try_new(
1552 "table",
1553 TABLE_SOURCE(),
1554 Some(vec![0, 1]),
1555 vec![],
1556 Some(5),
1558 )
1559 .unwrap(),
1560 )),
1561 fetch: Some(3),
1562 skip: 2,
1563 });
1564 let schemas = SCHEMAS();
1565 let result = logical_plan_to_proof_plan(&plan, &schemas).unwrap();
1566 let expected = DynProofPlan::new_slice(
1567 DynProofPlan::new_slice(
1568 DynProofPlan::new_projection(
1569 vec![ALIASED_A(), ALIASED_B()],
1570 DynProofPlan::new_table(
1571 TABLE_REF_TABLE(),
1572 vec![
1573 ColumnField::new("a".into(), ColumnType::BigInt),
1574 ColumnField::new("b".into(), ColumnType::Int),
1575 ColumnField::new("c".into(), ColumnType::VarChar),
1576 ColumnField::new("d".into(), ColumnType::Boolean),
1577 ],
1578 ),
1579 ),
1580 0,
1581 Some(5),
1582 ),
1583 2,
1584 Some(3),
1585 );
1586 assert_eq!(result, expected);
1587 }
1588
1589 #[test]
1590 fn we_can_convert_limit_plan_with_fetch_no_skip_to_proof_plan() {
1591 let plan = LogicalPlan::Limit(Limit {
1593 input: Arc::new(LogicalPlan::TableScan(
1594 TableScan::try_new("table", TABLE_SOURCE(), Some(vec![0, 1]), vec![], Some(3))
1595 .unwrap(),
1596 )),
1597 fetch: Some(3),
1598 skip: 0,
1599 });
1600
1601 let schemas = SCHEMAS();
1602 let result = logical_plan_to_proof_plan(&plan, &schemas).unwrap();
1603
1604 let expected = DynProofPlan::new_slice(
1605 DynProofPlan::new_slice(
1606 DynProofPlan::new_projection(
1607 vec![ALIASED_A(), ALIASED_B()],
1608 DynProofPlan::new_table(
1609 TABLE_REF_TABLE(),
1610 vec![
1611 ColumnField::new("a".into(), ColumnType::BigInt),
1612 ColumnField::new("b".into(), ColumnType::Int),
1613 ColumnField::new("c".into(), ColumnType::VarChar),
1614 ColumnField::new("d".into(), ColumnType::Boolean),
1615 ],
1616 ),
1617 ),
1618 0,
1619 Some(3),
1620 ),
1621 0,
1622 Some(3),
1623 );
1624 assert_eq!(result, expected);
1625 }
1626
1627 #[test]
1628 fn we_can_convert_limit_plan_with_skip_no_fetch_to_proof_plan() {
1629 let plan = LogicalPlan::Limit(Limit {
1630 input: Arc::new(LogicalPlan::TableScan(
1631 TableScan::try_new("table", TABLE_SOURCE(), Some(vec![0, 1]), vec![], None)
1632 .unwrap(),
1633 )),
1634 fetch: None,
1635 skip: 2,
1636 });
1637
1638 let schemas = SCHEMAS();
1639 let result = logical_plan_to_proof_plan(&plan, &schemas).unwrap();
1640
1641 let expected = DynProofPlan::new_slice(
1642 DynProofPlan::new_projection(
1643 vec![ALIASED_A(), ALIASED_B()],
1644 DynProofPlan::new_table(
1645 TABLE_REF_TABLE(),
1646 vec![
1647 ColumnField::new("a".into(), ColumnType::BigInt),
1648 ColumnField::new("b".into(), ColumnType::Int),
1649 ColumnField::new("c".into(), ColumnType::VarChar),
1650 ColumnField::new("d".into(), ColumnType::Boolean),
1651 ],
1652 ),
1653 ),
1654 2,
1655 None,
1656 );
1657 assert_eq!(result, expected);
1658 }
1659
1660 #[expect(clippy::too_many_lines)]
1662 #[test]
1663 fn we_can_convert_union_plan_to_proof_plan() {
1664 let plan = LogicalPlan::Union(Union {
1665 schema: Arc::new(df_schema(
1666 "table",
1667 vec![("a", DataType::Int64), ("b", DataType::Int32)],
1668 )),
1669 inputs: vec![
1670 Arc::new(LogicalPlan::TableScan(
1671 TableScan::try_new("table1", TABLE_SOURCE(), Some(vec![0, 1]), vec![], None)
1672 .unwrap(),
1673 )),
1674 Arc::new(LogicalPlan::TableScan(
1675 TableScan::try_new("table2", TABLE_SOURCE(), Some(vec![0, 1]), vec![], None)
1676 .unwrap(),
1677 )),
1678 Arc::new(LogicalPlan::TableScan(
1679 TableScan::try_new(
1680 "schema.table3",
1681 TABLE_SOURCE(),
1682 Some(vec![0, 1]),
1683 vec![],
1684 None,
1685 )
1686 .unwrap(),
1687 )),
1688 ],
1689 });
1690 let schemas = UNION_SCHEMAS();
1691 let result = logical_plan_to_proof_plan(&plan, &schemas).unwrap();
1692 let expected = DynProofPlan::try_new_union(vec![
1693 DynProofPlan::new_projection(
1694 vec![
1695 AliasedDynProofExpr {
1696 expr: DynProofExpr::new_column(ColumnRef::new(
1697 TableRef::from_names(None, "table1"),
1698 "a1".into(),
1699 ColumnType::BigInt,
1700 )),
1701 alias: "a".into(),
1702 },
1703 AliasedDynProofExpr {
1704 expr: DynProofExpr::new_column(ColumnRef::new(
1705 TableRef::from_names(None, "table1"),
1706 "b1".into(),
1707 ColumnType::Int,
1708 )),
1709 alias: "b".into(),
1710 },
1711 ],
1712 DynProofPlan::new_table(
1713 TableRef::from_names(None, "table1"),
1714 vec![
1715 ColumnField::new("a1".into(), ColumnType::BigInt),
1716 ColumnField::new("b1".into(), ColumnType::Int),
1717 ],
1718 ),
1719 ),
1720 DynProofPlan::new_projection(
1721 vec![
1722 AliasedDynProofExpr {
1723 expr: DynProofExpr::new_column(ColumnRef::new(
1724 TableRef::from_names(None, "table2"),
1725 "a2".into(),
1726 ColumnType::BigInt,
1727 )),
1728 alias: "a".into(),
1729 },
1730 AliasedDynProofExpr {
1731 expr: DynProofExpr::new_column(ColumnRef::new(
1732 TableRef::from_names(None, "table2"),
1733 "b2".into(),
1734 ColumnType::Int,
1735 )),
1736 alias: "b".into(),
1737 },
1738 ],
1739 DynProofPlan::new_table(
1740 TableRef::from_names(None, "table2"),
1741 vec![
1742 ColumnField::new("a2".into(), ColumnType::BigInt),
1743 ColumnField::new("b2".into(), ColumnType::Int),
1744 ],
1745 ),
1746 ),
1747 DynProofPlan::new_projection(
1748 vec![
1749 AliasedDynProofExpr {
1750 expr: DynProofExpr::new_column(ColumnRef::new(
1751 TableRef::from_names(Some("schema"), "table3"),
1752 "a3".into(),
1753 ColumnType::BigInt,
1754 )),
1755 alias: "a".into(),
1756 },
1757 AliasedDynProofExpr {
1758 expr: DynProofExpr::new_column(ColumnRef::new(
1759 TableRef::from_names(Some("schema"), "table3"),
1760 "b3".into(),
1761 ColumnType::Int,
1762 )),
1763 alias: "b".into(),
1764 },
1765 ],
1766 DynProofPlan::new_table(
1767 TableRef::from_names(Some("schema"), "table3"),
1768 vec![
1769 ColumnField::new("a3".into(), ColumnType::BigInt),
1770 ColumnField::new("b3".into(), ColumnType::Int),
1771 ],
1772 ),
1773 ),
1774 ])
1775 .unwrap();
1776 assert_eq!(result, expected);
1777 }
1778
1779 #[test]
1781 fn we_can_convert_supported_simple_agg_plan_to_proof_plan() {
1782 let group_expr = vec![df_column("table", "a")];
1784
1785 let aggr_expr = vec![
1787 SUM_B(), COUNT_1(), ];
1790
1791 let filter_exprs = vec![
1793 df_column("table", "d"), ];
1795
1796 let input_plan = LogicalPlan::TableScan(
1798 TableScan::try_new(
1799 "table",
1800 TABLE_SOURCE(),
1801 Some(vec![0, 1, 2, 3]),
1802 filter_exprs,
1803 None,
1804 )
1805 .unwrap(),
1806 );
1807
1808 let agg_plan = LogicalPlan::Aggregate(
1809 Aggregate::try_new(Arc::new(input_plan), group_expr.clone(), aggr_expr.clone())
1810 .unwrap(),
1811 );
1812
1813 let result = logical_plan_to_proof_plan(&agg_plan, &SCHEMAS()).unwrap();
1815
1816 let expected = DynProofPlan::new_group_by(
1818 vec![ColumnExpr::new(ColumnRef::new(
1819 TABLE_REF_TABLE(),
1820 "a".into(),
1821 ColumnType::BigInt,
1822 ))],
1823 vec![AliasedDynProofExpr {
1824 expr: DynProofExpr::new_column(ColumnRef::new(
1825 TABLE_REF_TABLE(),
1826 "b".into(),
1827 ColumnType::Int,
1828 )),
1829 alias: "SUM(table.b)".into(),
1830 }],
1831 "COUNT(Int64(1))".into(),
1832 TableExpr {
1833 table_ref: TABLE_REF_TABLE(),
1834 },
1835 DynProofExpr::new_column(ColumnRef::new(
1836 TABLE_REF_TABLE(),
1837 "d".into(),
1838 ColumnType::Boolean,
1839 )),
1840 );
1841
1842 assert_eq!(result, expected);
1843 }
1844
1845 #[test]
1847 fn we_can_convert_supported_agg_plan_to_proof_plan() {
1848 let group_expr = vec![df_column("table", "a")];
1850
1851 let aggr_expr = vec![
1853 SUM_B(), COUNT_1(), ];
1856
1857 let filter_exprs = vec![
1859 df_column("table", "d"), ];
1861
1862 let input_plan = LogicalPlan::TableScan(
1864 TableScan::try_new(
1865 "table",
1866 TABLE_SOURCE(),
1867 Some(vec![0, 1, 2, 3]),
1868 filter_exprs,
1869 None,
1870 )
1871 .unwrap(),
1872 );
1873
1874 let agg_plan = LogicalPlan::Aggregate(
1875 Aggregate::try_new(Arc::new(input_plan), group_expr.clone(), aggr_expr.clone())
1876 .unwrap(),
1877 );
1878
1879 let proj_plan = LogicalPlan::Projection(
1880 Projection::try_new(
1881 vec![
1882 df_column("table", "a"),
1883 Expr::Column(Column::new(
1884 None::<TableReference>,
1885 "SUM(table.b)".to_string(),
1886 ))
1887 .alias("sum_b"),
1888 Expr::Column(Column::new(
1889 None::<TableReference>,
1890 "COUNT(Int64(1))".to_string(),
1891 ))
1892 .alias("count_1"),
1893 ],
1894 Arc::new(agg_plan),
1895 )
1896 .unwrap(),
1897 );
1898
1899 let result = logical_plan_to_proof_plan(&proj_plan, &SCHEMAS()).unwrap();
1901
1902 let expected = DynProofPlan::new_group_by(
1904 vec![ColumnExpr::new(ColumnRef::new(
1905 TABLE_REF_TABLE(),
1906 "a".into(),
1907 ColumnType::BigInt,
1908 ))],
1909 vec![AliasedDynProofExpr {
1910 expr: DynProofExpr::new_column(ColumnRef::new(
1911 TABLE_REF_TABLE(),
1912 "b".into(),
1913 ColumnType::Int,
1914 )),
1915 alias: "sum_b".into(),
1916 }],
1917 "count_1".into(),
1918 TableExpr {
1919 table_ref: TABLE_REF_TABLE(),
1920 },
1921 DynProofExpr::new_column(ColumnRef::new(
1922 TABLE_REF_TABLE(),
1923 "d".into(),
1924 ColumnType::Boolean,
1925 )),
1926 );
1927
1928 assert_eq!(result, expected);
1929 }
1930
1931 #[test]
1932 fn we_cannot_convert_unsupported_agg_plan_to_proof_plan() {
1933 let group_expr = vec![df_column("table", "a")];
1935
1936 let aggr_expr = vec![
1938 SUM_B(), COUNT_1(), ];
1941
1942 let filter_exprs = vec![
1944 df_column("table", "d"), ];
1946
1947 let input_plan = LogicalPlan::TableScan(
1949 TableScan::try_new(
1950 "table",
1951 TABLE_SOURCE(),
1952 Some(vec![0, 1, 2, 3]),
1953 filter_exprs,
1954 None,
1955 )
1956 .unwrap(),
1957 );
1958
1959 let agg_plan = LogicalPlan::Aggregate(
1960 Aggregate::try_new(Arc::new(input_plan), group_expr.clone(), aggr_expr.clone())
1961 .unwrap(),
1962 );
1963
1964 let proj_plan = LogicalPlan::Projection(
1965 Projection::try_new(
1966 vec![df_column("table", "a").add(df_column("table", "a"))],
1967 Arc::new(agg_plan),
1968 )
1969 .unwrap(),
1970 );
1971
1972 assert!(matches!(
1974 logical_plan_to_proof_plan(&proj_plan, &SCHEMAS()),
1975 Err(PlannerError::UnsupportedLogicalPlan { .. })
1976 ));
1977 }
1978
1979 #[test]
1981 fn we_cannot_convert_unsupported_logical_plan_to_proof_plan() {
1982 let plan = LogicalPlan::Prepare(Prepare {
1983 name: "not_a_real_plan".to_string(),
1984 data_types: vec![],
1985 input: Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1986 produce_one_row: false,
1987 schema: Arc::new(DFSchema::empty()),
1988 })),
1989 });
1990 let schemas = SCHEMAS();
1991 assert!(matches!(
1992 logical_plan_to_proof_plan(&plan, &schemas),
1993 Err(PlannerError::UnsupportedLogicalPlan { .. })
1994 ));
1995 }
1996
1997 #[test]
1998 fn we_can_error_if_not_inner_join() {
1999 let plan = LogicalPlan::Prepare(Prepare {
2001 name: "not_a_real_plan".to_string(),
2002 data_types: vec![],
2003 input: Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
2004 produce_one_row: false,
2005 schema: Arc::new(DFSchema::empty()),
2006 })),
2007 });
2008 let schemas = SCHEMAS();
2009 let join_err = join_to_proof_plan(
2010 &Join {
2011 left: Arc::new(plan.clone()),
2012 right: Arc::new(plan.clone()),
2013 on: Vec::new(),
2014 filter: None,
2015 join_type: JoinType::Left,
2016 join_constraint: JoinConstraint::On,
2017 schema: Arc::new(DFSchema::empty()),
2018 null_equals_null: false,
2019 },
2020 &schemas,
2021 &plan,
2022 )
2023 .unwrap_err();
2024 assert!(
2025 matches!(join_err, PlannerError::UnsupportedLogicalPlan { plan: logical_plan } if *logical_plan == plan )
2026 );
2027 }
2028}