proof_of_sql/sql/postprocessing/
order_by_postprocessing.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
use super::{PostprocessingError, PostprocessingResult, PostprocessingStep};
use crate::base::{
    database::{compare_indexes_by_owned_columns_with_direction, OwnedColumn, OwnedTable},
    if_rayon,
    math::permutation::Permutation,
    scalar::Scalar,
};
use alloc::{string::ToString, vec::Vec};
use proof_of_sql_parser::intermediate_ast::{OrderBy, OrderByDirection};
#[cfg(feature = "rayon")]
use rayon::prelude::ParallelSliceMut;
use serde::{Deserialize, Serialize};

/// A node representing a list of `OrderBy` expressions.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct OrderByPostprocessing {
    by_exprs: Vec<OrderBy>,
}

impl OrderByPostprocessing {
    /// Create a new `OrderByPostprocessing` node.
    #[must_use]
    pub fn new(by_exprs: Vec<OrderBy>) -> Self {
        Self { by_exprs }
    }
}

impl<S: Scalar> PostprocessingStep<S> for OrderByPostprocessing {
    /// Apply the slice transformation to the given `OwnedTable`.
    fn apply(&self, owned_table: OwnedTable<S>) -> PostprocessingResult<OwnedTable<S>> {
        let mut indexes = (0..owned_table.num_rows()).collect::<Vec<_>>();
        // Evaluate the columns by which we order
        // Once we allow OrderBy for general aggregation-free expressions here we will need to call eval()
        let order_by_pairs: Vec<(OwnedColumn<S>, OrderByDirection)> = self
            .by_exprs
            .iter()
            .map(
                |order_by| -> PostprocessingResult<(OwnedColumn<S>, OrderByDirection)> {
                    Ok((
                        owned_table
                            .inner_table()
                            .get(&order_by.expr)
                            .ok_or(PostprocessingError::ColumnNotFound {
                                column: order_by.expr.to_string(),
                            })?
                            .clone(),
                        order_by.direction,
                    ))
                },
            )
            .collect::<PostprocessingResult<Vec<(OwnedColumn<S>, OrderByDirection)>>>()?;
        // Define the ordering
        if_rayon!(
            indexes.par_sort_unstable_by(|&a, &b| {
                compare_indexes_by_owned_columns_with_direction(&order_by_pairs, a, b)
            }),
            indexes.sort_unstable_by(|&a, &b| {
                compare_indexes_by_owned_columns_with_direction(&order_by_pairs, a, b)
            })
        );
        let permutation = Permutation::unchecked_new(indexes);
        // Apply the ordering
        Ok(
            OwnedTable::<S>::try_from_iter(owned_table.into_inner().into_iter().map(
                |(identifier, column)| {
                    (
                        identifier,
                        column
                            .try_permute(&permutation)
                            .expect("There should be no column length mismatch here"),
                    )
                },
            ))
            .expect("There should be no column length mismatch here"),
        )
    }
}