micromegas-analytics 0.11.0

analytics module of micromegas
Documentation
use datafusion::{
    arrow::{
        array::{Array, ArrayRef, Float64Array, ListArray, StructArray, UInt64Array},
        datatypes::{DataType, Fields},
    },
    error::DataFusionError,
    logical_expr::{
        Accumulator, AggregateUDF, ColumnarValue, Volatility, function::AccumulatorArgs,
    },
    physical_plan::expressions::Literal,
    prelude::*,
    scalar::ScalarValue,
};
use std::sync::Arc;

use super::accumulator::{HistogramAccumulator, state_arrow_fields};

#[derive(Debug)]
pub struct HistogramArray {
    inner: Arc<StructArray>,
}

impl HistogramArray {
    pub fn new(inner: Arc<StructArray>) -> Self {
        Self { inner }
    }

    pub fn inner(&self) -> Arc<StructArray> {
        self.inner.clone()
    }

    pub fn len(&self) -> usize {
        self.inner.len()
    }

    pub fn is_empty(&self) -> bool {
        self.inner.is_empty()
    }

    pub fn get_start(&self, index: usize) -> Result<f64, DataFusionError> {
        let starts = self
            .inner
            .column(0)
            .as_any()
            .downcast_ref::<Float64Array>()
            .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
        Ok(starts.value(index))
    }

    pub fn get_end(&self, index: usize) -> Result<f64, DataFusionError> {
        let ends = self
            .inner
            .column(1)
            .as_any()
            .downcast_ref::<Float64Array>()
            .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
        Ok(ends.value(index))
    }

    pub fn get_min(&self, index: usize) -> Result<f64, DataFusionError> {
        let mins = self
            .inner
            .column(2)
            .as_any()
            .downcast_ref::<Float64Array>()
            .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
        Ok(mins.value(index))
    }

    pub fn get_max(&self, index: usize) -> Result<f64, DataFusionError> {
        let maxs = self
            .inner
            .column(3)
            .as_any()
            .downcast_ref::<Float64Array>()
            .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
        Ok(maxs.value(index))
    }

    pub fn get_sum(&self, index: usize) -> Result<f64, DataFusionError> {
        let sums = self
            .inner
            .column(4)
            .as_any()
            .downcast_ref::<Float64Array>()
            .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
        Ok(sums.value(index))
    }

    pub fn get_sum_sq(&self, index: usize) -> Result<f64, DataFusionError> {
        let sums_sq = self
            .inner
            .column(5)
            .as_any()
            .downcast_ref::<Float64Array>()
            .ok_or_else(|| DataFusionError::Execution("downcasting to Float64Array".into()))?;
        Ok(sums_sq.value(index))
    }

    pub fn get_count(&self, index: usize) -> Result<u64, DataFusionError> {
        let counts = self
            .inner
            .column(6)
            .as_any()
            .downcast_ref::<UInt64Array>()
            .ok_or_else(|| DataFusionError::Execution("downcasting to UInt64Array".into()))?;
        Ok(counts.value(index))
    }

    pub fn get_bins(&self, index: usize) -> Result<UInt64Array, DataFusionError> {
        let bins_list = self
            .inner
            .column(7)
            .as_any()
            .downcast_ref::<ListArray>()
            .ok_or_else(|| DataFusionError::Execution("downcasting to ListArray".into()))?;
        let bins = bins_list.value(index);
        let bins = bins
            .as_any()
            .downcast_ref::<UInt64Array>()
            .ok_or_else(|| DataFusionError::Execution("downcasting to UInt64Array".into()))?;
        Ok(bins.clone())
    }
}

impl TryFrom<&ArrayRef> for HistogramArray {
    type Error = DataFusionError;

    fn try_from(value: &ArrayRef) -> Result<Self, Self::Error> {
        let struct_array = value
            .as_any()
            .downcast_ref::<StructArray>()
            .ok_or_else(|| DataFusionError::Execution("downcasting to StructArray".into()))?;
        let inner = Arc::new(struct_array.clone());
        Ok(Self { inner })
    }
}

impl TryFrom<&dyn Array> for HistogramArray {
    type Error = DataFusionError;

    fn try_from(value: &dyn Array) -> Result<Self, Self::Error> {
        let struct_array = value
            .as_any()
            .downcast_ref::<StructArray>()
            .ok_or_else(|| DataFusionError::Execution("downcasting to StructArray".into()))?;
        let inner = Arc::new(struct_array.clone());
        Ok(Self { inner })
    }
}

impl TryFrom<&ColumnarValue> for HistogramArray {
    type Error = DataFusionError;

    fn try_from(value: &ColumnarValue) -> Result<Self, Self::Error> {
        match value {
            ColumnarValue::Array(array) => array.try_into(),
            ColumnarValue::Scalar(scalar_value) => {
                if let ScalarValue::Struct(array) = scalar_value {
                    Ok(Self::new(array.clone()))
                } else {
                    Err(DataFusionError::Execution( "Can't convert ColumnarValue into HistogramArray: ScalarValue is not a struct".into()))
                }
            }
        }
    }
}

fn make_state(args: AccumulatorArgs) -> Result<Box<dyn Accumulator>, DataFusionError> {
    let start_arg = args
        .exprs
        .first()
        .ok_or_else(|| DataFusionError::Execution("Reading first argument".into()))?
        .as_any()
        .downcast_ref::<Literal>()
        .ok_or_else(|| DataFusionError::Execution("Downcasting first argument to Literal".into()))?
        .value();
    let start = if let ScalarValue::Float64(Some(start_value)) = start_arg {
        start_value
    } else {
        return Err(DataFusionError::Execution(format!(
            "arg 0 should be a float64, found {start_arg:?}"
        )));
    };

    let end_arg = args
        .exprs
        .get(1)
        .ok_or_else(|| DataFusionError::Execution("Reading argument 1".into()))?
        .as_any()
        .downcast_ref::<Literal>()
        .ok_or_else(|| DataFusionError::Execution("Downcasting argument 1 to Literal".into()))?
        .value();
    let end = if let ScalarValue::Float64(Some(end_value)) = end_arg {
        end_value
    } else {
        return Err(DataFusionError::Execution(format!(
            "arg 0 should be a float64, found {end_arg:?}"
        )));
    };

    let nb_bins_arg = args
        .exprs
        .get(2)
        .ok_or_else(|| DataFusionError::Execution("Reading argument 2".into()))?
        .as_any()
        .downcast_ref::<Literal>()
        .ok_or_else(|| DataFusionError::Execution("Downcasting argument 2 to Literal".into()))?
        .value();
    let nb_bins = if let ScalarValue::Int64(Some(nb_bins_value)) = nb_bins_arg {
        nb_bins_value
    } else {
        return Err(DataFusionError::Execution(format!(
            "arg 0 should be a int64, found {nb_bins_arg:?}"
        )));
    };

    Ok(Box::new(HistogramAccumulator::new(
        *start,
        *end,
        *nb_bins as usize,
    )))
}

pub fn make_histogram_arrow_type() -> DataType {
    DataType::Struct(Fields::from(state_arrow_fields()))
}

pub fn make_histo_udaf() -> AggregateUDF {
    create_udaf(
        "make_histogram",
        vec![
            DataType::Float64,
            DataType::Float64,
            DataType::Int64,
            DataType::Float64,
        ],
        Arc::new(make_histogram_arrow_type()),
        Volatility::Immutable,
        Arc::new(&make_state),
        Arc::new(vec![make_histogram_arrow_type()]),
    )
}