proof_of_sql_planner/
plan.rs

1use super::{
2    aggregate_function_to_proof_expr, column_to_column_ref, expr_to_proof_expr,
3    schema_to_column_fields, table_reference_to_table_ref, AggregateFunc, PlannerError,
4    PlannerResult,
5};
6use alloc::vec::Vec;
7use datafusion::{
8    common::{DFSchema, JoinConstraint, JoinType},
9    logical_expr::{
10        expr::Alias, Aggregate, Expr, Join, Limit, LogicalPlan, Projection, TableScan, Union,
11    },
12    sql::{sqlparser::ast::Ident, TableReference},
13};
14use indexmap::{IndexMap, IndexSet};
15use proof_of_sql::{
16    base::database::{ColumnRef, ColumnType, LiteralValue, SchemaAccessor, TableRef},
17    sql::{
18        proof::ProofPlan,
19        proof_exprs::{AliasedDynProofExpr, ColumnExpr, DynProofExpr, TableExpr},
20        proof_plans::{DynProofPlan, SortMergeJoinExec},
21    },
22};
23
24/// Get `AliasedDynProofExpr` from a `TableRef`, column indices for projection as well as
25/// input and output schemas
26///
27/// Note that at least in the current implementation of `DataFusion`
28/// the output schema should be a subset of the input schema
29/// and that no aliasing should take place.
30/// However that shouldn't be taken for granted.
31fn get_aliased_dyn_proof_exprs(
32    table_ref: &TableRef,
33    projection: &[usize],
34    input_schema: &[(Ident, ColumnType)],
35    output_schema: &DFSchema,
36) -> PlannerResult<Vec<AliasedDynProofExpr>> {
37    projection
38        .iter()
39        .enumerate()
40        .map(
41            |(output_index, input_index)| -> PlannerResult<AliasedDynProofExpr> {
42                // Get output column name / alias
43                let alias: Ident = output_schema.field(output_index).name().as_str().into();
44                let (input_column_name, data_type) = input_schema
45                    .get(*input_index)
46                    .ok_or(PlannerError::ColumnNotFound)?;
47                let expr = DynProofExpr::new_column(ColumnRef::new(
48                    table_ref.clone(),
49                    input_column_name.clone(),
50                    *data_type,
51                ));
52                Ok(AliasedDynProofExpr { expr, alias })
53            },
54        )
55        .collect::<PlannerResult<Vec<_>>>()
56}
57
58/// Convert a `TableScan` without filters or fetch limit to a `DynProofPlan`
59fn table_scan_to_projection(
60    table_name: &TableReference,
61    schemas: &impl SchemaAccessor,
62    projection: &[usize],
63    projected_schema: &DFSchema,
64) -> PlannerResult<DynProofPlan> {
65    // Check if the table exists
66    let table_ref = table_reference_to_table_ref(table_name)?;
67    let input_schema = schemas.lookup_schema(&table_ref);
68    // Get aliased expressions
69    let aliased_dyn_proof_exprs =
70        get_aliased_dyn_proof_exprs(&table_ref, projection, &input_schema, projected_schema)?;
71    let input_column_fields = schema_to_column_fields(input_schema);
72    let table_exec = DynProofPlan::new_table(table_ref, input_column_fields);
73    Ok(DynProofPlan::new_projection(
74        aliased_dyn_proof_exprs,
75        table_exec,
76    ))
77}
78
79/// Convert a `TableScan` with filters but without fetch limit to a `DynProofPlan`
80///
81/// # Panics
82/// Panics if there are no filters which should not happen if called from `logical_plan_to_proof_plan`
83fn table_scan_to_filter(
84    table_name: &TableReference,
85    schemas: &impl SchemaAccessor,
86    projection: &[usize],
87    projected_schema: &DFSchema,
88    filters: &[Expr],
89) -> PlannerResult<DynProofPlan> {
90    // Check if the table exists
91    let table_ref = table_reference_to_table_ref(table_name)?;
92    let input_schema = schemas.lookup_schema(&table_ref);
93    // Get aliased expressions
94    let aliased_dyn_proof_exprs =
95        get_aliased_dyn_proof_exprs(&table_ref, projection, &input_schema, projected_schema)?;
96    let table_expr = TableExpr { table_ref };
97    // Filter
98    let consolidated_filter_proof_expr = filters
99        .iter()
100        .map(|f| expr_to_proof_expr(f, &input_schema))
101        .reduce(|a, b| Ok(DynProofExpr::try_new_and(a?, b?)?))
102        .expect("At least one filter expression is required")?;
103    Ok(DynProofPlan::new_filter(
104        aliased_dyn_proof_exprs,
105        table_expr,
106        consolidated_filter_proof_expr,
107    ))
108}
109
110fn try_get_schema_as_vec_from_df_schema(
111    df_schema: &DFSchema,
112) -> PlannerResult<Vec<(Ident, ColumnType)>> {
113    df_schema
114        .inner()
115        .fields()
116        .into_iter()
117        .map(|f| {
118            ColumnType::try_from(f.data_type().clone())
119                .map_err(|_| PlannerError::UnsupportedDataType {
120                    data_type: f.data_type().clone(),
121                })
122                .map(|t| (Ident::from(f.name().as_ref()), t))
123        })
124        .collect::<Result<Vec<_>, _>>()
125}
126
127/// Converts a [`datafusion::logical_expr::Projection`] to a [`DynProofPlan`]
128fn projection_to_proof_plan(
129    expr: &[Expr],
130    input: &LogicalPlan,
131    output_schema: &DFSchema,
132    schemas: &impl SchemaAccessor,
133) -> PlannerResult<DynProofPlan> {
134    let input_plan = logical_plan_to_proof_plan(input, schemas)?;
135    let input_schema = try_get_schema_as_vec_from_df_schema(input.schema())?;
136    let aliased_exprs = expr
137        .iter()
138        .zip(output_schema.fields().into_iter())
139        .map(|(e, field)| -> PlannerResult<AliasedDynProofExpr> {
140            let proof_expr = expr_to_proof_expr(e, &input_schema)?;
141            let alias = field.name().as_str().into();
142            Ok(AliasedDynProofExpr {
143                expr: proof_expr,
144                alias,
145            })
146        })
147        .collect::<PlannerResult<Vec<_>>>()?;
148    Ok(DynProofPlan::new_projection(aliased_exprs, input_plan))
149}
150
151/// Convert a [`datafusion::logical_plan::LogicalPlan`] to a [`DynProofPlan`] for GROUP BYs
152///
153/// TODO: Improve how we handle GROUP BYs so that all the tech debt is resolved
154///
155/// # Panics
156/// The code should never panic
157fn aggregate_to_proof_plan(
158    input: &LogicalPlan,
159    group_expr: &[Expr],
160    aggr_expr: &[Expr],
161    schemas: &impl SchemaAccessor,
162    alias_map: &IndexMap<&str, &str>,
163) -> PlannerResult<DynProofPlan> {
164    // Check that all of `group_expr` are columns and get their names
165    let group_columns = group_expr
166        .iter()
167        .map(|e| match e {
168            Expr::Column(c) => Ok(c),
169            _ => Err(PlannerError::UnsupportedLogicalPlan {
170                plan: Box::new(input.clone()),
171            }),
172        })
173        .collect::<PlannerResult<Vec<_>>>()?;
174    match input {
175        // Only TableScan without fetch is supported
176        LogicalPlan::TableScan(TableScan {
177            table_name,
178            filters,
179            fetch: None,
180            ..
181        }) => {
182            let table_ref = table_reference_to_table_ref(table_name)?;
183            let input_schema = schemas.lookup_schema(&table_ref);
184            let table_expr = TableExpr { table_ref };
185            // Filter
186            let consolidated_filter_proof_expr = filters
187                .iter()
188                .map(|f| expr_to_proof_expr(f, &input_schema))
189                .reduce(|a, b| Ok(DynProofExpr::try_new_and(a?, b?)?))
190                .unwrap_or_else(|| Ok(DynProofExpr::new_literal(LiteralValue::Boolean(true))))?;
191            // Aggregate
192            // Prove that the ordering of `aggr_expr` is
193            // 1. All group columns according to `group_columns`
194            // 2. (Optional) All the SUMs
195            // 3. COUNT
196            if aggr_expr.is_empty() {
197                return Err(PlannerError::UnsupportedLogicalPlan {
198                    plan: Box::new(input.clone()),
199                });
200            }
201            let agg_aliased_proof_exprs: Vec<((AggregateFunc, DynProofExpr), Ident)> = aggr_expr
202                .iter()
203                .map(|e| match e.clone().unalias() {
204                    Expr::AggregateFunction(agg) => {
205                        let name_string = e.display_name()?;
206                        let name = name_string.as_str();
207                        let alias = alias_map.get(&name).ok_or_else(|| {
208                            PlannerError::UnsupportedLogicalPlan {
209                                plan: Box::new(input.clone()),
210                            }
211                        })?;
212                        Ok((
213                            aggregate_function_to_proof_expr(&agg, &input_schema)?,
214                            (*alias).into(),
215                        ))
216                    }
217                    _ => Err(PlannerError::UnsupportedLogicalPlan {
218                        plan: Box::new(input.clone()),
219                    }),
220                })
221                .collect::<PlannerResult<Vec<_>>>()?;
222            // Check that the last expression is COUNT and the rest are SUMs
223            let (sum_tuples, count_tuple) =
224                agg_aliased_proof_exprs.split_at(agg_aliased_proof_exprs.len() - 1);
225            let sum_is_compliant = sum_tuples
226                .iter()
227                .all(|((op, _), _)| matches!(op, AggregateFunc::Sum));
228            let count_is_compliant = count_tuple
229                .iter()
230                .all(|((op, _), _)| matches!(op, AggregateFunc::Count));
231            if !sum_is_compliant || !count_is_compliant {
232                return Err(PlannerError::UnsupportedLogicalPlan {
233                    plan: Box::new(input.clone()),
234                });
235            }
236            let count_alias = agg_aliased_proof_exprs
237                .last()
238                .expect("We have already checked that this exists")
239                .1
240                .clone();
241            // `group_by_exprs`
242            let group_by_exprs = group_columns
243                .iter()
244                .map(|column| {
245                    Ok(ColumnExpr::new(column_to_column_ref(
246                        column,
247                        &input_schema,
248                    )?))
249                })
250                .collect::<PlannerResult<Vec<_>>>()?;
251            // `sum_expr`
252            let sum_expr = sum_tuples
253                .iter()
254                .map(|((_, expr), alias)| AliasedDynProofExpr {
255                    expr: expr.clone(),
256                    alias: alias.clone(),
257                })
258                .collect::<Vec<_>>();
259            Ok(DynProofPlan::new_group_by(
260                group_by_exprs,
261                sum_expr,
262                count_alias,
263                table_expr,
264                consolidated_filter_proof_expr,
265            ))
266        }
267        _ => Err(PlannerError::UnsupportedLogicalPlan {
268            plan: Box::new(input.clone()),
269        }),
270    }
271}
272
273fn join_to_proof_plan(
274    join: &Join,
275    schema_accessor: &impl SchemaAccessor,
276    plan: &LogicalPlan,
277) -> PlannerResult<DynProofPlan> {
278    if join.join_type != JoinType::Inner || join.join_constraint != JoinConstraint::On {
279        return Err(PlannerError::UnsupportedLogicalPlan {
280            plan: Box::new(plan.clone()),
281        });
282    }
283    let left_plan = Box::new(logical_plan_to_proof_plan(&join.left, schema_accessor)?);
284    let right_plan = Box::new(logical_plan_to_proof_plan(&join.right, schema_accessor)?);
285    let left_column_result_fields = left_plan
286        .get_column_result_fields()
287        .into_iter()
288        .map(|c| c.name())
289        .collect::<IndexSet<_>>();
290    let right_column_result_fields = right_plan
291        .get_column_result_fields()
292        .into_iter()
293        .map(|c| c.name())
294        .collect::<IndexSet<_>>();
295    let on_indices_and_idents = join
296        .on
297        .iter()
298        .filter_map(|(left_expr, right_expr)| {
299            Some(match (left_expr, right_expr) {
300                (Expr::Column(col_a), Expr::Column(col_b)) if col_a.name == col_b.name => {
301                    let column_id = Ident::new(col_a.name.clone());
302                    Ok((
303                        (
304                            left_column_result_fields.get_index_of(&column_id)?,
305                            right_column_result_fields.get_index_of(&column_id)?,
306                        ),
307                        column_id,
308                    ))
309                }
310                _ => Err(PlannerError::UnsupportedLogicalPlan {
311                    plan: Box::new(plan.clone()),
312                }),
313            })
314        })
315        .collect::<Result<Vec<_>, _>>()?;
316    let (on_indices, join_idents): (Vec<(usize, usize)>, Vec<Ident>) =
317        on_indices_and_idents.into_iter().unzip();
318    let (left_indices, right_indices): (Vec<usize>, Vec<usize>) = on_indices.into_iter().unzip();
319    let (left_indices_cloned, right_indices_cloned) = (left_indices.clone(), right_indices.clone());
320    let left_other_column_idents = left_column_result_fields
321        .clone()
322        .into_iter()
323        .enumerate()
324        .filter_map(|(i, col_ident)| (!left_indices.contains(&i)).then_some(col_ident));
325    let right_other_column_idents = right_column_result_fields
326        .into_iter()
327        .enumerate()
328        .filter_map(|(i, col_ident)| (!right_indices.contains(&i)).then_some(col_ident));
329    Ok(DynProofPlan::SortMergeJoin(SortMergeJoinExec::new(
330        left_plan,
331        right_plan,
332        left_indices_cloned,
333        right_indices_cloned,
334        join_idents
335            .into_iter()
336            .chain(left_other_column_idents)
337            .chain(right_other_column_idents)
338            .collect(),
339    )))
340}
341
342/// Visit a [`datafusion::logical_plan::LogicalPlan`] and return a [`DynProofPlan`]
343#[expect(clippy::too_many_lines)]
344pub fn logical_plan_to_proof_plan(
345    plan: &LogicalPlan,
346    schema_accessor: &impl SchemaAccessor,
347) -> PlannerResult<DynProofPlan> {
348    match plan {
349        LogicalPlan::EmptyRelation { .. } => Ok(DynProofPlan::new_empty()),
350        // `projection` shouldn't be None in analyzed and optimized plans
351        LogicalPlan::TableScan(TableScan {
352            table_name,
353            projection: Some(projection),
354            projected_schema,
355            filters,
356            fetch,
357            ..
358        }) => {
359            let base_plan = if filters.is_empty() {
360                table_scan_to_projection(table_name, schema_accessor, projection, projected_schema)
361            } else {
362                table_scan_to_filter(
363                    table_name,
364                    schema_accessor,
365                    projection,
366                    projected_schema,
367                    filters,
368                )
369            }?;
370            if let Some(fetch) = fetch {
371                Ok(DynProofPlan::new_slice(base_plan, 0, Some(*fetch)))
372            } else {
373                Ok(base_plan)
374            }
375        }
376        // Aggregation
377        LogicalPlan::Aggregate(Aggregate {
378            input,
379            group_expr,
380            aggr_expr,
381            schema,
382            ..
383        }) => {
384            let name_strings = group_expr
385                .iter()
386                .chain(aggr_expr.iter())
387                .map(Expr::display_name)
388                .collect::<Result<Vec<_>, _>>()?;
389            let alias_map = name_strings
390                .iter()
391                .zip(schema.fields().iter())
392                .map(|(name_string, field)| {
393                    let name = name_string.as_str();
394                    let alias = field.name().as_str();
395                    Ok((name, alias))
396                })
397                .collect::<PlannerResult<IndexMap<_, _>>>()?;
398            aggregate_to_proof_plan(input, group_expr, aggr_expr, schema_accessor, &alias_map)
399        }
400        // Projection
401        LogicalPlan::Projection(Projection {
402            input,
403            expr,
404            schema,
405            ..
406        }) => {
407            match &**input {
408                LogicalPlan::Aggregate(Aggregate {
409                    input: agg_input,
410                    group_expr,
411                    aggr_expr,
412                    ..
413                }) => {
414                    // Check whether the last layer is identity
415                    let alias_map = expr
416                        .iter()
417                        .map(|e| match e {
418                            Expr::Column(c) => Ok((c.name.as_str(), c.name.as_str())),
419                            Expr::Alias(Alias { expr, name, .. }) => {
420                                if let Expr::Column(c) = expr.as_ref() {
421                                    Ok((c.name.as_str(), name.as_str()))
422                                } else {
423                                    Err(PlannerError::UnsupportedLogicalPlan {
424                                        plan: Box::new(plan.clone()),
425                                    })
426                                }
427                            }
428                            _ => Err(PlannerError::UnsupportedLogicalPlan {
429                                plan: Box::new(plan.clone()),
430                            }),
431                        })
432                        .collect::<PlannerResult<IndexMap<_, _>>>()?;
433                    aggregate_to_proof_plan(
434                        agg_input,
435                        group_expr,
436                        aggr_expr,
437                        schema_accessor,
438                        &alias_map,
439                    )
440                }
441                _ => projection_to_proof_plan(expr, input, schema, schema_accessor),
442            }
443        }
444        // Limit
445        LogicalPlan::Limit(Limit { input, fetch, skip }) => {
446            let input_plan = logical_plan_to_proof_plan(input, schema_accessor)?;
447            Ok(DynProofPlan::new_slice(input_plan, *skip, *fetch))
448        }
449        // Union
450        LogicalPlan::Union(Union { inputs, schema: _ }) => {
451            let input_plans = inputs
452                .iter()
453                .map(|input| logical_plan_to_proof_plan(input, schema_accessor))
454                .collect::<PlannerResult<Vec<_>>>()?;
455            Ok(DynProofPlan::try_new_union(input_plans)?)
456        }
457        LogicalPlan::Join(join) => join_to_proof_plan(join, schema_accessor, plan),
458        _ => Err(PlannerError::UnsupportedLogicalPlan {
459            plan: Box::new(plan.clone()),
460        }),
461    }
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467    use crate::{df_util::*, PoSqlTableSource};
468    use ahash::AHasher;
469    use alloc::{sync::Arc, vec};
470    use arrow::datatypes::DataType;
471    use core::ops::Add;
472    use datafusion::{
473        common::{Column, ScalarValue},
474        logical_expr::{
475            expr::{AggregateFunction, AggregateFunctionDefinition},
476            not, BinaryExpr, EmptyRelation, Operator, Prepare, TableScan, TableSource,
477        },
478        physical_plan,
479    };
480    use indexmap::{indexmap, indexmap_with_default};
481    use proof_of_sql::base::{
482        database::{ColumnField, TestSchemaAccessor},
483        math::decimal::Precision,
484    };
485    use std::hash::BuildHasherDefault;
486
487    const SUM: AggregateFunctionDefinition =
488        AggregateFunctionDefinition::BuiltIn(physical_plan::aggregates::AggregateFunction::Sum);
489    const COUNT: AggregateFunctionDefinition =
490        AggregateFunctionDefinition::BuiltIn(physical_plan::aggregates::AggregateFunction::Count);
491    const AVG: AggregateFunctionDefinition =
492        AggregateFunctionDefinition::BuiltIn(physical_plan::aggregates::AggregateFunction::Avg);
493
494    #[expect(non_snake_case)]
495    fn TABLE_REF_TABLE() -> TableRef {
496        TableRef::from_names(None, "table")
497    }
498
499    #[expect(non_snake_case)]
500    fn SCHEMAS() -> impl SchemaAccessor {
501        let schema: IndexMap<Ident, ColumnType, BuildHasherDefault<AHasher>> = indexmap_with_default! {
502            AHasher;
503            "a".into() => ColumnType::BigInt,
504            "b".into() => ColumnType::Int,
505            "c".into() => ColumnType::VarChar,
506            "d".into() => ColumnType::Boolean
507        };
508        let table_ref = TableRef::new("", "table");
509        let schema_accessor = indexmap_with_default! {
510            AHasher;
511            table_ref => schema
512        };
513        TestSchemaAccessor::new(schema_accessor)
514    }
515
516    #[expect(non_snake_case)]
517    fn UNION_SCHEMAS() -> impl SchemaAccessor {
518        TestSchemaAccessor::new(indexmap_with_default! {AHasher;
519            TableRef::new("", "table1") => indexmap_with_default! {AHasher;
520                "a1".into() => ColumnType::BigInt,
521                "b1".into() => ColumnType::Int
522            },
523            TableRef::new("", "table2") => indexmap_with_default! {AHasher;
524                "a2".into() => ColumnType::BigInt,
525                "b2".into() => ColumnType::Int
526            },
527            TableRef::new("schema", "table3") => indexmap_with_default! {AHasher;
528                "a3".into() => ColumnType::BigInt,
529                "b3".into() => ColumnType::Int
530            },
531        })
532    }
533
534    #[expect(non_snake_case)]
535    fn EMPTY_SCHEMAS() -> impl SchemaAccessor {
536        TestSchemaAccessor::new(indexmap_with_default! {AHasher;})
537    }
538
539    #[expect(non_snake_case)]
540    fn TABLE_SOURCE() -> Arc<dyn TableSource> {
541        Arc::new(PoSqlTableSource::new(vec![
542            ColumnField::new("a".into(), ColumnType::BigInt),
543            ColumnField::new("b".into(), ColumnType::Int),
544            ColumnField::new("c".into(), ColumnType::VarChar),
545            ColumnField::new("d".into(), ColumnType::Boolean),
546        ]))
547    }
548
549    #[expect(non_snake_case)]
550    fn ALIASED_A() -> AliasedDynProofExpr {
551        AliasedDynProofExpr {
552            expr: DynProofExpr::new_column(ColumnRef::new(
553                TABLE_REF_TABLE(),
554                "a".into(),
555                ColumnType::BigInt,
556            )),
557            alias: "a".into(),
558        }
559    }
560
561    #[expect(non_snake_case)]
562    fn ALIASED_B() -> AliasedDynProofExpr {
563        AliasedDynProofExpr {
564            expr: DynProofExpr::new_column(ColumnRef::new(
565                TABLE_REF_TABLE(),
566                "b".into(),
567                ColumnType::Int,
568            )),
569            alias: "b".into(),
570        }
571    }
572
573    #[expect(non_snake_case)]
574    fn ALIASED_C() -> AliasedDynProofExpr {
575        AliasedDynProofExpr {
576            expr: DynProofExpr::new_column(ColumnRef::new(
577                TABLE_REF_TABLE(),
578                "c".into(),
579                ColumnType::VarChar,
580            )),
581            alias: "c".into(),
582        }
583    }
584
585    #[expect(non_snake_case)]
586    fn ALIASED_D() -> AliasedDynProofExpr {
587        AliasedDynProofExpr {
588            expr: DynProofExpr::new_column(ColumnRef::new(
589                TABLE_REF_TABLE(),
590                "d".into(),
591                ColumnType::Boolean,
592            )),
593            alias: "d".into(),
594        }
595    }
596
597    #[expect(non_snake_case)]
598    fn COUNT_1() -> Expr {
599        Expr::AggregateFunction(AggregateFunction {
600            func_def: COUNT,
601            args: vec![Expr::Literal(ScalarValue::Int64(Some(1)))],
602            distinct: false,
603            filter: None,
604            order_by: None,
605            null_treatment: None,
606        })
607    }
608
609    #[expect(non_snake_case)]
610    fn SUM_B() -> Expr {
611        Expr::AggregateFunction(AggregateFunction {
612            func_def: SUM,
613            args: vec![df_column("table", "b")],
614            distinct: false,
615            filter: None,
616            order_by: None,
617            null_treatment: None,
618        })
619    }
620
621    #[expect(non_snake_case)]
622    fn SUM_D() -> Expr {
623        Expr::AggregateFunction(AggregateFunction {
624            func_def: SUM,
625            args: vec![df_column("table", "d")],
626            distinct: false,
627            filter: None,
628            order_by: None,
629            null_treatment: None,
630        })
631    }
632
633    // get_aliased_dyn_proof_exprs
634    #[test]
635    fn we_can_get_aliased_proof_expr_with_specified_projection_columns() {
636        // Unused columns can be of unsupported types
637        let table_ref = TABLE_REF_TABLE();
638        let input_schema = vec![
639            ("a".into(), ColumnType::BigInt),
640            ("b".into(), ColumnType::Int),
641            ("c".into(), ColumnType::VarChar),
642            (
643                "d".into(),
644                ColumnType::Decimal75(Precision::new(5).unwrap(), 1),
645            ), // Unused column
646        ];
647        let output_schema = df_schema("table", vec![("b", DataType::Int32), ("c", DataType::Utf8)]);
648        let result =
649            get_aliased_dyn_proof_exprs(&table_ref, &[1, 2], &input_schema, &output_schema)
650                .unwrap();
651        let expected = vec![ALIASED_B(), ALIASED_C()];
652        assert_eq!(result, expected);
653    }
654
655    #[test]
656    fn we_can_get_aliased_proof_expr_without_specified_projection_columns() {
657        let table_ref = TABLE_REF_TABLE();
658        let input_schema = vec![
659            ("a".into(), ColumnType::BigInt),
660            ("b".into(), ColumnType::Int),
661            ("c".into(), ColumnType::VarChar),
662            ("d".into(), ColumnType::Boolean),
663        ];
664        let output_schema = df_schema(
665            "table",
666            vec![
667                ("a", DataType::Int64),
668                ("b", DataType::Int32),
669                ("c", DataType::Utf8),
670                ("d", DataType::Boolean),
671            ],
672        );
673        let result =
674            get_aliased_dyn_proof_exprs(&table_ref, &[0, 1, 2, 3], &input_schema, &output_schema)
675                .unwrap();
676        let expected = vec![ALIASED_A(), ALIASED_B(), ALIASED_C(), ALIASED_D()];
677        assert_eq!(result, expected);
678    }
679
680    // aggregate_to_proof_plan
681    #[test]
682    fn we_can_aggregate_with_group_by_and_sum_count() {
683        // Setup group expression
684        let group_expr = vec![df_column("table", "a")];
685
686        // Create the aggregate expressions (must follow the pattern: group columns, then SUMs, then COUNT)
687        let aggr_expr = vec![
688            SUM_B(),   // SUM
689            COUNT_1(), // COUNT
690        ];
691
692        // Create the input plan
693        let input_plan = LogicalPlan::TableScan(
694            TableScan::try_new(
695                "table",
696                TABLE_SOURCE(),
697                Some(vec![0, 1, 2, 3]),
698                vec![],
699                None,
700            )
701            .unwrap(),
702        );
703        let alias_map = indexmap! {
704            "a" => "a",
705            "SUM(table.b)" => "sum_b",
706            "COUNT(Int64(1))" => "count_1",
707        };
708
709        // Test the function
710        let result =
711            aggregate_to_proof_plan(&input_plan, &group_expr, &aggr_expr, &SCHEMAS(), &alias_map)
712                .unwrap();
713
714        // Expected result
715        let expected = DynProofPlan::new_group_by(
716            vec![ColumnExpr::new(ColumnRef::new(
717                TABLE_REF_TABLE(),
718                "a".into(),
719                ColumnType::BigInt,
720            ))],
721            vec![AliasedDynProofExpr {
722                expr: DynProofExpr::new_column(ColumnRef::new(
723                    TABLE_REF_TABLE(),
724                    "b".into(),
725                    ColumnType::Int,
726                )),
727                alias: "sum_b".into(),
728            }],
729            "count_1".into(),
730            TableExpr {
731                table_ref: TABLE_REF_TABLE(),
732            },
733            DynProofExpr::new_literal(LiteralValue::Boolean(true)),
734        );
735
736        assert_eq!(result, expected);
737    }
738
739    #[test]
740    fn we_can_aggregate_with_filters() {
741        // Setup group expression
742        let group_expr = vec![df_column("table", "a")];
743
744        // Create the aggregate expressions
745        let aggr_expr = vec![
746            SUM_B(),   // SUM
747            COUNT_1(), // COUNT
748        ];
749
750        // Create filters
751        let filter_exprs = vec![
752            df_column("table", "d"), // Boolean column as filter
753        ];
754
755        // Create the input plan with filters
756        let input_plan = LogicalPlan::TableScan(
757            TableScan::try_new(
758                "table",
759                TABLE_SOURCE(),
760                Some(vec![0, 1, 2, 3]),
761                filter_exprs,
762                None,
763            )
764            .unwrap(),
765        );
766        let alias_map = indexmap! {
767            "a" => "a",
768            "SUM(table.b)" => "sum_b",
769            "COUNT(Int64(1))" => "count_1",
770        };
771
772        // Test the function
773        let result =
774            aggregate_to_proof_plan(&input_plan, &group_expr, &aggr_expr, &SCHEMAS(), &alias_map)
775                .unwrap();
776
777        // Expected result
778        let expected = DynProofPlan::new_group_by(
779            vec![ColumnExpr::new(ColumnRef::new(
780                TABLE_REF_TABLE(),
781                "a".into(),
782                ColumnType::BigInt,
783            ))],
784            vec![AliasedDynProofExpr {
785                expr: DynProofExpr::new_column(ColumnRef::new(
786                    TABLE_REF_TABLE(),
787                    "b".into(),
788                    ColumnType::Int,
789                )),
790                alias: "sum_b".into(),
791            }],
792            "count_1".into(),
793            TableExpr {
794                table_ref: TABLE_REF_TABLE(),
795            },
796            DynProofExpr::new_column(ColumnRef::new(
797                TABLE_REF_TABLE(),
798                "d".into(),
799                ColumnType::Boolean,
800            )),
801        );
802
803        assert_eq!(result, expected);
804    }
805
806    #[test]
807    fn we_can_aggregate_with_multiple_group_columns() {
808        // Setup group expressions
809        let group_expr = vec![df_column("table", "a"), df_column("table", "c")];
810
811        // Create the aggregate expressions
812        let aggr_expr = vec![
813            SUM_B(),   // SUM
814            COUNT_1(), // COUNT
815        ];
816
817        // Create the input plan
818        let input_plan = LogicalPlan::TableScan(
819            TableScan::try_new(
820                "table",
821                TABLE_SOURCE(),
822                Some(vec![0, 1, 2, 3]),
823                vec![],
824                None,
825            )
826            .unwrap(),
827        );
828        let alias_map = indexmap! {
829            "a" => "a",
830            "c" => "c",
831            "SUM(table.b)" => "sum_b",
832            "COUNT(Int64(1))" => "count_1",
833        };
834
835        // Test the function
836        let result =
837            aggregate_to_proof_plan(&input_plan, &group_expr, &aggr_expr, &SCHEMAS(), &alias_map)
838                .unwrap();
839
840        // Expected result
841        let expected = DynProofPlan::new_group_by(
842            vec![
843                ColumnExpr::new(ColumnRef::new(
844                    TABLE_REF_TABLE(),
845                    "a".into(),
846                    ColumnType::BigInt,
847                )),
848                ColumnExpr::new(ColumnRef::new(
849                    TABLE_REF_TABLE(),
850                    "c".into(),
851                    ColumnType::VarChar,
852                )),
853            ],
854            vec![AliasedDynProofExpr {
855                expr: DynProofExpr::new_column(ColumnRef::new(
856                    TABLE_REF_TABLE(),
857                    "b".into(),
858                    ColumnType::Int,
859                )),
860                alias: "sum_b".into(),
861            }],
862            "count_1".into(),
863            TableExpr {
864                table_ref: TABLE_REF_TABLE(),
865            },
866            DynProofExpr::new_literal(LiteralValue::Boolean(true)),
867        );
868
869        assert_eq!(result, expected);
870    }
871
872    #[test]
873    fn we_can_aggregate_with_multiple_sum_expressions() {
874        // Setup group expression
875        let group_expr = vec![df_column("table", "a")];
876
877        // Create the aggregate expressions
878        let aggr_expr = vec![
879            SUM_B(),   // First SUM
880            SUM_D(),   // Second SUM
881            COUNT_1(), // COUNT
882        ];
883
884        // Create the input plan
885        let input_plan = LogicalPlan::TableScan(
886            TableScan::try_new(
887                "table",
888                TABLE_SOURCE(),
889                Some(vec![0, 1, 2, 3]),
890                vec![],
891                None,
892            )
893            .unwrap(),
894        );
895        let alias_map = indexmap! {
896            "a" => "a",
897            "SUM(table.b)" => "sum_b",
898            "SUM(table.d)" => "sum_d",
899            "COUNT(Int64(1))" => "count_1",
900        };
901
902        // Test the function
903        let result =
904            aggregate_to_proof_plan(&input_plan, &group_expr, &aggr_expr, &SCHEMAS(), &alias_map)
905                .unwrap();
906
907        // Expected result
908        let expected = DynProofPlan::new_group_by(
909            vec![ColumnExpr::new(ColumnRef::new(
910                TABLE_REF_TABLE(),
911                "a".into(),
912                ColumnType::BigInt,
913            ))],
914            vec![
915                AliasedDynProofExpr {
916                    expr: DynProofExpr::new_column(ColumnRef::new(
917                        TABLE_REF_TABLE(),
918                        "b".into(),
919                        ColumnType::Int,
920                    )),
921                    alias: "sum_b".into(),
922                },
923                AliasedDynProofExpr {
924                    expr: DynProofExpr::new_column(ColumnRef::new(
925                        TABLE_REF_TABLE(),
926                        "d".into(),
927                        ColumnType::Boolean,
928                    )),
929                    alias: "sum_d".into(),
930                },
931            ],
932            "count_1".into(),
933            TableExpr {
934                table_ref: TABLE_REF_TABLE(),
935            },
936            DynProofExpr::new_literal(LiteralValue::Boolean(true)),
937        );
938
939        assert_eq!(result, expected);
940    }
941
942    #[test]
943    fn we_can_aggregate_without_sum_expressions() {
944        // Setup group expression
945        let group_expr = vec![df_column("table", "a")];
946
947        // Create the aggregate expressions
948        let aggr_expr = vec![
949            COUNT_1(), // COUNT
950        ];
951
952        // Create the input plan
953        let input_plan = LogicalPlan::TableScan(
954            TableScan::try_new(
955                "table",
956                TABLE_SOURCE(),
957                Some(vec![0, 1, 2, 3]),
958                vec![],
959                None,
960            )
961            .unwrap(),
962        );
963        let alias_map = indexmap! {
964            "a" => "a",
965            "COUNT(Int64(1))" => "count_1",
966        };
967
968        // Test the function
969        let result =
970            aggregate_to_proof_plan(&input_plan, &group_expr, &aggr_expr, &SCHEMAS(), &alias_map)
971                .unwrap();
972
973        // Expected result
974        let expected = DynProofPlan::new_group_by(
975            vec![ColumnExpr::new(ColumnRef::new(
976                TABLE_REF_TABLE(),
977                "a".into(),
978                ColumnType::BigInt,
979            ))],
980            vec![], // No SUMs
981            "count_1".into(),
982            TableExpr {
983                table_ref: TABLE_REF_TABLE(),
984            },
985            DynProofExpr::new_literal(LiteralValue::Boolean(true)),
986        );
987
988        assert_eq!(result, expected);
989    }
990
991    // Error case tests
992    #[test]
993    fn we_cannot_aggregate_with_non_column_group_expr() {
994        // Setup group expression with a non-column expression
995        let group_expr = vec![Expr::BinaryExpr(BinaryExpr::new(
996            Box::new(df_column("table", "a")),
997            Operator::Plus,
998            Box::new(df_column("table", "b")),
999        ))];
1000
1001        // Create the aggregate expressions
1002        let aggr_expr = vec![
1003            Expr::BinaryExpr(BinaryExpr::new(
1004                Box::new(df_column("table", "a")),
1005                Operator::Plus,
1006                Box::new(df_column("table", "b")),
1007            )),
1008            COUNT_1(),
1009        ];
1010
1011        // Create the input plan
1012        let input_plan = LogicalPlan::TableScan(
1013            TableScan::try_new(
1014                "table",
1015                TABLE_SOURCE(),
1016                Some(vec![0, 1, 2, 3]),
1017                vec![],
1018                None,
1019            )
1020            .unwrap(),
1021        );
1022        let alias_map = indexmap! {
1023            "a+b" => "res",
1024            "COUNT(Int64(1))" => "count_1",
1025        };
1026
1027        // Test the function - should return an error
1028        let result =
1029            aggregate_to_proof_plan(&input_plan, &group_expr, &aggr_expr, &SCHEMAS(), &alias_map);
1030        assert!(matches!(
1031            result,
1032            Err(PlannerError::UnsupportedLogicalPlan { .. })
1033        ));
1034    }
1035
1036    #[test]
1037    fn we_cannot_aggregate_with_non_aggregate_expression() {
1038        // Setup group expression
1039        let group_expr = vec![df_column("table", "a")];
1040
1041        // Setup a non-aggregate expression
1042        let non_agg_expr = Expr::BinaryExpr(BinaryExpr::new(
1043            Box::new(df_column("table", "b")),
1044            Operator::Plus,
1045            Box::new(df_column("table", "c")),
1046        ));
1047
1048        // Setup aliased expression
1049        let aliased_non_agg = Expr::Alias(Alias {
1050            expr: Box::new(non_agg_expr),
1051            relation: None,
1052            name: "b_plus_c".to_string(),
1053        });
1054
1055        // Create the aggregate expressions
1056        let aggr_expr = vec![
1057            aliased_non_agg, // Non-aggregate expression
1058        ];
1059
1060        // Create the input plan
1061        let input_plan = LogicalPlan::TableScan(
1062            TableScan::try_new(
1063                "table",
1064                TABLE_SOURCE(),
1065                Some(vec![0, 1, 2, 3]),
1066                vec![],
1067                None,
1068            )
1069            .unwrap(),
1070        );
1071        let alias_map = indexmap! {
1072            "b+c" => "b_plus_c",
1073        };
1074
1075        // Test the function - should return an error
1076        let result =
1077            aggregate_to_proof_plan(&input_plan, &group_expr, &aggr_expr, &SCHEMAS(), &alias_map);
1078        assert!(matches!(
1079            result,
1080            Err(PlannerError::UnsupportedLogicalPlan { .. })
1081        ));
1082    }
1083
1084    #[test]
1085    fn we_cannot_aggregate_with_non_sum_aggregate_function() {
1086        // Setup group expression
1087        let group_expr = vec![df_column("table", "a")];
1088
1089        // Setup a non-SUM aggregate function (e.g., Avg)
1090        let avg_expr = Expr::AggregateFunction(AggregateFunction {
1091            func_def: AVG,
1092            args: vec![df_column("table", "b")],
1093            distinct: false,
1094            filter: None,
1095            order_by: None,
1096            null_treatment: None,
1097        });
1098
1099        // Setup aliased expressions
1100        let aliased_avg = Expr::Alias(Alias {
1101            expr: Box::new(avg_expr),
1102            relation: None,
1103            name: "avg_b".to_string(),
1104        });
1105
1106        // Create the aggregate expressions
1107        let aggr_expr = vec![
1108            aliased_avg, // AVG aggregate (not SUM)
1109            COUNT_1(),   // COUNT
1110        ];
1111
1112        // Create the input plan
1113        let input_plan = LogicalPlan::TableScan(
1114            TableScan::try_new(
1115                "table",
1116                TABLE_SOURCE(),
1117                Some(vec![0, 1, 2, 3]),
1118                vec![],
1119                None,
1120            )
1121            .unwrap(),
1122        );
1123        let alias_map = indexmap! {
1124            "a" => "a",
1125            "AVG(table.b)" => "avg_b",
1126            "COUNT(Int64(1))" => "count_1",
1127        };
1128
1129        // Test the function - should return an error
1130        let result =
1131            aggregate_to_proof_plan(&input_plan, &group_expr, &aggr_expr, &SCHEMAS(), &alias_map);
1132        assert!(matches!(
1133            result,
1134            Err(PlannerError::UnsupportedLogicalPlan { .. })
1135        ));
1136    }
1137
1138    #[test]
1139    fn we_cannot_aggregate_with_non_count_last_aggregate() {
1140        // Setup group expression
1141        let group_expr = vec![df_column("table", "a")];
1142
1143        // Setup SUM aggregates
1144        let sum_expr1 = Expr::AggregateFunction(AggregateFunction {
1145            func_def: SUM,
1146            args: vec![df_column("table", "b")],
1147            distinct: false,
1148            filter: None,
1149            order_by: None,
1150            null_treatment: None,
1151        });
1152
1153        let sum_expr2 = Expr::AggregateFunction(AggregateFunction {
1154            func_def: SUM,
1155            args: vec![df_column("table", "c")],
1156            distinct: false,
1157            filter: None,
1158            order_by: None,
1159            null_treatment: None,
1160        });
1161
1162        // Setup aliased expressions
1163        let aliased_sum1 = Expr::Alias(Alias {
1164            expr: Box::new(sum_expr1),
1165            relation: None,
1166            name: "sum_b".to_string(),
1167        });
1168
1169        let aliased_sum2 = Expr::Alias(Alias {
1170            expr: Box::new(sum_expr2),
1171            relation: None,
1172            name: "sum_c".to_string(),
1173        });
1174
1175        // Create the aggregate expressions with no COUNT at the end
1176        let aggr_expr = vec![
1177            aliased_sum1, // SUM
1178            aliased_sum2, // Another SUM (should be COUNT)
1179        ];
1180
1181        // Create the input plan
1182        let input_plan = LogicalPlan::TableScan(
1183            TableScan::try_new(
1184                "table",
1185                TABLE_SOURCE(),
1186                Some(vec![0, 1, 2, 3]),
1187                vec![],
1188                None,
1189            )
1190            .unwrap(),
1191        );
1192        let alias_map = indexmap! {
1193            "a" => "a",
1194            "SUM(table.b)" => "sum_b",
1195            "SUM(c)" => "sum_c",
1196        };
1197
1198        // Test the function - should return an error
1199        let result =
1200            aggregate_to_proof_plan(&input_plan, &group_expr, &aggr_expr, &SCHEMAS(), &alias_map);
1201        assert!(matches!(
1202            result,
1203            Err(PlannerError::UnsupportedLogicalPlan { .. })
1204        ));
1205    }
1206
1207    #[test]
1208    fn we_cannot_aggregate_with_fetch_limit() {
1209        // Setup group expression
1210        let group_expr = vec![df_column("table", "a")];
1211
1212        // Create the aggregate expressions
1213        let aggr_expr = vec![
1214            COUNT_1(), // COUNT
1215        ];
1216
1217        // Create the input plan with fetch limit
1218        let input_plan = LogicalPlan::TableScan(
1219            TableScan::try_new(
1220                "table",
1221                TABLE_SOURCE(),
1222                Some(vec![0, 1, 2, 3]),
1223                vec![],
1224                Some(10),
1225            )
1226            .unwrap(),
1227        );
1228        let alias_map = indexmap! {
1229            "a" => "a",
1230            "COUNT(Int64(1))" => "count_1",
1231        };
1232
1233        // Test the function - should return an error because fetch limit is not supported
1234        let result =
1235            aggregate_to_proof_plan(&input_plan, &group_expr, &aggr_expr, &SCHEMAS(), &alias_map);
1236        assert!(matches!(
1237            result,
1238            Err(PlannerError::UnsupportedLogicalPlan { .. })
1239        ));
1240    }
1241
1242    #[test]
1243    fn we_cannot_aggregate_with_non_table_scan_input() {
1244        // Setup group expression
1245        let group_expr = vec![df_column("table", "a")];
1246
1247        // Create the aggregate expressions
1248        let aggr_expr = vec![
1249            COUNT_1(), // COUNT
1250        ];
1251
1252        // Create a non-TableScan input plan
1253        let input_plan = LogicalPlan::EmptyRelation(EmptyRelation {
1254            produce_one_row: false,
1255            schema: Arc::new(DFSchema::empty()),
1256        });
1257        let alias_map = indexmap! {
1258            "a" => "a",
1259            "COUNT(Int64(1))" => "count_1",
1260        };
1261
1262        // Test the function - should return an error
1263        let result =
1264            aggregate_to_proof_plan(&input_plan, &group_expr, &aggr_expr, &SCHEMAS(), &alias_map);
1265        assert!(matches!(
1266            result,
1267            Err(PlannerError::UnsupportedLogicalPlan { .. })
1268        ));
1269    }
1270
1271    // EmptyRelation
1272    #[test]
1273    fn we_can_convert_empty_plan_to_proof_plan() {
1274        let empty_plan = LogicalPlan::EmptyRelation(EmptyRelation {
1275            produce_one_row: false,
1276            schema: Arc::new(DFSchema::empty()),
1277        });
1278        let result = logical_plan_to_proof_plan(&empty_plan, &EMPTY_SCHEMAS()).unwrap();
1279        assert_eq!(result, DynProofPlan::new_empty());
1280    }
1281
1282    // TableScan
1283    #[test]
1284    fn we_can_convert_table_scan_plan_to_proof_plan_without_filter_or_fetch_limit() {
1285        let plan = LogicalPlan::TableScan(
1286            TableScan::try_new("table", TABLE_SOURCE(), Some(vec![0, 1, 2]), vec![], None).unwrap(),
1287        );
1288        let schemas = SCHEMAS();
1289        let result = logical_plan_to_proof_plan(&plan, &schemas).unwrap();
1290        let expected = DynProofPlan::new_projection(
1291            vec![ALIASED_A(), ALIASED_B(), ALIASED_C()],
1292            DynProofPlan::new_table(
1293                TABLE_REF_TABLE(),
1294                vec![
1295                    ColumnField::new("a".into(), ColumnType::BigInt),
1296                    ColumnField::new("b".into(), ColumnType::Int),
1297                    ColumnField::new("c".into(), ColumnType::VarChar),
1298                    ColumnField::new("d".into(), ColumnType::Boolean),
1299                ],
1300            ),
1301        );
1302        assert_eq!(result, expected);
1303    }
1304
1305    #[test]
1306    fn we_cannot_convert_table_scan_plan_to_proof_plan_without_filter_or_fetch_limit_if_bad_schemas(
1307    ) {
1308        let plan = LogicalPlan::TableScan(
1309            TableScan::try_new(
1310                "table",
1311                TABLE_SOURCE(),
1312                Some(vec![0, 1, 2, 3]),
1313                vec![],
1314                None,
1315            )
1316            .unwrap(),
1317        );
1318        let schemas = EMPTY_SCHEMAS();
1319        let result = logical_plan_to_proof_plan(&plan, &schemas);
1320        assert!(matches!(result, Err(PlannerError::ColumnNotFound)));
1321    }
1322
1323    #[test]
1324    fn we_can_convert_table_scan_plan_to_proof_plan_with_filter_but_without_fetch_limit() {
1325        let filter_exprs = vec![
1326            df_column("table", "a").eq(df_column("table", "b")),
1327            df_column("table", "d"),
1328        ];
1329        let plan = LogicalPlan::TableScan(
1330            TableScan::try_new(
1331                "table",
1332                TABLE_SOURCE(),
1333                Some(vec![0, 2]),
1334                filter_exprs,
1335                None,
1336            )
1337            .unwrap(),
1338        );
1339        let schemas = SCHEMAS();
1340        let result = logical_plan_to_proof_plan(&plan, &schemas).unwrap();
1341        let expected = DynProofPlan::new_filter(
1342            vec![ALIASED_A(), ALIASED_C()],
1343            TableExpr {
1344                table_ref: TABLE_REF_TABLE(),
1345            },
1346            DynProofExpr::try_new_and(
1347                DynProofExpr::try_new_equals(
1348                    DynProofExpr::new_column(ColumnRef::new(
1349                        TABLE_REF_TABLE(),
1350                        "a".into(),
1351                        ColumnType::BigInt,
1352                    )),
1353                    DynProofExpr::new_column(ColumnRef::new(
1354                        TABLE_REF_TABLE(),
1355                        "b".into(),
1356                        ColumnType::Int,
1357                    )),
1358                )
1359                .unwrap(),
1360                DynProofExpr::new_column(ColumnRef::new(
1361                    TABLE_REF_TABLE(),
1362                    "d".into(),
1363                    ColumnType::Boolean,
1364                )),
1365            )
1366            .unwrap(),
1367        );
1368        assert_eq!(result, expected);
1369    }
1370
1371    #[test]
1372    fn we_cannot_convert_table_scan_plan_to_proof_plan_with_filter_but_without_fetch_limit_if_bad_schemas(
1373    ) {
1374        let filter_exprs = vec![
1375            df_column("table", "a").eq(df_column("table", "b")),
1376            df_column("table", "d"),
1377        ];
1378        let plan = LogicalPlan::TableScan(
1379            TableScan::try_new(
1380                "table",
1381                TABLE_SOURCE(),
1382                Some(vec![0, 2]),
1383                filter_exprs,
1384                None,
1385            )
1386            .unwrap(),
1387        );
1388        let schemas = EMPTY_SCHEMAS();
1389        let result = logical_plan_to_proof_plan(&plan, &schemas);
1390        assert!(matches!(result, Err(PlannerError::ColumnNotFound)));
1391    }
1392
1393    #[test]
1394    fn we_can_convert_table_scan_plan_to_proof_plan_without_filter_but_with_fetch_limit() {
1395        let plan = LogicalPlan::TableScan(
1396            TableScan::try_new(
1397                "table",
1398                TABLE_SOURCE(),
1399                Some(vec![0, 1, 2, 3]),
1400                vec![],
1401                Some(2),
1402            )
1403            .unwrap(),
1404        );
1405        let schemas = SCHEMAS();
1406        let result = logical_plan_to_proof_plan(&plan, &schemas).unwrap();
1407        let expected = DynProofPlan::new_slice(
1408            DynProofPlan::new_projection(
1409                vec![ALIASED_A(), ALIASED_B(), ALIASED_C(), ALIASED_D()],
1410                DynProofPlan::new_table(
1411                    TABLE_REF_TABLE(),
1412                    vec![
1413                        ColumnField::new("a".into(), ColumnType::BigInt),
1414                        ColumnField::new("b".into(), ColumnType::Int),
1415                        ColumnField::new("c".into(), ColumnType::VarChar),
1416                        ColumnField::new("d".into(), ColumnType::Boolean),
1417                    ],
1418                ),
1419            ),
1420            0,
1421            Some(2),
1422        );
1423        assert_eq!(result, expected);
1424    }
1425
1426    #[test]
1427    fn we_can_convert_table_scan_plan_to_proof_plan_with_filter_and_fetch_limit() {
1428        let filter_exprs = vec![
1429            df_column("table", "a").gt(df_column("table", "b")),
1430            df_column("table", "d"),
1431        ];
1432        let plan = LogicalPlan::TableScan(
1433            TableScan::try_new(
1434                "table",
1435                TABLE_SOURCE(),
1436                Some(vec![0, 3]),
1437                filter_exprs,
1438                Some(5),
1439            )
1440            .unwrap(),
1441        );
1442        let schemas = SCHEMAS();
1443        let result = logical_plan_to_proof_plan(&plan, &schemas).unwrap();
1444        let expected = DynProofPlan::new_slice(
1445            DynProofPlan::new_filter(
1446                vec![ALIASED_A(), ALIASED_D()],
1447                TableExpr {
1448                    table_ref: TABLE_REF_TABLE(),
1449                },
1450                DynProofExpr::try_new_and(
1451                    DynProofExpr::try_new_inequality(
1452                        DynProofExpr::new_column(ColumnRef::new(
1453                            TABLE_REF_TABLE(),
1454                            "a".into(),
1455                            ColumnType::BigInt,
1456                        )),
1457                        DynProofExpr::new_column(ColumnRef::new(
1458                            TABLE_REF_TABLE(),
1459                            "b".into(),
1460                            ColumnType::Int,
1461                        )),
1462                        false,
1463                    )
1464                    .unwrap(),
1465                    DynProofExpr::new_column(ColumnRef::new(
1466                        TABLE_REF_TABLE(),
1467                        "d".into(),
1468                        ColumnType::Boolean,
1469                    )),
1470                )
1471                .unwrap(),
1472            ),
1473            0,
1474            Some(5),
1475        );
1476        assert_eq!(result, expected);
1477    }
1478
1479    // Projection
1480    #[test]
1481    fn we_can_convert_projection_plan_to_proof_plan() {
1482        let plan = LogicalPlan::Projection(
1483            Projection::try_new(
1484                vec![
1485                    Expr::BinaryExpr(BinaryExpr::new(
1486                        Box::new(df_column("table", "a")),
1487                        Operator::Plus,
1488                        Box::new(df_column("table", "b")),
1489                    )),
1490                    not(df_column("table", "d")),
1491                ],
1492                Arc::new(LogicalPlan::TableScan(
1493                    TableScan::try_new("table", TABLE_SOURCE(), Some(vec![0, 1, 3]), vec![], None)
1494                        .unwrap(),
1495                )),
1496            )
1497            .unwrap(),
1498        );
1499        let schemas = SCHEMAS();
1500        let result = logical_plan_to_proof_plan(&plan, &schemas).unwrap();
1501        let expected = DynProofPlan::new_projection(
1502            vec![
1503                AliasedDynProofExpr {
1504                    expr: DynProofExpr::try_new_add(
1505                        DynProofExpr::new_column(ColumnRef::new(
1506                            TABLE_REF_TABLE(),
1507                            "a".into(),
1508                            ColumnType::BigInt,
1509                        )),
1510                        DynProofExpr::new_column(ColumnRef::new(
1511                            TABLE_REF_TABLE(),
1512                            "b".into(),
1513                            ColumnType::Int,
1514                        )),
1515                    )
1516                    .unwrap(),
1517                    alias: "table.a + table.b".into(),
1518                },
1519                AliasedDynProofExpr {
1520                    expr: DynProofExpr::try_new_not(DynProofExpr::new_column(ColumnRef::new(
1521                        TABLE_REF_TABLE(),
1522                        "d".into(),
1523                        ColumnType::Boolean,
1524                    )))
1525                    .unwrap(),
1526                    alias: "NOT table.d".into(),
1527                },
1528            ],
1529            DynProofPlan::new_projection(
1530                vec![ALIASED_A(), ALIASED_B(), ALIASED_D()],
1531                DynProofPlan::new_table(
1532                    TABLE_REF_TABLE(),
1533                    vec![
1534                        ColumnField::new("a".into(), ColumnType::BigInt),
1535                        ColumnField::new("b".into(), ColumnType::Int),
1536                        ColumnField::new("c".into(), ColumnType::VarChar),
1537                        ColumnField::new("d".into(), ColumnType::Boolean),
1538                    ],
1539                ),
1540            ),
1541        );
1542        assert_eq!(result, expected);
1543    }
1544
1545    // Limit
1546    // Note that either fetch or skip will exist or optimizer will remove the Limit node
1547    #[test]
1548    fn we_can_convert_limit_plan_with_fetch_and_skip_to_proof_plan() {
1549        let plan = LogicalPlan::Limit(Limit {
1550            input: Arc::new(LogicalPlan::TableScan(
1551                TableScan::try_new(
1552                    "table",
1553                    TABLE_SOURCE(),
1554                    Some(vec![0, 1]),
1555                    vec![],
1556                    // Optimizer will put a fetch on TableScan if there is a non-empty fetch in an outer Limit
1557                    Some(5),
1558                )
1559                .unwrap(),
1560            )),
1561            fetch: Some(3),
1562            skip: 2,
1563        });
1564        let schemas = SCHEMAS();
1565        let result = logical_plan_to_proof_plan(&plan, &schemas).unwrap();
1566        let expected = DynProofPlan::new_slice(
1567            DynProofPlan::new_slice(
1568                DynProofPlan::new_projection(
1569                    vec![ALIASED_A(), ALIASED_B()],
1570                    DynProofPlan::new_table(
1571                        TABLE_REF_TABLE(),
1572                        vec![
1573                            ColumnField::new("a".into(), ColumnType::BigInt),
1574                            ColumnField::new("b".into(), ColumnType::Int),
1575                            ColumnField::new("c".into(), ColumnType::VarChar),
1576                            ColumnField::new("d".into(), ColumnType::Boolean),
1577                        ],
1578                    ),
1579                ),
1580                0,
1581                Some(5),
1582            ),
1583            2,
1584            Some(3),
1585        );
1586        assert_eq!(result, expected);
1587    }
1588
1589    #[test]
1590    fn we_can_convert_limit_plan_with_fetch_no_skip_to_proof_plan() {
1591        //TODO: Optimize proof plan to remove redundant slices
1592        let plan = LogicalPlan::Limit(Limit {
1593            input: Arc::new(LogicalPlan::TableScan(
1594                TableScan::try_new("table", TABLE_SOURCE(), Some(vec![0, 1]), vec![], Some(3))
1595                    .unwrap(),
1596            )),
1597            fetch: Some(3),
1598            skip: 0,
1599        });
1600
1601        let schemas = SCHEMAS();
1602        let result = logical_plan_to_proof_plan(&plan, &schemas).unwrap();
1603
1604        let expected = DynProofPlan::new_slice(
1605            DynProofPlan::new_slice(
1606                DynProofPlan::new_projection(
1607                    vec![ALIASED_A(), ALIASED_B()],
1608                    DynProofPlan::new_table(
1609                        TABLE_REF_TABLE(),
1610                        vec![
1611                            ColumnField::new("a".into(), ColumnType::BigInt),
1612                            ColumnField::new("b".into(), ColumnType::Int),
1613                            ColumnField::new("c".into(), ColumnType::VarChar),
1614                            ColumnField::new("d".into(), ColumnType::Boolean),
1615                        ],
1616                    ),
1617                ),
1618                0,
1619                Some(3),
1620            ),
1621            0,
1622            Some(3),
1623        );
1624        assert_eq!(result, expected);
1625    }
1626
1627    #[test]
1628    fn we_can_convert_limit_plan_with_skip_no_fetch_to_proof_plan() {
1629        let plan = LogicalPlan::Limit(Limit {
1630            input: Arc::new(LogicalPlan::TableScan(
1631                TableScan::try_new("table", TABLE_SOURCE(), Some(vec![0, 1]), vec![], None)
1632                    .unwrap(),
1633            )),
1634            fetch: None,
1635            skip: 2,
1636        });
1637
1638        let schemas = SCHEMAS();
1639        let result = logical_plan_to_proof_plan(&plan, &schemas).unwrap();
1640
1641        let expected = DynProofPlan::new_slice(
1642            DynProofPlan::new_projection(
1643                vec![ALIASED_A(), ALIASED_B()],
1644                DynProofPlan::new_table(
1645                    TABLE_REF_TABLE(),
1646                    vec![
1647                        ColumnField::new("a".into(), ColumnType::BigInt),
1648                        ColumnField::new("b".into(), ColumnType::Int),
1649                        ColumnField::new("c".into(), ColumnType::VarChar),
1650                        ColumnField::new("d".into(), ColumnType::Boolean),
1651                    ],
1652                ),
1653            ),
1654            2,
1655            None,
1656        );
1657        assert_eq!(result, expected);
1658    }
1659
1660    // Union
1661    #[expect(clippy::too_many_lines)]
1662    #[test]
1663    fn we_can_convert_union_plan_to_proof_plan() {
1664        let plan = LogicalPlan::Union(Union {
1665            schema: Arc::new(df_schema(
1666                "table",
1667                vec![("a", DataType::Int64), ("b", DataType::Int32)],
1668            )),
1669            inputs: vec![
1670                Arc::new(LogicalPlan::TableScan(
1671                    TableScan::try_new("table1", TABLE_SOURCE(), Some(vec![0, 1]), vec![], None)
1672                        .unwrap(),
1673                )),
1674                Arc::new(LogicalPlan::TableScan(
1675                    TableScan::try_new("table2", TABLE_SOURCE(), Some(vec![0, 1]), vec![], None)
1676                        .unwrap(),
1677                )),
1678                Arc::new(LogicalPlan::TableScan(
1679                    TableScan::try_new(
1680                        "schema.table3",
1681                        TABLE_SOURCE(),
1682                        Some(vec![0, 1]),
1683                        vec![],
1684                        None,
1685                    )
1686                    .unwrap(),
1687                )),
1688            ],
1689        });
1690        let schemas = UNION_SCHEMAS();
1691        let result = logical_plan_to_proof_plan(&plan, &schemas).unwrap();
1692        let expected = DynProofPlan::try_new_union(vec![
1693            DynProofPlan::new_projection(
1694                vec![
1695                    AliasedDynProofExpr {
1696                        expr: DynProofExpr::new_column(ColumnRef::new(
1697                            TableRef::from_names(None, "table1"),
1698                            "a1".into(),
1699                            ColumnType::BigInt,
1700                        )),
1701                        alias: "a".into(),
1702                    },
1703                    AliasedDynProofExpr {
1704                        expr: DynProofExpr::new_column(ColumnRef::new(
1705                            TableRef::from_names(None, "table1"),
1706                            "b1".into(),
1707                            ColumnType::Int,
1708                        )),
1709                        alias: "b".into(),
1710                    },
1711                ],
1712                DynProofPlan::new_table(
1713                    TableRef::from_names(None, "table1"),
1714                    vec![
1715                        ColumnField::new("a1".into(), ColumnType::BigInt),
1716                        ColumnField::new("b1".into(), ColumnType::Int),
1717                    ],
1718                ),
1719            ),
1720            DynProofPlan::new_projection(
1721                vec![
1722                    AliasedDynProofExpr {
1723                        expr: DynProofExpr::new_column(ColumnRef::new(
1724                            TableRef::from_names(None, "table2"),
1725                            "a2".into(),
1726                            ColumnType::BigInt,
1727                        )),
1728                        alias: "a".into(),
1729                    },
1730                    AliasedDynProofExpr {
1731                        expr: DynProofExpr::new_column(ColumnRef::new(
1732                            TableRef::from_names(None, "table2"),
1733                            "b2".into(),
1734                            ColumnType::Int,
1735                        )),
1736                        alias: "b".into(),
1737                    },
1738                ],
1739                DynProofPlan::new_table(
1740                    TableRef::from_names(None, "table2"),
1741                    vec![
1742                        ColumnField::new("a2".into(), ColumnType::BigInt),
1743                        ColumnField::new("b2".into(), ColumnType::Int),
1744                    ],
1745                ),
1746            ),
1747            DynProofPlan::new_projection(
1748                vec![
1749                    AliasedDynProofExpr {
1750                        expr: DynProofExpr::new_column(ColumnRef::new(
1751                            TableRef::from_names(Some("schema"), "table3"),
1752                            "a3".into(),
1753                            ColumnType::BigInt,
1754                        )),
1755                        alias: "a".into(),
1756                    },
1757                    AliasedDynProofExpr {
1758                        expr: DynProofExpr::new_column(ColumnRef::new(
1759                            TableRef::from_names(Some("schema"), "table3"),
1760                            "b3".into(),
1761                            ColumnType::Int,
1762                        )),
1763                        alias: "b".into(),
1764                    },
1765                ],
1766                DynProofPlan::new_table(
1767                    TableRef::from_names(Some("schema"), "table3"),
1768                    vec![
1769                        ColumnField::new("a3".into(), ColumnType::BigInt),
1770                        ColumnField::new("b3".into(), ColumnType::Int),
1771                    ],
1772                ),
1773            ),
1774        ])
1775        .unwrap();
1776        assert_eq!(result, expected);
1777    }
1778
1779    // Aggregate
1780    #[test]
1781    fn we_can_convert_supported_simple_agg_plan_to_proof_plan() {
1782        // Setup group expression
1783        let group_expr = vec![df_column("table", "a")];
1784
1785        // Create the aggregate expressions
1786        let aggr_expr = vec![
1787            SUM_B(),   // SUM
1788            COUNT_1(), // COUNT
1789        ];
1790
1791        // Create filters
1792        let filter_exprs = vec![
1793            df_column("table", "d"), // Boolean column as filter
1794        ];
1795
1796        // Create the input plan with filters
1797        let input_plan = LogicalPlan::TableScan(
1798            TableScan::try_new(
1799                "table",
1800                TABLE_SOURCE(),
1801                Some(vec![0, 1, 2, 3]),
1802                filter_exprs,
1803                None,
1804            )
1805            .unwrap(),
1806        );
1807
1808        let agg_plan = LogicalPlan::Aggregate(
1809            Aggregate::try_new(Arc::new(input_plan), group_expr.clone(), aggr_expr.clone())
1810                .unwrap(),
1811        );
1812
1813        // Test the function
1814        let result = logical_plan_to_proof_plan(&agg_plan, &SCHEMAS()).unwrap();
1815
1816        // Expected result
1817        let expected = DynProofPlan::new_group_by(
1818            vec![ColumnExpr::new(ColumnRef::new(
1819                TABLE_REF_TABLE(),
1820                "a".into(),
1821                ColumnType::BigInt,
1822            ))],
1823            vec![AliasedDynProofExpr {
1824                expr: DynProofExpr::new_column(ColumnRef::new(
1825                    TABLE_REF_TABLE(),
1826                    "b".into(),
1827                    ColumnType::Int,
1828                )),
1829                alias: "SUM(table.b)".into(),
1830            }],
1831            "COUNT(Int64(1))".into(),
1832            TableExpr {
1833                table_ref: TABLE_REF_TABLE(),
1834            },
1835            DynProofExpr::new_column(ColumnRef::new(
1836                TABLE_REF_TABLE(),
1837                "d".into(),
1838                ColumnType::Boolean,
1839            )),
1840        );
1841
1842        assert_eq!(result, expected);
1843    }
1844
1845    // Aggregate + Projection
1846    #[test]
1847    fn we_can_convert_supported_agg_plan_to_proof_plan() {
1848        // Setup group expression
1849        let group_expr = vec![df_column("table", "a")];
1850
1851        // Create the aggregate expressions
1852        let aggr_expr = vec![
1853            SUM_B(),   // SUM
1854            COUNT_1(), // COUNT
1855        ];
1856
1857        // Create filters
1858        let filter_exprs = vec![
1859            df_column("table", "d"), // Boolean column as filter
1860        ];
1861
1862        // Create the input plan with filters
1863        let input_plan = LogicalPlan::TableScan(
1864            TableScan::try_new(
1865                "table",
1866                TABLE_SOURCE(),
1867                Some(vec![0, 1, 2, 3]),
1868                filter_exprs,
1869                None,
1870            )
1871            .unwrap(),
1872        );
1873
1874        let agg_plan = LogicalPlan::Aggregate(
1875            Aggregate::try_new(Arc::new(input_plan), group_expr.clone(), aggr_expr.clone())
1876                .unwrap(),
1877        );
1878
1879        let proj_plan = LogicalPlan::Projection(
1880            Projection::try_new(
1881                vec![
1882                    df_column("table", "a"),
1883                    Expr::Column(Column::new(
1884                        None::<TableReference>,
1885                        "SUM(table.b)".to_string(),
1886                    ))
1887                    .alias("sum_b"),
1888                    Expr::Column(Column::new(
1889                        None::<TableReference>,
1890                        "COUNT(Int64(1))".to_string(),
1891                    ))
1892                    .alias("count_1"),
1893                ],
1894                Arc::new(agg_plan),
1895            )
1896            .unwrap(),
1897        );
1898
1899        // Test the function
1900        let result = logical_plan_to_proof_plan(&proj_plan, &SCHEMAS()).unwrap();
1901
1902        // Expected result
1903        let expected = DynProofPlan::new_group_by(
1904            vec![ColumnExpr::new(ColumnRef::new(
1905                TABLE_REF_TABLE(),
1906                "a".into(),
1907                ColumnType::BigInt,
1908            ))],
1909            vec![AliasedDynProofExpr {
1910                expr: DynProofExpr::new_column(ColumnRef::new(
1911                    TABLE_REF_TABLE(),
1912                    "b".into(),
1913                    ColumnType::Int,
1914                )),
1915                alias: "sum_b".into(),
1916            }],
1917            "count_1".into(),
1918            TableExpr {
1919                table_ref: TABLE_REF_TABLE(),
1920            },
1921            DynProofExpr::new_column(ColumnRef::new(
1922                TABLE_REF_TABLE(),
1923                "d".into(),
1924                ColumnType::Boolean,
1925            )),
1926        );
1927
1928        assert_eq!(result, expected);
1929    }
1930
1931    #[test]
1932    fn we_cannot_convert_unsupported_agg_plan_to_proof_plan() {
1933        // Setup group expression
1934        let group_expr = vec![df_column("table", "a")];
1935
1936        // Create the aggregate expressions
1937        let aggr_expr = vec![
1938            SUM_B(),   // SUM
1939            COUNT_1(), // COUNT
1940        ];
1941
1942        // Create filters
1943        let filter_exprs = vec![
1944            df_column("table", "d"), // Boolean column as filter
1945        ];
1946
1947        // Create the input plan with filters
1948        let input_plan = LogicalPlan::TableScan(
1949            TableScan::try_new(
1950                "table",
1951                TABLE_SOURCE(),
1952                Some(vec![0, 1, 2, 3]),
1953                filter_exprs,
1954                None,
1955            )
1956            .unwrap(),
1957        );
1958
1959        let agg_plan = LogicalPlan::Aggregate(
1960            Aggregate::try_new(Arc::new(input_plan), group_expr.clone(), aggr_expr.clone())
1961                .unwrap(),
1962        );
1963
1964        let proj_plan = LogicalPlan::Projection(
1965            Projection::try_new(
1966                vec![df_column("table", "a").add(df_column("table", "a"))],
1967                Arc::new(agg_plan),
1968            )
1969            .unwrap(),
1970        );
1971
1972        // Test the function
1973        assert!(matches!(
1974            logical_plan_to_proof_plan(&proj_plan, &SCHEMAS()),
1975            Err(PlannerError::UnsupportedLogicalPlan { .. })
1976        ));
1977    }
1978
1979    // Unsupported
1980    #[test]
1981    fn we_cannot_convert_unsupported_logical_plan_to_proof_plan() {
1982        let plan = LogicalPlan::Prepare(Prepare {
1983            name: "not_a_real_plan".to_string(),
1984            data_types: vec![],
1985            input: Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1986                produce_one_row: false,
1987                schema: Arc::new(DFSchema::empty()),
1988            })),
1989        });
1990        let schemas = SCHEMAS();
1991        assert!(matches!(
1992            logical_plan_to_proof_plan(&plan, &schemas),
1993            Err(PlannerError::UnsupportedLogicalPlan { .. })
1994        ));
1995    }
1996
1997    #[test]
1998    fn we_can_error_if_not_inner_join() {
1999        // Most of the arguments here are bogus. The only thing that really matters is the join type.
2000        let plan = LogicalPlan::Prepare(Prepare {
2001            name: "not_a_real_plan".to_string(),
2002            data_types: vec![],
2003            input: Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
2004                produce_one_row: false,
2005                schema: Arc::new(DFSchema::empty()),
2006            })),
2007        });
2008        let schemas = SCHEMAS();
2009        let join_err = join_to_proof_plan(
2010            &Join {
2011                left: Arc::new(plan.clone()),
2012                right: Arc::new(plan.clone()),
2013                on: Vec::new(),
2014                filter: None,
2015                join_type: JoinType::Left,
2016                join_constraint: JoinConstraint::On,
2017                schema: Arc::new(DFSchema::empty()),
2018                null_equals_null: false,
2019            },
2020            &schemas,
2021            &plan,
2022        )
2023        .unwrap_err();
2024        assert!(
2025            matches!(join_err, PlannerError::UnsupportedLogicalPlan { plan: logical_plan } if *logical_plan == plan )
2026        );
2027    }
2028}