proof_of_sql_planner/
proof_plan_with_postprocessing.rs

1use super::{
2    logical_plan_to_proof_plan, postprocessing::SelectPostprocessing, PlannerError, PlannerResult,
3};
4use datafusion::logical_expr::{LogicalPlan, Projection};
5use proof_of_sql::{base::database::SchemaAccessor, sql::proof_plans::DynProofPlan};
6
7/// A [`DynProofPlan`] with optional postprocessing
8#[derive(Debug, Clone)]
9pub struct ProofPlanWithPostprocessing {
10    plan: DynProofPlan,
11    postprocessing: Option<SelectPostprocessing>,
12}
13
14impl ProofPlanWithPostprocessing {
15    /// Create a new `ProofPlanWithPostprocessing`
16    #[must_use]
17    pub fn new(plan: DynProofPlan, postprocessing: Option<SelectPostprocessing>) -> Self {
18        Self {
19            plan,
20            postprocessing,
21        }
22    }
23
24    /// Get the `DynProofPlan`
25    #[must_use]
26    pub fn plan(&self) -> &DynProofPlan {
27        &self.plan
28    }
29
30    /// Get the postprocessing
31    #[must_use]
32    pub fn postprocessing(&self) -> Option<&SelectPostprocessing> {
33        self.postprocessing.as_ref()
34    }
35}
36
37/// Visit a [`datafusion::logical_plan::LogicalPlan`] and return a [`DynProofPlan`] with optional postprocessing
38pub fn logical_plan_to_proof_plan_with_postprocessing(
39    plan: &LogicalPlan,
40    schemas: &impl SchemaAccessor,
41) -> PlannerResult<ProofPlanWithPostprocessing> {
42    let result_proof_plan = logical_plan_to_proof_plan(plan, schemas);
43    match result_proof_plan {
44        Ok(proof_plan) => Ok(ProofPlanWithPostprocessing::new(proof_plan, None)),
45        Err(_err) => {
46            match plan {
47                // For projections, we can apply a postprocessing step
48                LogicalPlan::Projection(Projection { input, expr, .. }) => {
49                    // If the inner `LogicalPlan` is not provable we error out
50                    let input_proof_plan = logical_plan_to_proof_plan(input, schemas)?;
51                    let postprocessing = SelectPostprocessing::new(expr.clone());
52                    Ok(ProofPlanWithPostprocessing::new(
53                        input_proof_plan,
54                        Some(postprocessing),
55                    ))
56                }
57                _ => Err(PlannerError::UnsupportedLogicalPlan {
58                    plan: Box::new(plan.clone()),
59                }),
60            }
61        }
62    }
63}
64
65#[cfg(test)]
66mod tests {
67    use super::*;
68    use crate::{df_util::*, PoSqlTableSource};
69    use ahash::AHasher;
70    use alloc::sync::Arc;
71    use core::ops::Mul;
72    use datafusion::{
73        common::{Column, DFSchema, ScalarValue},
74        logical_expr::{
75            expr::{AggregateFunction, AggregateFunctionDefinition},
76            Aggregate, EmptyRelation, Expr, LogicalPlan, Prepare, TableScan, TableSource,
77        },
78        physical_plan,
79        sql::TableReference,
80    };
81    use indexmap::{indexmap_with_default, IndexMap};
82    use proof_of_sql::{
83        base::database::{ColumnField, ColumnRef, ColumnType, TableRef, TestSchemaAccessor},
84        sql::{
85            proof_exprs::{AliasedDynProofExpr, ColumnExpr, DynProofExpr, TableExpr},
86            proof_plans::DynProofPlan,
87        },
88    };
89    use sqlparser::ast::Ident;
90    use std::hash::BuildHasherDefault;
91
92    const SUM: AggregateFunctionDefinition =
93        AggregateFunctionDefinition::BuiltIn(physical_plan::aggregates::AggregateFunction::Sum);
94    const COUNT: AggregateFunctionDefinition =
95        AggregateFunctionDefinition::BuiltIn(physical_plan::aggregates::AggregateFunction::Count);
96
97    #[expect(non_snake_case)]
98    fn TABLE_LANGUAGES() -> TableRef {
99        TableRef::from_names(None, "languages")
100    }
101
102    #[expect(non_snake_case)]
103    fn SCHEMAS() -> impl SchemaAccessor {
104        let schema: IndexMap<Ident, ColumnType, BuildHasherDefault<AHasher>> = indexmap_with_default! {
105            AHasher;
106            "name".into() => ColumnType::VarChar,
107            "language_family".into() => ColumnType::VarChar,
108            "uses_abjad".into() => ColumnType::Boolean,
109            "num_of_letters".into() => ColumnType::BigInt,
110            "grace".into() => ColumnType::VarChar,
111            "love".into() => ColumnType::VarChar,
112            "joy".into() => ColumnType::VarChar,
113            "peace".into() => ColumnType::VarChar
114        };
115        let table_ref = TableRef::new("", "languages");
116        let schema_accessor = indexmap_with_default! {
117            AHasher;
118            table_ref => schema
119        };
120        TestSchemaAccessor::new(schema_accessor)
121    }
122
123    #[expect(non_snake_case)]
124    fn TABLE_SOURCE() -> Arc<dyn TableSource> {
125        Arc::new(PoSqlTableSource::new(vec![
126            ColumnField::new("name".into(), ColumnType::VarChar),
127            ColumnField::new("language_family".into(), ColumnType::VarChar),
128            ColumnField::new("uses_abjad".into(), ColumnType::Boolean),
129            ColumnField::new("num_of_letters".into(), ColumnType::BigInt),
130            ColumnField::new("grace".into(), ColumnType::VarChar),
131            ColumnField::new("love".into(), ColumnType::VarChar),
132            ColumnField::new("joy".into(), ColumnType::VarChar),
133            ColumnField::new("peace".into(), ColumnType::VarChar),
134        ]))
135    }
136
137    #[expect(non_snake_case)]
138    fn ALIASED_NAME() -> AliasedDynProofExpr {
139        AliasedDynProofExpr {
140            expr: DynProofExpr::new_column(ColumnRef::new(
141                TABLE_LANGUAGES(),
142                "name".into(),
143                ColumnType::VarChar,
144            )),
145            alias: "name".into(),
146        }
147    }
148
149    #[expect(non_snake_case)]
150    fn ALIASED_GRACE() -> AliasedDynProofExpr {
151        AliasedDynProofExpr {
152            expr: DynProofExpr::new_column(ColumnRef::new(
153                TABLE_LANGUAGES(),
154                "grace".into(),
155                ColumnType::VarChar,
156            )),
157            alias: "grace".into(),
158        }
159    }
160
161    #[expect(non_snake_case)]
162    fn ALIASED_LOVE() -> AliasedDynProofExpr {
163        AliasedDynProofExpr {
164            expr: DynProofExpr::new_column(ColumnRef::new(
165                TABLE_LANGUAGES(),
166                "love".into(),
167                ColumnType::VarChar,
168            )),
169            alias: "love".into(),
170        }
171    }
172
173    #[expect(non_snake_case)]
174    fn ALIASED_JOY() -> AliasedDynProofExpr {
175        AliasedDynProofExpr {
176            expr: DynProofExpr::new_column(ColumnRef::new(
177                TABLE_LANGUAGES(),
178                "joy".into(),
179                ColumnType::VarChar,
180            )),
181            alias: "joy".into(),
182        }
183    }
184
185    #[expect(non_snake_case)]
186    fn COUNT_1() -> Expr {
187        Expr::AggregateFunction(AggregateFunction {
188            func_def: COUNT,
189            args: vec![Expr::Literal(ScalarValue::Int64(Some(1)))],
190            distinct: false,
191            filter: None,
192            order_by: None,
193            null_treatment: None,
194        })
195    }
196
197    #[expect(non_snake_case)]
198    fn SUM_NUM_LETTERS() -> Expr {
199        Expr::AggregateFunction(AggregateFunction {
200            func_def: SUM,
201            args: vec![df_column("languages", "num_of_letters")],
202            distinct: false,
203            filter: None,
204            order_by: None,
205            null_treatment: None,
206        })
207    }
208
209    #[expect(non_snake_case)]
210    fn ALIASED_PEACE() -> AliasedDynProofExpr {
211        AliasedDynProofExpr {
212            expr: DynProofExpr::new_column(ColumnRef::new(
213                TABLE_LANGUAGES(),
214                "peace".into(),
215                ColumnType::VarChar,
216            )),
217            alias: "peace".into(),
218        }
219    }
220
221    #[test]
222    fn we_can_convert_logical_plan_to_proof_plan_without_postprocessing() {
223        let plan = LogicalPlan::TableScan(
224            TableScan::try_new(
225                "languages",
226                TABLE_SOURCE(),
227                Some(vec![0, 4, 5, 6, 7]),
228                vec![],
229                None,
230            )
231            .unwrap(),
232        );
233        let schemas = SCHEMAS();
234        let result = logical_plan_to_proof_plan_with_postprocessing(&plan, &schemas).unwrap();
235        let expected = DynProofPlan::new_projection(
236            vec![
237                ALIASED_NAME(),
238                ALIASED_GRACE(),
239                ALIASED_LOVE(),
240                ALIASED_JOY(),
241                ALIASED_PEACE(),
242            ],
243            DynProofPlan::new_table(
244                TABLE_LANGUAGES(),
245                vec![
246                    ColumnField::new("name".into(), ColumnType::VarChar),
247                    ColumnField::new("language_family".into(), ColumnType::VarChar),
248                    ColumnField::new("uses_abjad".into(), ColumnType::Boolean),
249                    ColumnField::new("num_of_letters".into(), ColumnType::BigInt),
250                    ColumnField::new("grace".into(), ColumnType::VarChar),
251                    ColumnField::new("love".into(), ColumnType::VarChar),
252                    ColumnField::new("joy".into(), ColumnType::VarChar),
253                    ColumnField::new("peace".into(), ColumnType::VarChar),
254                ],
255            ),
256        );
257        assert_eq!(result.plan(), &expected);
258        assert!(result.postprocessing().is_none());
259    }
260
261    #[test]
262    fn we_can_convert_logical_plan_to_proof_plan_with_postprocessing() {
263        // Setup group expression
264        let group_expr = vec![df_column("languages", "language_family")];
265
266        // Create the aggregate expressions
267        let aggr_expr = vec![
268            SUM_NUM_LETTERS(), // SUM
269            COUNT_1(),         // COUNT
270        ];
271
272        // Create filters
273        let filter_exprs = vec![
274            df_column("languages", "uses_abjad"), // Boolean column as filter
275        ];
276
277        // Create the input plan with filters
278        let input_plan = LogicalPlan::TableScan(
279            TableScan::try_new(
280                "languages",
281                TABLE_SOURCE(),
282                Some(vec![1, 2, 3]),
283                filter_exprs,
284                None,
285            )
286            .unwrap(),
287        );
288
289        let agg_plan = LogicalPlan::Aggregate(
290            Aggregate::try_new(Arc::new(input_plan), group_expr.clone(), aggr_expr.clone())
291                .unwrap(),
292        );
293
294        let proj_plan = LogicalPlan::Projection(
295            Projection::try_new(
296                vec![
297                    df_column("languages", "language_family"),
298                    Expr::Column(Column::new(
299                        None::<TableReference>,
300                        "COUNT(Int64(1))".to_string(),
301                    ))
302                    .mul(Expr::Literal(ScalarValue::Int64(Some(2))))
303                    .alias("twice_num_languages_using_abjad"),
304                    Expr::Column(Column::new(
305                        None::<TableReference>,
306                        "SUM(languages.num_of_letters)".to_string(),
307                    ))
308                    .alias("sum_num_of_letters"),
309                ],
310                Arc::new(agg_plan),
311            )
312            .unwrap(),
313        );
314
315        // Test the function
316        let result =
317            logical_plan_to_proof_plan_with_postprocessing(&proj_plan, &SCHEMAS()).unwrap();
318
319        // Expected result
320        let expected_plan = DynProofPlan::new_group_by(
321            vec![ColumnExpr::new(ColumnRef::new(
322                TABLE_LANGUAGES(),
323                "language_family".into(),
324                ColumnType::VarChar,
325            ))],
326            vec![AliasedDynProofExpr {
327                expr: DynProofExpr::new_column(ColumnRef::new(
328                    TABLE_LANGUAGES(),
329                    "num_of_letters".into(),
330                    ColumnType::BigInt,
331                )),
332                alias: "SUM(languages.num_of_letters)".into(),
333            }],
334            "COUNT(Int64(1))".into(),
335            TableExpr {
336                table_ref: TABLE_LANGUAGES(),
337            },
338            DynProofExpr::new_column(ColumnRef::new(
339                TABLE_LANGUAGES(),
340                "uses_abjad".into(),
341                ColumnType::Boolean,
342            )),
343        );
344
345        let expected_postprocessing = SelectPostprocessing::new(vec![
346            df_column("languages", "language_family"),
347            Expr::Column(Column::new(
348                None::<TableReference>,
349                "COUNT(Int64(1))".to_string(),
350            ))
351            .mul(Expr::Literal(ScalarValue::Int64(Some(2))))
352            .alias("twice_num_languages_using_abjad"),
353            Expr::Column(Column::new(
354                None::<TableReference>,
355                "SUM(languages.num_of_letters)".to_string(),
356            ))
357            .alias("sum_num_of_letters"),
358        ]);
359
360        assert_eq!(result.plan(), &expected_plan);
361        assert_eq!(result.postprocessing().unwrap(), &expected_postprocessing);
362    }
363
364    // Unsupported
365    #[test]
366    fn we_cannot_convert_unsupported_logical_plan_to_proof_plan_with_postprocessing() {
367        let plan = LogicalPlan::Prepare(Prepare {
368            name: "not_a_real_plan".to_string(),
369            data_types: vec![],
370            input: Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
371                produce_one_row: false,
372                schema: Arc::new(DFSchema::empty()),
373            })),
374        });
375        let schemas = SCHEMAS();
376        assert!(matches!(
377            logical_plan_to_proof_plan_with_postprocessing(&plan, &schemas),
378            Err(PlannerError::UnsupportedLogicalPlan { .. })
379        ));
380    }
381}