1use super::{
2 logical_plan_to_proof_plan, postprocessing::SelectPostprocessing, PlannerError, PlannerResult,
3};
4use datafusion::logical_expr::{LogicalPlan, Projection};
5use proof_of_sql::{base::database::SchemaAccessor, sql::proof_plans::DynProofPlan};
6
7#[derive(Debug, Clone)]
9pub struct ProofPlanWithPostprocessing {
10 plan: DynProofPlan,
11 postprocessing: Option<SelectPostprocessing>,
12}
13
14impl ProofPlanWithPostprocessing {
15 #[must_use]
17 pub fn new(plan: DynProofPlan, postprocessing: Option<SelectPostprocessing>) -> Self {
18 Self {
19 plan,
20 postprocessing,
21 }
22 }
23
24 #[must_use]
26 pub fn plan(&self) -> &DynProofPlan {
27 &self.plan
28 }
29
30 #[must_use]
32 pub fn postprocessing(&self) -> Option<&SelectPostprocessing> {
33 self.postprocessing.as_ref()
34 }
35}
36
37pub fn logical_plan_to_proof_plan_with_postprocessing(
39 plan: &LogicalPlan,
40 schemas: &impl SchemaAccessor,
41) -> PlannerResult<ProofPlanWithPostprocessing> {
42 let result_proof_plan = logical_plan_to_proof_plan(plan, schemas);
43 match result_proof_plan {
44 Ok(proof_plan) => Ok(ProofPlanWithPostprocessing::new(proof_plan, None)),
45 Err(_err) => {
46 match plan {
47 LogicalPlan::Projection(Projection { input, expr, .. }) => {
49 let input_proof_plan = logical_plan_to_proof_plan(input, schemas)?;
51 let postprocessing = SelectPostprocessing::new(expr.clone());
52 Ok(ProofPlanWithPostprocessing::new(
53 input_proof_plan,
54 Some(postprocessing),
55 ))
56 }
57 _ => Err(PlannerError::UnsupportedLogicalPlan {
58 plan: Box::new(plan.clone()),
59 }),
60 }
61 }
62 }
63}
64
65#[cfg(test)]
66mod tests {
67 use super::*;
68 use crate::{df_util::*, PoSqlTableSource};
69 use ahash::AHasher;
70 use alloc::sync::Arc;
71 use core::ops::Mul;
72 use datafusion::{
73 common::{Column, DFSchema, ScalarValue},
74 logical_expr::{
75 expr::{AggregateFunction, AggregateFunctionDefinition},
76 Aggregate, EmptyRelation, Expr, LogicalPlan, Prepare, TableScan, TableSource,
77 },
78 physical_plan,
79 sql::TableReference,
80 };
81 use indexmap::{indexmap_with_default, IndexMap};
82 use proof_of_sql::{
83 base::database::{ColumnField, ColumnRef, ColumnType, TableRef, TestSchemaAccessor},
84 sql::{
85 proof_exprs::{AliasedDynProofExpr, ColumnExpr, DynProofExpr, TableExpr},
86 proof_plans::DynProofPlan,
87 },
88 };
89 use sqlparser::ast::Ident;
90 use std::hash::BuildHasherDefault;
91
92 const SUM: AggregateFunctionDefinition =
93 AggregateFunctionDefinition::BuiltIn(physical_plan::aggregates::AggregateFunction::Sum);
94 const COUNT: AggregateFunctionDefinition =
95 AggregateFunctionDefinition::BuiltIn(physical_plan::aggregates::AggregateFunction::Count);
96
97 #[expect(non_snake_case)]
98 fn TABLE_LANGUAGES() -> TableRef {
99 TableRef::from_names(None, "languages")
100 }
101
102 #[expect(non_snake_case)]
103 fn SCHEMAS() -> impl SchemaAccessor {
104 let schema: IndexMap<Ident, ColumnType, BuildHasherDefault<AHasher>> = indexmap_with_default! {
105 AHasher;
106 "name".into() => ColumnType::VarChar,
107 "language_family".into() => ColumnType::VarChar,
108 "uses_abjad".into() => ColumnType::Boolean,
109 "num_of_letters".into() => ColumnType::BigInt,
110 "grace".into() => ColumnType::VarChar,
111 "love".into() => ColumnType::VarChar,
112 "joy".into() => ColumnType::VarChar,
113 "peace".into() => ColumnType::VarChar
114 };
115 let table_ref = TableRef::new("", "languages");
116 let schema_accessor = indexmap_with_default! {
117 AHasher;
118 table_ref => schema
119 };
120 TestSchemaAccessor::new(schema_accessor)
121 }
122
123 #[expect(non_snake_case)]
124 fn TABLE_SOURCE() -> Arc<dyn TableSource> {
125 Arc::new(PoSqlTableSource::new(vec![
126 ColumnField::new("name".into(), ColumnType::VarChar),
127 ColumnField::new("language_family".into(), ColumnType::VarChar),
128 ColumnField::new("uses_abjad".into(), ColumnType::Boolean),
129 ColumnField::new("num_of_letters".into(), ColumnType::BigInt),
130 ColumnField::new("grace".into(), ColumnType::VarChar),
131 ColumnField::new("love".into(), ColumnType::VarChar),
132 ColumnField::new("joy".into(), ColumnType::VarChar),
133 ColumnField::new("peace".into(), ColumnType::VarChar),
134 ]))
135 }
136
137 #[expect(non_snake_case)]
138 fn ALIASED_NAME() -> AliasedDynProofExpr {
139 AliasedDynProofExpr {
140 expr: DynProofExpr::new_column(ColumnRef::new(
141 TABLE_LANGUAGES(),
142 "name".into(),
143 ColumnType::VarChar,
144 )),
145 alias: "name".into(),
146 }
147 }
148
149 #[expect(non_snake_case)]
150 fn ALIASED_GRACE() -> AliasedDynProofExpr {
151 AliasedDynProofExpr {
152 expr: DynProofExpr::new_column(ColumnRef::new(
153 TABLE_LANGUAGES(),
154 "grace".into(),
155 ColumnType::VarChar,
156 )),
157 alias: "grace".into(),
158 }
159 }
160
161 #[expect(non_snake_case)]
162 fn ALIASED_LOVE() -> AliasedDynProofExpr {
163 AliasedDynProofExpr {
164 expr: DynProofExpr::new_column(ColumnRef::new(
165 TABLE_LANGUAGES(),
166 "love".into(),
167 ColumnType::VarChar,
168 )),
169 alias: "love".into(),
170 }
171 }
172
173 #[expect(non_snake_case)]
174 fn ALIASED_JOY() -> AliasedDynProofExpr {
175 AliasedDynProofExpr {
176 expr: DynProofExpr::new_column(ColumnRef::new(
177 TABLE_LANGUAGES(),
178 "joy".into(),
179 ColumnType::VarChar,
180 )),
181 alias: "joy".into(),
182 }
183 }
184
185 #[expect(non_snake_case)]
186 fn COUNT_1() -> Expr {
187 Expr::AggregateFunction(AggregateFunction {
188 func_def: COUNT,
189 args: vec![Expr::Literal(ScalarValue::Int64(Some(1)))],
190 distinct: false,
191 filter: None,
192 order_by: None,
193 null_treatment: None,
194 })
195 }
196
197 #[expect(non_snake_case)]
198 fn SUM_NUM_LETTERS() -> Expr {
199 Expr::AggregateFunction(AggregateFunction {
200 func_def: SUM,
201 args: vec![df_column("languages", "num_of_letters")],
202 distinct: false,
203 filter: None,
204 order_by: None,
205 null_treatment: None,
206 })
207 }
208
209 #[expect(non_snake_case)]
210 fn ALIASED_PEACE() -> AliasedDynProofExpr {
211 AliasedDynProofExpr {
212 expr: DynProofExpr::new_column(ColumnRef::new(
213 TABLE_LANGUAGES(),
214 "peace".into(),
215 ColumnType::VarChar,
216 )),
217 alias: "peace".into(),
218 }
219 }
220
221 #[test]
222 fn we_can_convert_logical_plan_to_proof_plan_without_postprocessing() {
223 let plan = LogicalPlan::TableScan(
224 TableScan::try_new(
225 "languages",
226 TABLE_SOURCE(),
227 Some(vec![0, 4, 5, 6, 7]),
228 vec![],
229 None,
230 )
231 .unwrap(),
232 );
233 let schemas = SCHEMAS();
234 let result = logical_plan_to_proof_plan_with_postprocessing(&plan, &schemas).unwrap();
235 let expected = DynProofPlan::new_projection(
236 vec![
237 ALIASED_NAME(),
238 ALIASED_GRACE(),
239 ALIASED_LOVE(),
240 ALIASED_JOY(),
241 ALIASED_PEACE(),
242 ],
243 DynProofPlan::new_table(
244 TABLE_LANGUAGES(),
245 vec![
246 ColumnField::new("name".into(), ColumnType::VarChar),
247 ColumnField::new("language_family".into(), ColumnType::VarChar),
248 ColumnField::new("uses_abjad".into(), ColumnType::Boolean),
249 ColumnField::new("num_of_letters".into(), ColumnType::BigInt),
250 ColumnField::new("grace".into(), ColumnType::VarChar),
251 ColumnField::new("love".into(), ColumnType::VarChar),
252 ColumnField::new("joy".into(), ColumnType::VarChar),
253 ColumnField::new("peace".into(), ColumnType::VarChar),
254 ],
255 ),
256 );
257 assert_eq!(result.plan(), &expected);
258 assert!(result.postprocessing().is_none());
259 }
260
261 #[test]
262 fn we_can_convert_logical_plan_to_proof_plan_with_postprocessing() {
263 let group_expr = vec![df_column("languages", "language_family")];
265
266 let aggr_expr = vec![
268 SUM_NUM_LETTERS(), COUNT_1(), ];
271
272 let filter_exprs = vec![
274 df_column("languages", "uses_abjad"), ];
276
277 let input_plan = LogicalPlan::TableScan(
279 TableScan::try_new(
280 "languages",
281 TABLE_SOURCE(),
282 Some(vec![1, 2, 3]),
283 filter_exprs,
284 None,
285 )
286 .unwrap(),
287 );
288
289 let agg_plan = LogicalPlan::Aggregate(
290 Aggregate::try_new(Arc::new(input_plan), group_expr.clone(), aggr_expr.clone())
291 .unwrap(),
292 );
293
294 let proj_plan = LogicalPlan::Projection(
295 Projection::try_new(
296 vec![
297 df_column("languages", "language_family"),
298 Expr::Column(Column::new(
299 None::<TableReference>,
300 "COUNT(Int64(1))".to_string(),
301 ))
302 .mul(Expr::Literal(ScalarValue::Int64(Some(2))))
303 .alias("twice_num_languages_using_abjad"),
304 Expr::Column(Column::new(
305 None::<TableReference>,
306 "SUM(languages.num_of_letters)".to_string(),
307 ))
308 .alias("sum_num_of_letters"),
309 ],
310 Arc::new(agg_plan),
311 )
312 .unwrap(),
313 );
314
315 let result =
317 logical_plan_to_proof_plan_with_postprocessing(&proj_plan, &SCHEMAS()).unwrap();
318
319 let expected_plan = DynProofPlan::new_group_by(
321 vec![ColumnExpr::new(ColumnRef::new(
322 TABLE_LANGUAGES(),
323 "language_family".into(),
324 ColumnType::VarChar,
325 ))],
326 vec![AliasedDynProofExpr {
327 expr: DynProofExpr::new_column(ColumnRef::new(
328 TABLE_LANGUAGES(),
329 "num_of_letters".into(),
330 ColumnType::BigInt,
331 )),
332 alias: "SUM(languages.num_of_letters)".into(),
333 }],
334 "COUNT(Int64(1))".into(),
335 TableExpr {
336 table_ref: TABLE_LANGUAGES(),
337 },
338 DynProofExpr::new_column(ColumnRef::new(
339 TABLE_LANGUAGES(),
340 "uses_abjad".into(),
341 ColumnType::Boolean,
342 )),
343 );
344
345 let expected_postprocessing = SelectPostprocessing::new(vec![
346 df_column("languages", "language_family"),
347 Expr::Column(Column::new(
348 None::<TableReference>,
349 "COUNT(Int64(1))".to_string(),
350 ))
351 .mul(Expr::Literal(ScalarValue::Int64(Some(2))))
352 .alias("twice_num_languages_using_abjad"),
353 Expr::Column(Column::new(
354 None::<TableReference>,
355 "SUM(languages.num_of_letters)".to_string(),
356 ))
357 .alias("sum_num_of_letters"),
358 ]);
359
360 assert_eq!(result.plan(), &expected_plan);
361 assert_eq!(result.postprocessing().unwrap(), &expected_postprocessing);
362 }
363
364 #[test]
366 fn we_cannot_convert_unsupported_logical_plan_to_proof_plan_with_postprocessing() {
367 let plan = LogicalPlan::Prepare(Prepare {
368 name: "not_a_real_plan".to_string(),
369 data_types: vec![],
370 input: Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
371 produce_one_row: false,
372 schema: Arc::new(DFSchema::empty()),
373 })),
374 });
375 let schemas = SCHEMAS();
376 assert!(matches!(
377 logical_plan_to_proof_plan_with_postprocessing(&plan, &schemas),
378 Err(PlannerError::UnsupportedLogicalPlan { .. })
379 ));
380 }
381}