1use crate::expr::Alias;
21use crate::expr_rewriter::normalize_col;
22use crate::{Cast, Expr, LogicalPlan, TryCast, expr::Sort};
23
24use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
25use datafusion_common::{Column, Result};
26
27pub fn rewrite_sort_cols_by_aggs(
30 sorts: impl IntoIterator<Item = impl Into<Sort>>,
31 plan: &LogicalPlan,
32) -> Result<Vec<Sort>> {
33 sorts
34 .into_iter()
35 .map(|e| {
36 let sort = e.into();
37 Ok(Sort::new(
38 rewrite_sort_col_by_aggs(sort.expr, plan)?,
39 sort.asc,
40 sort.nulls_first,
41 ))
42 })
43 .collect()
44}
45
46fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
47 let plan_inputs = plan.inputs();
48
49 if plan_inputs.len() == 1 {
52 let proj_exprs = plan.expressions();
53 rewrite_in_terms_of_projection(expr, &proj_exprs, plan_inputs[0])
54 } else {
55 Ok(expr)
56 }
57}
58
59fn rewrite_in_terms_of_projection(
71 expr: Expr,
72 proj_exprs: &[Expr],
73 input: &LogicalPlan,
74) -> Result<Expr> {
75 expr.transform(|expr| {
78 if let Some(found) = proj_exprs.iter().find(|a| expr_match(&expr, a)) {
82 let (qualifier, field_name) = found.qualified_name();
83 let col = Expr::Column(Column::new(qualifier, field_name));
84 return Ok(Transformed::yes(col));
85 }
86
87 let normalized_expr = if let Ok(e) = normalize_col(expr.clone(), input) {
93 e
94 } else {
95 return Ok(Transformed::no(expr));
97 };
98
99 let name = normalized_expr.schema_name().to_string();
102
103 let search_col = Expr::Column(Column::new_unqualified(name));
104
105 let found = proj_exprs
110 .iter()
111 .find(|proj_expr| expr_match(&search_col, proj_expr));
112
113 if let Some(found) = found {
114 let (qualifier, field_name) = found.qualified_name();
115 let col = Expr::Column(Column::new(qualifier, field_name));
116 return Ok(Transformed::yes(match normalized_expr {
117 Expr::Cast(Cast { expr: _, field }) => Expr::Cast(Cast {
118 expr: Box::new(col),
119 field,
120 }),
121 Expr::TryCast(TryCast { expr: _, field }) => Expr::TryCast(TryCast {
122 expr: Box::new(col),
123 field,
124 }),
125 _ => col,
126 }));
127 }
128
129 Ok(Transformed::no(expr))
130 })
131 .data()
132}
133
134fn expr_match(needle: &Expr, expr: &Expr) -> bool {
137 if let Expr::Alias(Alias { expr, .. }) = &expr {
139 expr.as_ref() == needle
140 } else {
141 expr == needle
142 }
143}
144
145#[cfg(test)]
146mod test {
147 use std::ops::Add;
148 use std::sync::Arc;
149
150 use arrow::datatypes::{DataType, Field, Schema};
151
152 use crate::{
153 LogicalPlanBuilder, cast, col, lit, logical_plan::builder::LogicalTableSource,
154 try_cast,
155 };
156
157 use super::*;
158 use crate::test::function_stub::avg;
159 use crate::test::function_stub::count;
160 use crate::test::function_stub::max;
161 use crate::test::function_stub::min;
162 use crate::test::function_stub::sum;
163
164 #[test]
165 fn rewrite_sort_cols_by_agg() {
166 let agg = make_input()
168 .aggregate(
169 vec![col("c1")],
171 vec![min(col("c2"))],
173 )
174 .unwrap()
175 .build()
176 .unwrap();
177
178 let cases = vec![
179 TestCase {
180 desc: "c1 --> c1",
181 input: sort(col("c1")),
182 expected: sort(col("c1")),
183 },
184 TestCase {
185 desc: "c1 + c2 --> c1 + c2",
186 input: sort(col("c1") + col("c1")),
187 expected: sort(col("c1") + col("c1")),
188 },
189 TestCase {
190 desc: r#"min(c2) --> "min(c2)"#,
191 input: sort(min(col("c2"))),
192 expected: sort(min(col("c2"))),
193 },
194 TestCase {
195 desc: r#"c1 + min(c2) --> "c1 + min(c2)"#,
196 input: sort(col("c1") + min(col("c2"))),
197 expected: sort(col("c1") + min(col("c2"))),
198 },
199 ];
200
201 for case in cases {
202 case.run(&agg)
203 }
204 }
205
206 #[test]
207 fn rewrite_sort_cols_by_agg_alias() {
208 let agg = make_input()
209 .aggregate(
210 vec![col("c1")],
212 vec![min(col("c2")), avg(col("c3"))],
214 )
215 .unwrap()
216 .project(vec![
218 col("c1").add(lit(1)).alias("c1"),
220 min(col("c2")),
222 avg(col("c3")).alias("average"),
224 ])
225 .unwrap()
226 .build()
227 .unwrap();
228
229 let cases = vec![
230 TestCase {
231 desc: "c1 --> c1 -- column *named* c1 that came out of the projection, (not t.c1)",
232 input: sort(col("c1")),
233 expected: sort(col("c1")),
235 },
236 TestCase {
237 desc: r#"min(c2) --> "min(c2)" -- (column *named* "min(t.c2)"!)"#,
238 input: sort(min(col("c2"))),
239 expected: sort(Expr::Column(Column::new_unqualified("min(t.c2)"))),
240 },
241 TestCase {
242 desc: r#"c1 + min(c2) --> "c1 + min(c2)" -- (column *named* "min(t.c2)"!)"#,
243 input: sort(col("c1") + min(col("c2"))),
244 expected: sort(
245 col("c1") + Expr::Column(Column::new_unqualified("min(t.c2)")),
246 ),
247 },
248 TestCase {
249 desc: r#"avg(c3) --> "average" (column *named* "average", from alias)"#,
250 input: sort(avg(col("c3"))),
251 expected: sort(col("average")),
252 },
253 ];
254
255 for case in cases {
256 case.run(&agg)
257 }
258 }
259
260 #[test]
265 fn rewrite_sort_resolves_alias_to_column_ref() {
266 let plan = make_input()
267 .aggregate(vec![col("c1")], vec![min(col("c2")), max(col("c3"))])
268 .unwrap()
269 .project(vec![
270 col("c1"),
271 min(col("c2")).alias("min_val"),
272 max(col("c3")).alias("max_val"),
273 ])
274 .unwrap()
275 .build()
276 .unwrap();
277
278 let cases = vec![
279 TestCase {
280 desc: "min(c2) with alias 'min_val' should resolve to col(min_val)",
281 input: sort(min(col("c2"))),
282 expected: sort(col("min_val")),
283 },
284 TestCase {
285 desc: "max(c3) with alias 'max_val' should resolve to col(max_val)",
286 input: sort(max(col("c3"))),
287 expected: sort(col("max_val")),
288 },
289 ];
290
291 for case in cases {
292 case.run(&plan)
293 }
294 }
295
296 #[test]
297 fn composite_proj_expr_containing_sort_col_as_subexpr() {
298 let plan = make_input()
299 .aggregate(vec![col("c1")], vec![min(col("c2")), max(col("c3"))])
300 .unwrap()
301 .project(vec![
302 col("c1"),
303 (min(col("c2")) + max(col("c3"))).alias("range"),
304 min(col("c2")).alias("min_val"),
305 max(col("c3")).alias("max_val"),
306 ])
307 .unwrap()
308 .build()
309 .unwrap();
310
311 let cases = vec![
312 TestCase {
313 desc: "sort by min(c2) should resolve to col(min_val), not col(range)",
314 input: sort(min(col("c2"))),
315 expected: sort(col("min_val")),
316 },
317 TestCase {
318 desc: "sort by max(c3) should resolve to col(max_val), not col(range)",
319 input: sort(max(col("c3"))),
320 expected: sort(col("max_val")),
321 },
322 ];
323
324 for case in cases {
325 case.run(&plan)
326 }
327 }
328
329 #[test]
330 fn composite_before_standalone_should_not_shadow() {
331 let plan = make_input()
332 .aggregate(vec![col("c1")], vec![min(col("c2")), max(col("c2"))])
333 .unwrap()
334 .project(vec![
335 col("c1"),
336 (min(col("c2")) + max(col("c2"))).alias("combined"),
337 min(col("c2")),
338 ])
339 .unwrap()
340 .build()
341 .unwrap();
342
343 let cases = vec![TestCase {
344 desc: "sort by min(c2) should resolve to col(min(t.c2)), not col(combined)",
345 input: sort(min(col("c2"))),
346 expected: sort(Expr::Column(Column::new_unqualified("min(t.c2)"))),
347 }];
348
349 for case in cases {
350 case.run(&plan)
351 }
352 }
353
354 #[test]
355 fn duplicate_aggregate_in_multiple_proj_exprs() {
356 let plan = make_input()
357 .aggregate(vec![col("c1")], vec![min(col("c2"))])
358 .unwrap()
359 .project(vec![
360 col("c1"),
361 min(col("c2")).alias("first_alias"),
362 min(col("c2")).alias("second_alias"),
363 ])
364 .unwrap()
365 .build()
366 .unwrap();
367
368 let cases = vec![TestCase {
369 desc: "sort by min(c2) with two aliases picks first_alias",
370 input: sort(min(col("c2"))),
371 expected: sort(col("first_alias")),
372 }];
373
374 for case in cases {
375 case.run(&plan)
376 }
377 }
378
379 #[test]
380 fn sort_agg_not_in_select_with_aliased_aggs() {
381 let plan = make_input()
382 .aggregate(
383 vec![col("c1")],
384 vec![min(col("c2")), max(col("c3")), sum(col("c3"))],
385 )
386 .unwrap()
387 .project(vec![
388 col("c1"),
389 min(col("c2")).alias("min_val"),
390 max(col("c3")).alias("max_val"),
391 ])
392 .unwrap()
393 .build()
394 .unwrap();
395
396 let cases = vec![TestCase {
397 desc: "sort by sum(c3) not in projection should not be rewritten",
398 input: sort(sum(col("c3"))),
399 expected: sort(sum(col("c3"))),
400 }];
401
402 for case in cases {
403 case.run(&plan)
404 }
405 }
406
407 #[test]
408 fn cast_on_aliased_aggregate() {
409 let plan = make_input()
410 .aggregate(vec![col("c1")], vec![min(col("c2"))])
411 .unwrap()
412 .project(vec![col("c1"), min(col("c2")).alias("min_val")])
413 .unwrap()
414 .build()
415 .unwrap();
416
417 let cases = vec![
418 TestCase {
419 desc: "CAST on aliased aggregate should preserve cast and resolve alias",
420 input: sort(cast(min(col("c2")), DataType::Int64)),
421 expected: sort(cast(col("min_val"), DataType::Int64)),
422 },
423 TestCase {
424 desc: "TryCast on aliased aggregate should preserve try_cast and resolve alias",
425 input: sort(try_cast(min(col("c2")), DataType::Int64)),
426 expected: sort(try_cast(col("min_val"), DataType::Int64)),
427 },
428 ];
429
430 for case in cases {
431 case.run(&plan)
432 }
433 }
434
435 #[test]
436 fn count_star_with_alias() {
437 let plan = make_input()
438 .aggregate(vec![col("c1")], vec![count(lit(1))])
439 .unwrap()
440 .project(vec![col("c1"), count(lit(1)).alias("cnt")])
441 .unwrap()
442 .build()
443 .unwrap();
444
445 let cases = vec![TestCase {
446 desc: "sort by count(1) should resolve to cnt alias",
447 input: sort(count(lit(1))),
448 expected: sort(col("cnt")),
449 }];
450
451 for case in cases {
452 case.run(&plan)
453 }
454 }
455
456 #[test]
457 fn preserve_cast() {
458 let plan = make_input()
459 .project(vec![col("c2").alias("c2")])
460 .unwrap()
461 .project(vec![col("c2").alias("c2")])
462 .unwrap()
463 .build()
464 .unwrap();
465
466 let cases = vec![
467 TestCase {
468 desc: "Cast is preserved by rewrite_sort_cols_by_aggs",
469 input: sort(cast(col("c2"), DataType::Int64)),
470 expected: sort(cast(col("c2"), DataType::Int64)),
471 },
472 TestCase {
473 desc: "TryCast is preserved by rewrite_sort_cols_by_aggs",
474 input: sort(try_cast(col("c2"), DataType::Int64)),
475 expected: sort(try_cast(col("c2"), DataType::Int64)),
476 },
477 ];
478
479 for case in cases {
480 case.run(&plan)
481 }
482 }
483
484 struct TestCase {
485 desc: &'static str,
486 input: Sort,
487 expected: Sort,
488 }
489
490 impl TestCase {
491 fn run(self, input_plan: &LogicalPlan) {
493 let Self {
494 desc,
495 input,
496 expected,
497 } = self;
498
499 println!("running: '{desc}'");
500 let mut exprs =
501 rewrite_sort_cols_by_aggs(vec![input.clone()], input_plan).unwrap();
502
503 assert_eq!(exprs.len(), 1);
504 let rewritten = exprs.pop().unwrap();
505
506 assert_eq!(
507 rewritten, expected,
508 "\n\ninput:{input:?}\nrewritten:{rewritten:?}\nexpected:{expected:?}\n"
509 );
510 }
511 }
512
513 fn make_input() -> LogicalPlanBuilder {
515 let schema = Arc::new(Schema::new(vec![
516 Field::new("c1", DataType::Int32, true),
517 Field::new("c2", DataType::Utf8, true),
518 Field::new("c3", DataType::Float64, true),
519 ]));
520 let projection = None;
521 LogicalPlanBuilder::scan(
522 "t",
523 Arc::new(LogicalTableSource::new(schema)),
524 projection,
525 )
526 .unwrap()
527 }
528
529 fn sort(expr: Expr) -> Sort {
530 let asc = true;
531 let nulls_first = true;
532 expr.sort(asc, nulls_first)
533 }
534}