use ahash::RandomState;
use datafusion_physical_expr_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator;
use std::collections::HashSet;
use std::ops::BitAnd;
use std::{fmt::Debug, sync::Arc};
use arrow::{
    array::{ArrayRef, AsArray},
    compute,
    datatypes::{
        DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field,
        Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
        Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
        Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
        TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
        UInt16Type, UInt32Type, UInt64Type, UInt8Type,
    },
};
use arrow::{
    array::{Array, BooleanArray, Int64Array, PrimitiveArray},
    buffer::BooleanBuffer,
};
use datafusion_common::{
    downcast_value, internal_err, not_impl_err, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::function::StateFieldsArgs;
use datafusion_expr::{
    function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
    EmitTo, GroupsAccumulator, Signature, Volatility,
};
use datafusion_expr::{Expr, ReversedUDAF, TypeSignature};
use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::accumulate_indices;
use datafusion_physical_expr_common::{
    aggregate::count_distinct::{
        BytesDistinctCountAccumulator, FloatDistinctCountAccumulator,
        PrimitiveDistinctCountAccumulator,
    },
    binary_map::OutputType,
};
make_udaf_expr_and_func!(
    Count,
    count,
    expr,
    "Count the number of non-null values in the column",
    count_udaf
);
pub fn count_distinct(expr: Expr) -> Expr {
    Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
        count_udaf(),
        vec![expr],
        true,
        None,
        None,
        None,
    ))
}
pub struct Count {
    signature: Signature,
}
impl Debug for Count {
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
        f.debug_struct("Count")
            .field("name", &self.name())
            .field("signature", &self.signature)
            .finish()
    }
}
impl Default for Count {
    fn default() -> Self {
        Self::new()
    }
}
impl Count {
    pub fn new() -> Self {
        Self {
            signature: Signature::one_of(
                vec![TypeSignature::VariadicAny, TypeSignature::Any(0)],
                Volatility::Immutable,
            ),
        }
    }
}
impl AggregateUDFImpl for Count {
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }
    fn name(&self) -> &str {
        "count"
    }
    fn signature(&self) -> &Signature {
        &self.signature
    }
    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
        Ok(DataType::Int64)
    }
    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
        if args.is_distinct {
            Ok(vec![Field::new_list(
                format_state_name(args.name, "count distinct"),
                Field::new("item", args.input_types[0].clone(), true),
                false,
            )])
        } else {
            Ok(vec![Field::new(
                format_state_name(args.name, "count"),
                DataType::Int64,
                true,
            )])
        }
    }
    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
        if !acc_args.is_distinct {
            return Ok(Box::new(CountAccumulator::new()));
        }
        if acc_args.input_exprs.len() > 1 {
            return not_impl_err!("COUNT DISTINCT with multiple arguments");
        }
        let data_type = &acc_args.input_types[0];
        Ok(match data_type {
            DataType::Int8 => Box::new(
                PrimitiveDistinctCountAccumulator::<Int8Type>::new(data_type),
            ),
            DataType::Int16 => Box::new(
                PrimitiveDistinctCountAccumulator::<Int16Type>::new(data_type),
            ),
            DataType::Int32 => Box::new(
                PrimitiveDistinctCountAccumulator::<Int32Type>::new(data_type),
            ),
            DataType::Int64 => Box::new(
                PrimitiveDistinctCountAccumulator::<Int64Type>::new(data_type),
            ),
            DataType::UInt8 => Box::new(
                PrimitiveDistinctCountAccumulator::<UInt8Type>::new(data_type),
            ),
            DataType::UInt16 => Box::new(
                PrimitiveDistinctCountAccumulator::<UInt16Type>::new(data_type),
            ),
            DataType::UInt32 => Box::new(
                PrimitiveDistinctCountAccumulator::<UInt32Type>::new(data_type),
            ),
            DataType::UInt64 => Box::new(
                PrimitiveDistinctCountAccumulator::<UInt64Type>::new(data_type),
            ),
            DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
                Decimal128Type,
            >::new(data_type)),
            DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
                Decimal256Type,
            >::new(data_type)),
            DataType::Date32 => Box::new(
                PrimitiveDistinctCountAccumulator::<Date32Type>::new(data_type),
            ),
            DataType::Date64 => Box::new(
                PrimitiveDistinctCountAccumulator::<Date64Type>::new(data_type),
            ),
            DataType::Time32(TimeUnit::Millisecond) => Box::new(
                PrimitiveDistinctCountAccumulator::<Time32MillisecondType>::new(
                    data_type,
                ),
            ),
            DataType::Time32(TimeUnit::Second) => Box::new(
                PrimitiveDistinctCountAccumulator::<Time32SecondType>::new(data_type),
            ),
            DataType::Time64(TimeUnit::Microsecond) => Box::new(
                PrimitiveDistinctCountAccumulator::<Time64MicrosecondType>::new(
                    data_type,
                ),
            ),
            DataType::Time64(TimeUnit::Nanosecond) => Box::new(
                PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new(data_type),
            ),
            DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new(
                PrimitiveDistinctCountAccumulator::<TimestampMicrosecondType>::new(
                    data_type,
                ),
            ),
            DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new(
                PrimitiveDistinctCountAccumulator::<TimestampMillisecondType>::new(
                    data_type,
                ),
            ),
            DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new(
                PrimitiveDistinctCountAccumulator::<TimestampNanosecondType>::new(
                    data_type,
                ),
            ),
            DataType::Timestamp(TimeUnit::Second, _) => Box::new(
                PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new(data_type),
            ),
            DataType::Float16 => {
                Box::new(FloatDistinctCountAccumulator::<Float16Type>::new())
            }
            DataType::Float32 => {
                Box::new(FloatDistinctCountAccumulator::<Float32Type>::new())
            }
            DataType::Float64 => {
                Box::new(FloatDistinctCountAccumulator::<Float64Type>::new())
            }
            DataType::Utf8 => {
                Box::new(BytesDistinctCountAccumulator::<i32>::new(OutputType::Utf8))
            }
            DataType::Utf8View => {
                Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View))
            }
            DataType::LargeUtf8 => {
                Box::new(BytesDistinctCountAccumulator::<i64>::new(OutputType::Utf8))
            }
            DataType::Binary => Box::new(BytesDistinctCountAccumulator::<i32>::new(
                OutputType::Binary,
            )),
            DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new(
                OutputType::BinaryView,
            )),
            DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::<i64>::new(
                OutputType::Binary,
            )),
            _ => Box::new(DistinctCountAccumulator {
                values: HashSet::default(),
                state_data_type: data_type.clone(),
            }),
        })
    }
    fn aliases(&self) -> &[String] {
        &[]
    }
    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
        if args.is_distinct {
            return false;
        }
        args.input_exprs.len() == 1
    }
    fn create_groups_accumulator(
        &self,
        _args: AccumulatorArgs,
    ) -> Result<Box<dyn GroupsAccumulator>> {
        Ok(Box::new(CountGroupsAccumulator::new()))
    }
    fn reverse_expr(&self) -> ReversedUDAF {
        ReversedUDAF::Identical
    }
}
#[derive(Debug)]
struct CountAccumulator {
    count: i64,
}
impl CountAccumulator {
    pub fn new() -> Self {
        Self { count: 0 }
    }
}
impl Accumulator for CountAccumulator {
    fn state(&mut self) -> Result<Vec<ScalarValue>> {
        Ok(vec![ScalarValue::Int64(Some(self.count))])
    }
    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
        let array = &values[0];
        self.count += (array.len() - null_count_for_multiple_cols(values)) as i64;
        Ok(())
    }
    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
        let array = &values[0];
        self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64;
        Ok(())
    }
    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
        let counts = downcast_value!(states[0], Int64Array);
        let delta = &arrow::compute::sum(counts);
        if let Some(d) = delta {
            self.count += *d;
        }
        Ok(())
    }
    fn evaluate(&mut self) -> Result<ScalarValue> {
        Ok(ScalarValue::Int64(Some(self.count)))
    }
    fn supports_retract_batch(&self) -> bool {
        true
    }
    fn size(&self) -> usize {
        std::mem::size_of_val(self)
    }
}
#[derive(Debug)]
struct CountGroupsAccumulator {
    counts: Vec<i64>,
}
impl CountGroupsAccumulator {
    pub fn new() -> Self {
        Self { counts: vec![] }
    }
}
impl GroupsAccumulator for CountGroupsAccumulator {
    fn update_batch(
        &mut self,
        values: &[ArrayRef],
        group_indices: &[usize],
        opt_filter: Option<&BooleanArray>,
        total_num_groups: usize,
    ) -> Result<()> {
        assert_eq!(values.len(), 1, "single argument to update_batch");
        let values = &values[0];
        self.counts.resize(total_num_groups, 0);
        accumulate_indices(
            group_indices,
            values.logical_nulls().as_ref(),
            opt_filter,
            |group_index| {
                self.counts[group_index] += 1;
            },
        );
        Ok(())
    }
    fn merge_batch(
        &mut self,
        values: &[ArrayRef],
        group_indices: &[usize],
        opt_filter: Option<&BooleanArray>,
        total_num_groups: usize,
    ) -> Result<()> {
        assert_eq!(values.len(), 1, "one argument to merge_batch");
        let partial_counts = values[0].as_primitive::<Int64Type>();
        assert_eq!(partial_counts.null_count(), 0);
        let partial_counts = partial_counts.values();
        self.counts.resize(total_num_groups, 0);
        match opt_filter {
            Some(filter) => filter
                .iter()
                .zip(group_indices.iter())
                .zip(partial_counts.iter())
                .for_each(|((filter_value, &group_index), partial_count)| {
                    if let Some(true) = filter_value {
                        self.counts[group_index] += partial_count;
                    }
                }),
            None => group_indices.iter().zip(partial_counts.iter()).for_each(
                |(&group_index, partial_count)| {
                    self.counts[group_index] += partial_count;
                },
            ),
        }
        Ok(())
    }
    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
        let counts = emit_to.take_needed(&mut self.counts);
        let nulls = None;
        let array = PrimitiveArray::<Int64Type>::new(counts.into(), nulls);
        Ok(Arc::new(array))
    }
    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
        let counts = emit_to.take_needed(&mut self.counts);
        let counts: PrimitiveArray<Int64Type> = Int64Array::from(counts); Ok(vec![Arc::new(counts) as ArrayRef])
    }
    fn convert_to_state(
        &self,
        values: &[ArrayRef],
        opt_filter: Option<&BooleanArray>,
    ) -> Result<Vec<ArrayRef>> {
        let values = &values[0];
        let state_array = match (values.logical_nulls(), opt_filter) {
            (None, None) => {
                Arc::new(Int64Array::from_value(1, values.len()))
            }
            (Some(nulls), None) => {
                let nulls = BooleanArray::new(nulls.into_inner(), None);
                compute::cast(&nulls, &DataType::Int64)?
            }
            (None, Some(filter)) => {
                let (filter_values, filter_nulls) = filter.clone().into_parts();
                let state_buf = match filter_nulls {
                    Some(filter_nulls) => &filter_values & filter_nulls.inner(),
                    None => filter_values,
                };
                let boolean_state = BooleanArray::new(state_buf, None);
                compute::cast(&boolean_state, &DataType::Int64)?
            }
            (Some(nulls), Some(filter)) => {
                let (filter_values, filter_nulls) = filter.clone().into_parts();
                let filter_buf = match filter_nulls {
                    Some(filter_nulls) => &filter_values & filter_nulls.inner(),
                    None => filter_values,
                };
                let state_buf = &filter_buf & nulls.inner();
                let boolean_state = BooleanArray::new(state_buf, None);
                compute::cast(&boolean_state, &DataType::Int64)?
            }
        };
        Ok(vec![state_array])
    }
    fn supports_convert_to_state(&self) -> bool {
        true
    }
    fn size(&self) -> usize {
        self.counts.capacity() * std::mem::size_of::<usize>()
    }
}
fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize {
    if values.len() > 1 {
        let result_bool_buf: Option<BooleanBuffer> = values
            .iter()
            .map(|a| a.logical_nulls())
            .fold(None, |acc, b| match (acc, b) {
                (Some(acc), Some(b)) => Some(acc.bitand(b.inner())),
                (Some(acc), None) => Some(acc),
                (None, Some(b)) => Some(b.into_inner()),
                _ => None,
            });
        result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits())
    } else {
        values[0]
            .logical_nulls()
            .map_or(0, |nulls| nulls.null_count())
    }
}
#[derive(Debug)]
struct DistinctCountAccumulator {
    values: HashSet<ScalarValue, RandomState>,
    state_data_type: DataType,
}
impl DistinctCountAccumulator {
    fn fixed_size(&self) -> usize {
        std::mem::size_of_val(self)
            + (std::mem::size_of::<ScalarValue>() * self.values.capacity())
            + self
                .values
                .iter()
                .next()
                .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals))
                .unwrap_or(0)
            + std::mem::size_of::<DataType>()
    }
    fn full_size(&self) -> usize {
        std::mem::size_of_val(self)
            + (std::mem::size_of::<ScalarValue>() * self.values.capacity())
            + self
                .values
                .iter()
                .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals))
                .sum::<usize>()
            + std::mem::size_of::<DataType>()
    }
}
impl Accumulator for DistinctCountAccumulator {
    fn state(&mut self) -> Result<Vec<ScalarValue>> {
        let scalars = self.values.iter().cloned().collect::<Vec<_>>();
        let arr =
            ScalarValue::new_list_nullable(scalars.as_slice(), &self.state_data_type);
        Ok(vec![ScalarValue::List(arr)])
    }
    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
        if values.is_empty() {
            return Ok(());
        }
        let arr = &values[0];
        if arr.data_type() == &DataType::Null {
            return Ok(());
        }
        (0..arr.len()).try_for_each(|index| {
            if !arr.is_null(index) {
                let scalar = ScalarValue::try_from_array(arr, index)?;
                self.values.insert(scalar);
            }
            Ok(())
        })
    }
    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
        if states.is_empty() {
            return Ok(());
        }
        assert_eq!(states.len(), 1, "array_agg states must be singleton!");
        let array = &states[0];
        let list_array = array.as_list::<i32>();
        for inner_array in list_array.iter() {
            let Some(inner_array) = inner_array else {
                return internal_err!(
                    "Intermediate results of COUNT DISTINCT should always be non null"
                );
            };
            self.update_batch(&[inner_array])?;
        }
        Ok(())
    }
    fn evaluate(&mut self) -> Result<ScalarValue> {
        Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
    }
    fn size(&self) -> usize {
        match &self.state_data_type {
            DataType::Boolean | DataType::Null => self.fixed_size(),
            d if d.is_primitive() => self.fixed_size(),
            _ => self.full_size(),
        }
    }
}