proof_of_sql/sql/postprocessing/
order_by_postprocessing.rs

1use super::{PostprocessingError, PostprocessingResult, PostprocessingStep};
2use crate::base::{
3    database::{
4        order_by_util::compare_indexes_by_owned_columns_with_direction, OwnedColumn, OwnedTable,
5    },
6    math::permutation::Permutation,
7    scalar::Scalar,
8};
9use alloc::{string::ToString, vec::Vec};
10use proof_of_sql_parser::intermediate_ast::{OrderBy, OrderByDirection};
11use serde::{Deserialize, Serialize};
12
13/// A node representing a list of `OrderBy` expressions.
14#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
15pub struct OrderByPostprocessing {
16    by_exprs: Vec<OrderBy>,
17}
18
19impl OrderByPostprocessing {
20    /// Create a new `OrderByPostprocessing` node.
21    #[must_use]
22    pub fn new(by_exprs: Vec<OrderBy>) -> Self {
23        Self { by_exprs }
24    }
25}
26
27impl<S: Scalar> PostprocessingStep<S> for OrderByPostprocessing {
28    /// Apply the slice transformation to the given `OwnedTable`.
29    fn apply(&self, owned_table: OwnedTable<S>) -> PostprocessingResult<OwnedTable<S>> {
30        // Evaluate the columns by which we order
31        // Once we allow OrderBy for general aggregation-free expressions here we will need to call eval()
32        let order_by_pairs: Vec<(OwnedColumn<S>, OrderByDirection)> = self
33            .by_exprs
34            .iter()
35            .map(
36                |order_by| -> PostprocessingResult<(OwnedColumn<S>, OrderByDirection)> {
37                    let identifier: sqlparser::ast::Ident = order_by.expr.into();
38                    Ok((
39                        owned_table
40                            .inner_table()
41                            .get(&identifier)
42                            .ok_or(PostprocessingError::ColumnNotFound {
43                                column: order_by.expr.to_string(),
44                            })?
45                            .clone(),
46                        order_by.direction,
47                    ))
48                },
49            )
50            .collect::<PostprocessingResult<Vec<(OwnedColumn<S>, OrderByDirection)>>>()?;
51        // Define the ordering
52        let permutation = Permutation::unchecked_new_from_cmp(owned_table.num_rows(), |&a, &b| {
53            compare_indexes_by_owned_columns_with_direction(&order_by_pairs, a, b)
54        });
55        // Apply the ordering
56        Ok(
57            OwnedTable::<S>::try_from_iter(owned_table.into_inner().into_iter().map(
58                |(identifier, column)| {
59                    (
60                        identifier,
61                        column
62                            .try_permute(&permutation)
63                            .expect("There should be no column length mismatch here"),
64                    )
65                },
66            ))
67            .expect("There should be no column length mismatch here"),
68        )
69    }
70}