proof_of_sql_planner/
proof_plan_with_postprocessing.rs

1use super::{
2    logical_plan_to_proof_plan, postprocessing::SelectPostprocessing, PlannerError, PlannerResult,
3};
4use datafusion::logical_expr::{LogicalPlan, Projection};
5use proof_of_sql::{base::database::SchemaAccessor, sql::proof_plans::DynProofPlan};
6
7/// A [`DynProofPlan`] with optional postprocessing
8#[derive(Debug, Clone)]
9pub struct ProofPlanWithPostprocessing {
10    plan: DynProofPlan,
11    postprocessing: Option<SelectPostprocessing>,
12}
13
14impl ProofPlanWithPostprocessing {
15    /// Create a new `ProofPlanWithPostprocessing`
16    #[must_use]
17    pub fn new(plan: DynProofPlan, postprocessing: Option<SelectPostprocessing>) -> Self {
18        Self {
19            plan,
20            postprocessing,
21        }
22    }
23
24    /// Get the `DynProofPlan`
25    #[must_use]
26    pub fn plan(&self) -> &DynProofPlan {
27        &self.plan
28    }
29
30    /// Get the postprocessing
31    #[must_use]
32    pub fn postprocessing(&self) -> Option<&SelectPostprocessing> {
33        self.postprocessing.as_ref()
34    }
35}
36
37/// Visit a [`datafusion::logical_plan::LogicalPlan`] and return a [`DynProofPlan`] with optional postprocessing
38pub fn logical_plan_to_proof_plan_with_postprocessing(
39    plan: &LogicalPlan,
40    schemas: &impl SchemaAccessor,
41) -> PlannerResult<ProofPlanWithPostprocessing> {
42    let result_proof_plan = logical_plan_to_proof_plan(plan, schemas);
43    match result_proof_plan {
44        Ok(proof_plan) => Ok(ProofPlanWithPostprocessing::new(proof_plan, None)),
45        Err(_err) => {
46            match plan {
47                // For projections, we can apply a postprocessing step
48                LogicalPlan::Projection(Projection { input, expr, .. }) => {
49                    // If the inner `LogicalPlan` is not provable we error out
50                    let input_proof_plan = logical_plan_to_proof_plan(input, schemas)?;
51                    let postprocessing = SelectPostprocessing::new(expr.clone());
52                    Ok(ProofPlanWithPostprocessing::new(
53                        input_proof_plan,
54                        Some(postprocessing),
55                    ))
56                }
57                _ => Err(PlannerError::UnsupportedLogicalPlan { plan: plan.clone() }),
58            }
59        }
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use super::*;
66    use crate::{df_util::*, PoSqlTableSource};
67    use ahash::AHasher;
68    use alloc::sync::Arc;
69    use core::ops::Mul;
70    use datafusion::{
71        common::{Column, DFSchema, ScalarValue},
72        logical_expr::{
73            expr::{AggregateFunction, AggregateFunctionDefinition},
74            Aggregate, EmptyRelation, Expr, LogicalPlan, Prepare, TableScan, TableSource,
75        },
76        physical_plan,
77        sql::TableReference,
78    };
79    use indexmap::{indexmap_with_default, IndexMap};
80    use proof_of_sql::{
81        base::database::{ColumnField, ColumnRef, ColumnType, TableRef, TestSchemaAccessor},
82        sql::{
83            proof_exprs::{AliasedDynProofExpr, ColumnExpr, DynProofExpr, TableExpr},
84            proof_plans::DynProofPlan,
85        },
86    };
87    use sqlparser::ast::Ident;
88    use std::hash::BuildHasherDefault;
89
90    const SUM: AggregateFunctionDefinition =
91        AggregateFunctionDefinition::BuiltIn(physical_plan::aggregates::AggregateFunction::Sum);
92    const COUNT: AggregateFunctionDefinition =
93        AggregateFunctionDefinition::BuiltIn(physical_plan::aggregates::AggregateFunction::Count);
94
95    #[expect(non_snake_case)]
96    fn TABLE_LANGUAGES() -> TableRef {
97        TableRef::from_names(None, "languages")
98    }
99
100    #[expect(non_snake_case)]
101    fn SCHEMAS() -> impl SchemaAccessor {
102        let schema: IndexMap<Ident, ColumnType, BuildHasherDefault<AHasher>> = indexmap_with_default! {
103            AHasher;
104            "name".into() => ColumnType::VarChar,
105            "language_family".into() => ColumnType::VarChar,
106            "uses_abjad".into() => ColumnType::Boolean,
107            "num_of_letters".into() => ColumnType::BigInt,
108            "grace".into() => ColumnType::VarChar,
109            "love".into() => ColumnType::VarChar,
110            "joy".into() => ColumnType::VarChar,
111            "peace".into() => ColumnType::VarChar
112        };
113        let table_ref = TableRef::new("", "languages");
114        let schema_accessor = indexmap_with_default! {
115            AHasher;
116            table_ref => schema
117        };
118        TestSchemaAccessor::new(schema_accessor)
119    }
120
121    #[expect(non_snake_case)]
122    fn TABLE_SOURCE() -> Arc<dyn TableSource> {
123        Arc::new(PoSqlTableSource::new(vec![
124            ColumnField::new("name".into(), ColumnType::VarChar),
125            ColumnField::new("language_family".into(), ColumnType::VarChar),
126            ColumnField::new("uses_abjad".into(), ColumnType::Boolean),
127            ColumnField::new("num_of_letters".into(), ColumnType::BigInt),
128            ColumnField::new("grace".into(), ColumnType::VarChar),
129            ColumnField::new("love".into(), ColumnType::VarChar),
130            ColumnField::new("joy".into(), ColumnType::VarChar),
131            ColumnField::new("peace".into(), ColumnType::VarChar),
132        ]))
133    }
134
135    #[expect(non_snake_case)]
136    fn ALIASED_NAME() -> AliasedDynProofExpr {
137        AliasedDynProofExpr {
138            expr: DynProofExpr::new_column(ColumnRef::new(
139                TABLE_LANGUAGES(),
140                "name".into(),
141                ColumnType::VarChar,
142            )),
143            alias: "name".into(),
144        }
145    }
146
147    #[expect(non_snake_case)]
148    fn ALIASED_GRACE() -> AliasedDynProofExpr {
149        AliasedDynProofExpr {
150            expr: DynProofExpr::new_column(ColumnRef::new(
151                TABLE_LANGUAGES(),
152                "grace".into(),
153                ColumnType::VarChar,
154            )),
155            alias: "grace".into(),
156        }
157    }
158
159    #[expect(non_snake_case)]
160    fn ALIASED_LOVE() -> AliasedDynProofExpr {
161        AliasedDynProofExpr {
162            expr: DynProofExpr::new_column(ColumnRef::new(
163                TABLE_LANGUAGES(),
164                "love".into(),
165                ColumnType::VarChar,
166            )),
167            alias: "love".into(),
168        }
169    }
170
171    #[expect(non_snake_case)]
172    fn ALIASED_JOY() -> AliasedDynProofExpr {
173        AliasedDynProofExpr {
174            expr: DynProofExpr::new_column(ColumnRef::new(
175                TABLE_LANGUAGES(),
176                "joy".into(),
177                ColumnType::VarChar,
178            )),
179            alias: "joy".into(),
180        }
181    }
182
183    #[expect(non_snake_case)]
184    fn COUNT_1() -> Expr {
185        Expr::AggregateFunction(AggregateFunction {
186            func_def: COUNT,
187            args: vec![Expr::Literal(ScalarValue::Int64(Some(1)))],
188            distinct: false,
189            filter: None,
190            order_by: None,
191            null_treatment: None,
192        })
193    }
194
195    #[expect(non_snake_case)]
196    fn SUM_NUM_LETTERS() -> Expr {
197        Expr::AggregateFunction(AggregateFunction {
198            func_def: SUM,
199            args: vec![df_column("languages", "num_of_letters")],
200            distinct: false,
201            filter: None,
202            order_by: None,
203            null_treatment: None,
204        })
205    }
206
207    #[expect(non_snake_case)]
208    fn ALIASED_PEACE() -> AliasedDynProofExpr {
209        AliasedDynProofExpr {
210            expr: DynProofExpr::new_column(ColumnRef::new(
211                TABLE_LANGUAGES(),
212                "peace".into(),
213                ColumnType::VarChar,
214            )),
215            alias: "peace".into(),
216        }
217    }
218
219    #[test]
220    fn we_can_convert_logical_plan_to_proof_plan_without_postprocessing() {
221        let plan = LogicalPlan::TableScan(
222            TableScan::try_new(
223                "languages",
224                TABLE_SOURCE(),
225                Some(vec![0, 4, 5, 6, 7]),
226                vec![],
227                None,
228            )
229            .unwrap(),
230        );
231        let schemas = SCHEMAS();
232        let result = logical_plan_to_proof_plan_with_postprocessing(&plan, &schemas).unwrap();
233        let expected = DynProofPlan::new_projection(
234            vec![
235                ALIASED_NAME(),
236                ALIASED_GRACE(),
237                ALIASED_LOVE(),
238                ALIASED_JOY(),
239                ALIASED_PEACE(),
240            ],
241            DynProofPlan::new_table(
242                TABLE_LANGUAGES(),
243                vec![
244                    ColumnField::new("name".into(), ColumnType::VarChar),
245                    ColumnField::new("language_family".into(), ColumnType::VarChar),
246                    ColumnField::new("uses_abjad".into(), ColumnType::Boolean),
247                    ColumnField::new("num_of_letters".into(), ColumnType::BigInt),
248                    ColumnField::new("grace".into(), ColumnType::VarChar),
249                    ColumnField::new("love".into(), ColumnType::VarChar),
250                    ColumnField::new("joy".into(), ColumnType::VarChar),
251                    ColumnField::new("peace".into(), ColumnType::VarChar),
252                ],
253            ),
254        );
255        assert_eq!(result.plan(), &expected);
256        assert!(result.postprocessing().is_none());
257    }
258
259    #[test]
260    fn we_can_convert_logical_plan_to_proof_plan_with_postprocessing() {
261        // Setup group expression
262        let group_expr = vec![df_column("languages", "language_family")];
263
264        // Create the aggregate expressions
265        let aggr_expr = vec![
266            SUM_NUM_LETTERS(), // SUM
267            COUNT_1(),         // COUNT
268        ];
269
270        // Create filters
271        let filter_exprs = vec![
272            df_column("languages", "uses_abjad"), // Boolean column as filter
273        ];
274
275        // Create the input plan with filters
276        let input_plan = LogicalPlan::TableScan(
277            TableScan::try_new(
278                "languages",
279                TABLE_SOURCE(),
280                Some(vec![1, 2, 3]),
281                filter_exprs,
282                None,
283            )
284            .unwrap(),
285        );
286
287        let agg_plan = LogicalPlan::Aggregate(
288            Aggregate::try_new(Arc::new(input_plan), group_expr.clone(), aggr_expr.clone())
289                .unwrap(),
290        );
291
292        let proj_plan = LogicalPlan::Projection(
293            Projection::try_new(
294                vec![
295                    df_column("languages", "language_family"),
296                    Expr::Column(Column::new(
297                        None::<TableReference>,
298                        "COUNT(Int64(1))".to_string(),
299                    ))
300                    .mul(Expr::Literal(ScalarValue::Int64(Some(2))))
301                    .alias("twice_num_languages_using_abjad"),
302                    Expr::Column(Column::new(
303                        None::<TableReference>,
304                        "SUM(languages.num_of_letters)".to_string(),
305                    ))
306                    .alias("sum_num_of_letters"),
307                ],
308                Arc::new(agg_plan),
309            )
310            .unwrap(),
311        );
312
313        // Test the function
314        let result =
315            logical_plan_to_proof_plan_with_postprocessing(&proj_plan, &SCHEMAS()).unwrap();
316
317        // Expected result
318        let expected_plan = DynProofPlan::new_group_by(
319            vec![ColumnExpr::new(ColumnRef::new(
320                TABLE_LANGUAGES(),
321                "language_family".into(),
322                ColumnType::VarChar,
323            ))],
324            vec![AliasedDynProofExpr {
325                expr: DynProofExpr::new_column(ColumnRef::new(
326                    TABLE_LANGUAGES(),
327                    "num_of_letters".into(),
328                    ColumnType::BigInt,
329                )),
330                alias: "SUM(languages.num_of_letters)".into(),
331            }],
332            "COUNT(Int64(1))".into(),
333            TableExpr {
334                table_ref: TABLE_LANGUAGES(),
335            },
336            DynProofExpr::new_column(ColumnRef::new(
337                TABLE_LANGUAGES(),
338                "uses_abjad".into(),
339                ColumnType::Boolean,
340            )),
341        );
342
343        let expected_postprocessing = SelectPostprocessing::new(vec![
344            df_column("languages", "language_family"),
345            Expr::Column(Column::new(
346                None::<TableReference>,
347                "COUNT(Int64(1))".to_string(),
348            ))
349            .mul(Expr::Literal(ScalarValue::Int64(Some(2))))
350            .alias("twice_num_languages_using_abjad"),
351            Expr::Column(Column::new(
352                None::<TableReference>,
353                "SUM(languages.num_of_letters)".to_string(),
354            ))
355            .alias("sum_num_of_letters"),
356        ]);
357
358        assert_eq!(result.plan(), &expected_plan);
359        assert_eq!(result.postprocessing().unwrap(), &expected_postprocessing);
360    }
361
362    // Unsupported
363    #[test]
364    fn we_cannot_convert_unsupported_logical_plan_to_proof_plan_with_postprocessing() {
365        let plan = LogicalPlan::Prepare(Prepare {
366            name: "not_a_real_plan".to_string(),
367            data_types: vec![],
368            input: Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
369                produce_one_row: false,
370                schema: Arc::new(DFSchema::empty()),
371            })),
372        });
373        let schemas = SCHEMAS();
374        assert!(matches!(
375            logical_plan_to_proof_plan_with_postprocessing(&plan, &schemas),
376            Err(PlannerError::UnsupportedLogicalPlan { .. })
377        ));
378    }
379}