use super::{PostprocessingError, PostprocessingResult, PostprocessingStep};
use crate::base::{
    database::{group_by_util::aggregate_columns, Column, OwnedColumn, OwnedTable},
    map::{indexmap, IndexMap, IndexSet},
    scalar::Scalar,
};
use alloc::{boxed::Box, format, string::ToString, vec, vec::Vec};
use bumpalo::Bump;
use itertools::{izip, Itertools};
use proof_of_sql_parser::{
    intermediate_ast::{AggregationOperator, AliasedResultExpr, Expression},
    Identifier,
};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct GroupByPostprocessing {
    remainder_exprs: Vec<AliasedResultExpr>,
    group_by_identifiers: Vec<Identifier>,
    aggregation_exprs: Vec<(AggregationOperator, Expression, Identifier)>,
}
fn contains_nested_aggregation(expr: &Expression, is_agg: bool) -> bool {
    match expr {
        Expression::Column(_) | Expression::Literal(_) | Expression::Wildcard => false,
        Expression::Aggregation { expr, .. } => is_agg || contains_nested_aggregation(expr, true),
        Expression::Binary { left, right, .. } => {
            contains_nested_aggregation(left, is_agg) || contains_nested_aggregation(right, is_agg)
        }
        Expression::Unary { expr, .. } => contains_nested_aggregation(expr, is_agg),
    }
}
fn get_free_identifiers_from_expr(expr: &Expression) -> IndexSet<Identifier> {
    match expr {
        Expression::Column(identifier) => IndexSet::from_iter([*identifier]),
        Expression::Literal(_) | Expression::Aggregation { .. } | Expression::Wildcard => {
            IndexSet::default()
        }
        Expression::Binary { left, right, .. } => {
            let mut left_identifiers = get_free_identifiers_from_expr(left);
            let right_identifiers = get_free_identifiers_from_expr(right);
            left_identifiers.extend(right_identifiers);
            left_identifiers
        }
        Expression::Unary { expr, .. } => get_free_identifiers_from_expr(expr),
    }
}
fn get_aggregate_and_remainder_expressions(
    expr: Expression,
    aggregation_expr_map: &mut IndexMap<(AggregationOperator, Expression), Identifier>,
) -> Expression {
    match expr {
        Expression::Column(_) | Expression::Literal(_) | Expression::Wildcard => expr,
        Expression::Aggregation { op, expr } => {
            let key = (op, (*expr));
            if aggregation_expr_map.contains_key(&key) {
                Expression::Column(*aggregation_expr_map.get(&key).unwrap())
            } else {
                let new_col_id = format!("__col_agg_{}", aggregation_expr_map.len())
                    .parse()
                    .unwrap();
                aggregation_expr_map.insert(key, new_col_id);
                Expression::Column(new_col_id)
            }
        }
        Expression::Binary { op, left, right } => {
            let left_remainder =
                get_aggregate_and_remainder_expressions(*left, aggregation_expr_map);
            let right_remainder =
                get_aggregate_and_remainder_expressions(*right, aggregation_expr_map);
            Expression::Binary {
                op,
                left: Box::new(left_remainder),
                right: Box::new(right_remainder),
            }
        }
        Expression::Unary { op, expr } => {
            let remainder = get_aggregate_and_remainder_expressions(*expr, aggregation_expr_map);
            Expression::Unary {
                op,
                expr: Box::new(remainder),
            }
        }
    }
}
fn check_and_get_aggregation_and_remainder(
    expr: AliasedResultExpr,
    group_by_identifiers: &[Identifier],
    aggregation_expr_map: &mut IndexMap<(AggregationOperator, Expression), Identifier>,
) -> PostprocessingResult<AliasedResultExpr> {
    let free_identifiers = get_free_identifiers_from_expr(&expr.expr);
    let group_by_identifier_set = group_by_identifiers
        .iter()
        .copied()
        .collect::<IndexSet<_>>();
    if contains_nested_aggregation(&expr.expr, false) {
        return Err(PostprocessingError::NestedAggregationInGroupByClause {
            error: format!("Nested aggregations found {:?}", expr.expr),
        });
    }
    if free_identifiers.is_subset(&group_by_identifier_set) {
        let remainder = get_aggregate_and_remainder_expressions(*expr.expr, aggregation_expr_map);
        Ok(AliasedResultExpr {
            alias: expr.alias,
            expr: Box::new(remainder),
        })
    } else {
        let diff = free_identifiers
            .difference(&group_by_identifier_set)
            .next()
            .unwrap();
        Err(
            PostprocessingError::IdentifierNotInAggregationOperatorOrGroupByClause {
                column: *diff,
            },
        )
    }
}
impl GroupByPostprocessing {
    pub fn try_new(
        by_ids: Vec<Identifier>,
        aliased_exprs: Vec<AliasedResultExpr>,
    ) -> PostprocessingResult<Self> {
        let mut aggregation_expr_map: IndexMap<(AggregationOperator, Expression), Identifier> =
            IndexMap::default();
        let remainder_exprs: Vec<AliasedResultExpr> = aliased_exprs
            .into_iter()
            .map(|aliased_expr| -> PostprocessingResult<_> {
                check_and_get_aggregation_and_remainder(
                    aliased_expr,
                    &by_ids,
                    &mut aggregation_expr_map,
                )
            })
            .collect::<PostprocessingResult<Vec<AliasedResultExpr>>>()?;
        let group_by_identifiers = Vec::from_iter(IndexSet::from_iter(by_ids));
        Ok(Self {
            remainder_exprs,
            group_by_identifiers,
            aggregation_exprs: aggregation_expr_map
                .into_iter()
                .map(|((op, expr), id)| (op, expr, id))
                .collect(),
        })
    }
    #[must_use]
    pub fn group_by(&self) -> &[Identifier] {
        &self.group_by_identifiers
    }
    #[must_use]
    pub fn remainder_exprs(&self) -> &[AliasedResultExpr] {
        &self.remainder_exprs
    }
    #[must_use]
    pub fn aggregation_exprs(&self) -> &[(AggregationOperator, Expression, Identifier)] {
        &self.aggregation_exprs
    }
}
impl<S: Scalar> PostprocessingStep<S> for GroupByPostprocessing {
    #[allow(clippy::too_many_lines)]
    fn apply(&self, owned_table: OwnedTable<S>) -> PostprocessingResult<OwnedTable<S>> {
        let alloc = Bump::new();
        let evaluated_columns = self
            .aggregation_exprs
            .iter()
            .map(|(agg_op, expr, id)| -> PostprocessingResult<_> {
                let evaluated_owned_column = owned_table.evaluate(expr)?;
                Ok((*agg_op, (*id, evaluated_owned_column)))
            })
            .process_results(|iter| {
                iter.fold(
                    IndexMap::<_, Vec<_>>::default(),
                    |mut lookup, (key, val)| {
                        lookup.entry(key).or_default().push(val);
                        lookup
                    },
                )
            })?;
        let group_by_ins = self
            .group_by_identifiers
            .iter()
            .map(|id| {
                let column = owned_table.inner_table().get(id).ok_or(
                    PostprocessingError::ColumnNotFound {
                        column: id.to_string(),
                    },
                )?;
                Ok(Column::<S>::from_owned_column(column, &alloc))
            })
            .collect::<PostprocessingResult<Vec<_>>>()?;
        let selection_in = vec![true; owned_table.num_rows()];
        let (sum_identifiers, sum_columns): (Vec<_>, Vec<_>) = evaluated_columns
            .get(&AggregationOperator::Sum)
            .map_or((vec![], vec![]), |tuple| {
                tuple
                    .iter()
                    .map(|(id, c)| (*id, Column::<S>::from_owned_column(c, &alloc)))
                    .unzip()
            });
        let (max_identifiers, max_columns): (Vec<_>, Vec<_>) = evaluated_columns
            .get(&AggregationOperator::Max)
            .map_or((vec![], vec![]), |tuple| {
                tuple
                    .iter()
                    .map(|(id, c)| (*id, Column::<S>::from_owned_column(c, &alloc)))
                    .unzip()
            });
        let (min_identifiers, min_columns): (Vec<_>, Vec<_>) = evaluated_columns
            .get(&AggregationOperator::Min)
            .map_or((vec![], vec![]), |tuple| {
                tuple
                    .iter()
                    .map(|(id, c)| (*id, Column::<S>::from_owned_column(c, &alloc)))
                    .unzip()
            });
        let aggregation_results = aggregate_columns(
            &alloc,
            &group_by_ins,
            &sum_columns,
            &max_columns,
            &min_columns,
            &selection_in,
        )?;
        let group_by_outs = aggregation_results
            .group_by_columns
            .iter()
            .zip(self.group_by_identifiers.iter())
            .map(|(column, id)| Ok((*id, OwnedColumn::from(column))));
        let sum_outs = izip!(
            aggregation_results.sum_columns,
            sum_identifiers,
            sum_columns,
        )
        .map(|(c_out, id, c_in)| {
            Ok((
                id,
                OwnedColumn::try_from_scalars(c_out, c_in.column_type())?,
            ))
        });
        let max_outs = izip!(
            aggregation_results.max_columns,
            max_identifiers,
            max_columns,
        )
        .map(|(c_out, id, c_in)| {
            Ok((
                id,
                OwnedColumn::try_from_option_scalars(c_out, c_in.column_type())?,
            ))
        });
        let min_outs = izip!(
            aggregation_results.min_columns,
            min_identifiers,
            min_columns,
        )
        .map(|(c_out, id, c_in)| {
            Ok((
                id,
                OwnedColumn::try_from_option_scalars(c_out, c_in.column_type())?,
            ))
        });
        let count_column = OwnedColumn::BigInt(aggregation_results.count_column.to_vec());
        let count_outs = evaluated_columns
            .get(&AggregationOperator::Count)
            .into_iter()
            .flatten()
            .map(|(id, _)| -> PostprocessingResult<_> { Ok((*id, count_column.clone())) });
        let new_owned_table: OwnedTable<S> = group_by_outs
            .into_iter()
            .chain(sum_outs)
            .chain(max_outs)
            .chain(min_outs)
            .chain(count_outs)
            .process_results(|iter| OwnedTable::try_from_iter(iter))??;
        let target_table = if new_owned_table.is_empty() {
            OwnedTable::try_new(indexmap! {"__count__".parse().unwrap() => count_column})?
        } else {
            new_owned_table
        };
        let result = self
            .remainder_exprs
            .iter()
            .map(|aliased_expr| -> PostprocessingResult<_> {
                let column = target_table.evaluate(&aliased_expr.expr)?;
                Ok((aliased_expr.alias, column))
            })
            .process_results(|iter| OwnedTable::try_from_iter(iter))??;
        Ok(result)
    }
}
#[cfg(test)]
mod tests {
    use super::*;
    use proof_of_sql_parser::utility::*;
    #[test]
    fn we_can_detect_nested_aggregation() {
        let expr = sum(sum(col("a")));
        assert!(contains_nested_aggregation(&expr, false));
        assert!(contains_nested_aggregation(&expr, true));
        let expr = add(max(col("a")), sum(col("b")));
        assert!(!contains_nested_aggregation(&expr, false));
        assert!(contains_nested_aggregation(&expr, true));
        let expr = add(col("a"), sum(col("b")));
        assert!(!contains_nested_aggregation(&expr, false));
        assert!(contains_nested_aggregation(&expr, true));
        let expr = sub(add(sum(col("a")), col("b")), sum(mul(lit(2), col("c"))));
        assert!(!contains_nested_aggregation(&expr, false));
        assert!(contains_nested_aggregation(&expr, true));
        let expr = add(col("a"), count(sum(col("a"))));
        assert!(contains_nested_aggregation(&expr, false));
        assert!(contains_nested_aggregation(&expr, true));
        let expr = add(add(col("a"), col("b")), lit(1));
        assert!(!contains_nested_aggregation(&expr, false));
        assert!(!contains_nested_aggregation(&expr, true));
    }
    #[test]
    fn we_can_get_free_identifiers_from_expr() {
        let expr = lit("Not an identifier");
        let expected: IndexSet<Identifier> = IndexSet::default();
        let actual = get_free_identifiers_from_expr(&expr);
        assert_eq!(actual, expected);
        let expr = add(add(col("a"), col("b")), lit(1));
        let expected: IndexSet<Identifier> = [ident("a"), ident("b")].iter().copied().collect();
        let actual = get_free_identifiers_from_expr(&expr);
        assert_eq!(actual, expected);
        let expr = not(or(equal(col("a"), col("b")), ge(col("c"), col("a"))));
        let expected: IndexSet<Identifier> = [ident("a"), ident("b"), ident("c")]
            .iter()
            .copied()
            .collect();
        let actual = get_free_identifiers_from_expr(&expr);
        assert_eq!(actual, expected);
        let expr = mul(sum(add(col("a"), col("b"))), lit(2));
        let expected: IndexSet<Identifier> = IndexSet::default();
        let actual = get_free_identifiers_from_expr(&expr);
        assert_eq!(actual, expected);
        let expr = mul(add(count(add(col("a"), col("b"))), col("c")), col("d"));
        let expected: IndexSet<Identifier> = [ident("c"), ident("d")].iter().copied().collect();
        let actual = get_free_identifiers_from_expr(&expr);
        assert_eq!(actual, expected);
    }
    #[test]
    fn we_can_get_aggregate_and_remainder_expressions() {
        let mut aggregation_expr_map: IndexMap<(AggregationOperator, Expression), Identifier> =
            IndexMap::default();
        let expr = add(sum(col("a")), col("b"));
        let remainder_expr =
            get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map);
        assert_eq!(
            aggregation_expr_map[&(AggregationOperator::Sum, *col("a"))],
            ident("__col_agg_0")
        );
        assert_eq!(remainder_expr, *add(col("__col_agg_0"), col("b")));
        assert_eq!(aggregation_expr_map.len(), 1);
        let expr = add(sum(col("a")), sum(col("b")));
        let remainder_expr =
            get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map);
        assert_eq!(
            aggregation_expr_map[&(AggregationOperator::Sum, *col("a"))],
            ident("__col_agg_0")
        );
        assert_eq!(
            aggregation_expr_map[&(AggregationOperator::Sum, *col("b"))],
            ident("__col_agg_1")
        );
        assert_eq!(remainder_expr, *add(col("__col_agg_0"), col("__col_agg_1")));
        assert_eq!(aggregation_expr_map.len(), 2);
        let expr = add(
            add(
                max(col("a") + lit(1)),
                min(sub(mul(lit(2), col("b")), lit(4))),
            ),
            col("c"),
        );
        let remainder_expr =
            get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map);
        assert_eq!(
            aggregation_expr_map[&(AggregationOperator::Max, *add(col("a"), lit(1)))],
            ident("__col_agg_2")
        );
        assert_eq!(
            aggregation_expr_map[&(
                AggregationOperator::Min,
                *sub(mul(lit(2), col("b")), lit(4))
            )],
            ident("__col_agg_3")
        );
        assert_eq!(
            remainder_expr,
            *add(add(col("__col_agg_2"), col("__col_agg_3")), col("c"))
        );
        assert_eq!(aggregation_expr_map.len(), 4);
        let expr = add(
            add(mul(count(mul(lit(2), col("a"))), lit(2)), sum(col("b"))),
            lit(1),
        );
        let remainder_expr =
            get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map);
        assert_eq!(
            aggregation_expr_map[&(AggregationOperator::Count, *mul(lit(2), col("a")))],
            ident("__col_agg_4")
        );
        assert_eq!(
            remainder_expr,
            *add(
                add(mul(col("__col_agg_4"), lit(2)), col("__col_agg_1")),
                lit(1)
            )
        );
        assert_eq!(aggregation_expr_map.len(), 5);
    }
}