1use super::{
2 logical_plan_to_proof_plan, postprocessing::SelectPostprocessing, PlannerError, PlannerResult,
3};
4use datafusion::logical_expr::{LogicalPlan, Projection};
5use proof_of_sql::{base::database::SchemaAccessor, sql::proof_plans::DynProofPlan};
6
7#[derive(Debug, Clone)]
9pub struct ProofPlanWithPostprocessing {
10 plan: DynProofPlan,
11 postprocessing: Option<SelectPostprocessing>,
12}
13
14impl ProofPlanWithPostprocessing {
15 #[must_use]
17 pub fn new(plan: DynProofPlan, postprocessing: Option<SelectPostprocessing>) -> Self {
18 Self {
19 plan,
20 postprocessing,
21 }
22 }
23
24 #[must_use]
26 pub fn plan(&self) -> &DynProofPlan {
27 &self.plan
28 }
29
30 #[must_use]
32 pub fn postprocessing(&self) -> Option<&SelectPostprocessing> {
33 self.postprocessing.as_ref()
34 }
35}
36
37pub fn logical_plan_to_proof_plan_with_postprocessing(
39 plan: &LogicalPlan,
40 schemas: &impl SchemaAccessor,
41) -> PlannerResult<ProofPlanWithPostprocessing> {
42 let result_proof_plan = logical_plan_to_proof_plan(plan, schemas);
43 match result_proof_plan {
44 Ok(proof_plan) => Ok(ProofPlanWithPostprocessing::new(proof_plan, None)),
45 Err(_err) => {
46 match plan {
47 LogicalPlan::Projection(Projection { input, expr, .. }) => {
49 let input_proof_plan = logical_plan_to_proof_plan(input, schemas)?;
51 let postprocessing = SelectPostprocessing::new(expr.clone());
52 Ok(ProofPlanWithPostprocessing::new(
53 input_proof_plan,
54 Some(postprocessing),
55 ))
56 }
57 _ => Err(PlannerError::UnsupportedLogicalPlan { plan: plan.clone() }),
58 }
59 }
60 }
61}
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66 use crate::{df_util::*, PoSqlTableSource};
67 use ahash::AHasher;
68 use alloc::sync::Arc;
69 use core::ops::Mul;
70 use datafusion::{
71 common::{Column, DFSchema, ScalarValue},
72 logical_expr::{
73 expr::{AggregateFunction, AggregateFunctionDefinition},
74 Aggregate, EmptyRelation, Expr, LogicalPlan, Prepare, TableScan, TableSource,
75 },
76 physical_plan,
77 sql::TableReference,
78 };
79 use indexmap::{indexmap_with_default, IndexMap};
80 use proof_of_sql::{
81 base::database::{ColumnField, ColumnRef, ColumnType, TableRef, TestSchemaAccessor},
82 sql::{
83 proof_exprs::{AliasedDynProofExpr, ColumnExpr, DynProofExpr, TableExpr},
84 proof_plans::DynProofPlan,
85 },
86 };
87 use sqlparser::ast::Ident;
88 use std::hash::BuildHasherDefault;
89
90 const SUM: AggregateFunctionDefinition =
91 AggregateFunctionDefinition::BuiltIn(physical_plan::aggregates::AggregateFunction::Sum);
92 const COUNT: AggregateFunctionDefinition =
93 AggregateFunctionDefinition::BuiltIn(physical_plan::aggregates::AggregateFunction::Count);
94
95 #[expect(non_snake_case)]
96 fn TABLE_LANGUAGES() -> TableRef {
97 TableRef::from_names(None, "languages")
98 }
99
100 #[expect(non_snake_case)]
101 fn SCHEMAS() -> impl SchemaAccessor {
102 let schema: IndexMap<Ident, ColumnType, BuildHasherDefault<AHasher>> = indexmap_with_default! {
103 AHasher;
104 "name".into() => ColumnType::VarChar,
105 "language_family".into() => ColumnType::VarChar,
106 "uses_abjad".into() => ColumnType::Boolean,
107 "num_of_letters".into() => ColumnType::BigInt,
108 "grace".into() => ColumnType::VarChar,
109 "love".into() => ColumnType::VarChar,
110 "joy".into() => ColumnType::VarChar,
111 "peace".into() => ColumnType::VarChar
112 };
113 let table_ref = TableRef::new("", "languages");
114 let schema_accessor = indexmap_with_default! {
115 AHasher;
116 table_ref => schema
117 };
118 TestSchemaAccessor::new(schema_accessor)
119 }
120
121 #[expect(non_snake_case)]
122 fn TABLE_SOURCE() -> Arc<dyn TableSource> {
123 Arc::new(PoSqlTableSource::new(vec![
124 ColumnField::new("name".into(), ColumnType::VarChar),
125 ColumnField::new("language_family".into(), ColumnType::VarChar),
126 ColumnField::new("uses_abjad".into(), ColumnType::Boolean),
127 ColumnField::new("num_of_letters".into(), ColumnType::BigInt),
128 ColumnField::new("grace".into(), ColumnType::VarChar),
129 ColumnField::new("love".into(), ColumnType::VarChar),
130 ColumnField::new("joy".into(), ColumnType::VarChar),
131 ColumnField::new("peace".into(), ColumnType::VarChar),
132 ]))
133 }
134
135 #[expect(non_snake_case)]
136 fn ALIASED_NAME() -> AliasedDynProofExpr {
137 AliasedDynProofExpr {
138 expr: DynProofExpr::new_column(ColumnRef::new(
139 TABLE_LANGUAGES(),
140 "name".into(),
141 ColumnType::VarChar,
142 )),
143 alias: "name".into(),
144 }
145 }
146
147 #[expect(non_snake_case)]
148 fn ALIASED_GRACE() -> AliasedDynProofExpr {
149 AliasedDynProofExpr {
150 expr: DynProofExpr::new_column(ColumnRef::new(
151 TABLE_LANGUAGES(),
152 "grace".into(),
153 ColumnType::VarChar,
154 )),
155 alias: "grace".into(),
156 }
157 }
158
159 #[expect(non_snake_case)]
160 fn ALIASED_LOVE() -> AliasedDynProofExpr {
161 AliasedDynProofExpr {
162 expr: DynProofExpr::new_column(ColumnRef::new(
163 TABLE_LANGUAGES(),
164 "love".into(),
165 ColumnType::VarChar,
166 )),
167 alias: "love".into(),
168 }
169 }
170
171 #[expect(non_snake_case)]
172 fn ALIASED_JOY() -> AliasedDynProofExpr {
173 AliasedDynProofExpr {
174 expr: DynProofExpr::new_column(ColumnRef::new(
175 TABLE_LANGUAGES(),
176 "joy".into(),
177 ColumnType::VarChar,
178 )),
179 alias: "joy".into(),
180 }
181 }
182
183 #[expect(non_snake_case)]
184 fn COUNT_1() -> Expr {
185 Expr::AggregateFunction(AggregateFunction {
186 func_def: COUNT,
187 args: vec![Expr::Literal(ScalarValue::Int64(Some(1)))],
188 distinct: false,
189 filter: None,
190 order_by: None,
191 null_treatment: None,
192 })
193 }
194
195 #[expect(non_snake_case)]
196 fn SUM_NUM_LETTERS() -> Expr {
197 Expr::AggregateFunction(AggregateFunction {
198 func_def: SUM,
199 args: vec![df_column("languages", "num_of_letters")],
200 distinct: false,
201 filter: None,
202 order_by: None,
203 null_treatment: None,
204 })
205 }
206
207 #[expect(non_snake_case)]
208 fn ALIASED_PEACE() -> AliasedDynProofExpr {
209 AliasedDynProofExpr {
210 expr: DynProofExpr::new_column(ColumnRef::new(
211 TABLE_LANGUAGES(),
212 "peace".into(),
213 ColumnType::VarChar,
214 )),
215 alias: "peace".into(),
216 }
217 }
218
219 #[test]
220 fn we_can_convert_logical_plan_to_proof_plan_without_postprocessing() {
221 let plan = LogicalPlan::TableScan(
222 TableScan::try_new(
223 "languages",
224 TABLE_SOURCE(),
225 Some(vec![0, 4, 5, 6, 7]),
226 vec![],
227 None,
228 )
229 .unwrap(),
230 );
231 let schemas = SCHEMAS();
232 let result = logical_plan_to_proof_plan_with_postprocessing(&plan, &schemas).unwrap();
233 let expected = DynProofPlan::new_projection(
234 vec![
235 ALIASED_NAME(),
236 ALIASED_GRACE(),
237 ALIASED_LOVE(),
238 ALIASED_JOY(),
239 ALIASED_PEACE(),
240 ],
241 DynProofPlan::new_table(
242 TABLE_LANGUAGES(),
243 vec![
244 ColumnField::new("name".into(), ColumnType::VarChar),
245 ColumnField::new("language_family".into(), ColumnType::VarChar),
246 ColumnField::new("uses_abjad".into(), ColumnType::Boolean),
247 ColumnField::new("num_of_letters".into(), ColumnType::BigInt),
248 ColumnField::new("grace".into(), ColumnType::VarChar),
249 ColumnField::new("love".into(), ColumnType::VarChar),
250 ColumnField::new("joy".into(), ColumnType::VarChar),
251 ColumnField::new("peace".into(), ColumnType::VarChar),
252 ],
253 ),
254 );
255 assert_eq!(result.plan(), &expected);
256 assert!(result.postprocessing().is_none());
257 }
258
259 #[test]
260 fn we_can_convert_logical_plan_to_proof_plan_with_postprocessing() {
261 let group_expr = vec![df_column("languages", "language_family")];
263
264 let aggr_expr = vec![
266 SUM_NUM_LETTERS(), COUNT_1(), ];
269
270 let filter_exprs = vec![
272 df_column("languages", "uses_abjad"), ];
274
275 let input_plan = LogicalPlan::TableScan(
277 TableScan::try_new(
278 "languages",
279 TABLE_SOURCE(),
280 Some(vec![1, 2, 3]),
281 filter_exprs,
282 None,
283 )
284 .unwrap(),
285 );
286
287 let agg_plan = LogicalPlan::Aggregate(
288 Aggregate::try_new(Arc::new(input_plan), group_expr.clone(), aggr_expr.clone())
289 .unwrap(),
290 );
291
292 let proj_plan = LogicalPlan::Projection(
293 Projection::try_new(
294 vec![
295 df_column("languages", "language_family"),
296 Expr::Column(Column::new(
297 None::<TableReference>,
298 "COUNT(Int64(1))".to_string(),
299 ))
300 .mul(Expr::Literal(ScalarValue::Int64(Some(2))))
301 .alias("twice_num_languages_using_abjad"),
302 Expr::Column(Column::new(
303 None::<TableReference>,
304 "SUM(languages.num_of_letters)".to_string(),
305 ))
306 .alias("sum_num_of_letters"),
307 ],
308 Arc::new(agg_plan),
309 )
310 .unwrap(),
311 );
312
313 let result =
315 logical_plan_to_proof_plan_with_postprocessing(&proj_plan, &SCHEMAS()).unwrap();
316
317 let expected_plan = DynProofPlan::new_group_by(
319 vec![ColumnExpr::new(ColumnRef::new(
320 TABLE_LANGUAGES(),
321 "language_family".into(),
322 ColumnType::VarChar,
323 ))],
324 vec![AliasedDynProofExpr {
325 expr: DynProofExpr::new_column(ColumnRef::new(
326 TABLE_LANGUAGES(),
327 "num_of_letters".into(),
328 ColumnType::BigInt,
329 )),
330 alias: "SUM(languages.num_of_letters)".into(),
331 }],
332 "COUNT(Int64(1))".into(),
333 TableExpr {
334 table_ref: TABLE_LANGUAGES(),
335 },
336 DynProofExpr::new_column(ColumnRef::new(
337 TABLE_LANGUAGES(),
338 "uses_abjad".into(),
339 ColumnType::Boolean,
340 )),
341 );
342
343 let expected_postprocessing = SelectPostprocessing::new(vec![
344 df_column("languages", "language_family"),
345 Expr::Column(Column::new(
346 None::<TableReference>,
347 "COUNT(Int64(1))".to_string(),
348 ))
349 .mul(Expr::Literal(ScalarValue::Int64(Some(2))))
350 .alias("twice_num_languages_using_abjad"),
351 Expr::Column(Column::new(
352 None::<TableReference>,
353 "SUM(languages.num_of_letters)".to_string(),
354 ))
355 .alias("sum_num_of_letters"),
356 ]);
357
358 assert_eq!(result.plan(), &expected_plan);
359 assert_eq!(result.postprocessing().unwrap(), &expected_postprocessing);
360 }
361
362 #[test]
364 fn we_cannot_convert_unsupported_logical_plan_to_proof_plan_with_postprocessing() {
365 let plan = LogicalPlan::Prepare(Prepare {
366 name: "not_a_real_plan".to_string(),
367 data_types: vec![],
368 input: Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
369 produce_one_row: false,
370 schema: Arc::new(DFSchema::empty()),
371 })),
372 });
373 let schemas = SCHEMAS();
374 assert!(matches!(
375 logical_plan_to_proof_plan_with_postprocessing(&plan, &schemas),
376 Err(PlannerError::UnsupportedLogicalPlan { .. })
377 ));
378 }
379}