radiate-expr 1.2.22

A Rust library for genetic algorithms and artificial evolution.
Documentation
use crate::{AnyValue, DataType, Expr, ExprProjection, ExprQuery, ExprResult, value};
use radiate_error::radiate_bail;
use radiate_utils::{Slope, Statistic, WindowBuffer};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::fmt::Debug;

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum Rollup {
    First,
    Last,
    Mean,
    StdDev,
    Min,
    Max,
    Sum,
    Var,
    Skew,
    Count,
    Unique,
    Slope,
}

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Clone, Debug, PartialEq)]
pub struct AggExpr {
    pub(super) child: Box<Expr>,
    pub(super) rollup: Rollup,
    pub(super) buffer: Option<WindowBuffer<AnyValue<'static>>>,
}

impl AggExpr {
    pub fn new(child: Expr, rollup: Rollup) -> Self {
        Self {
            child: Box::new(child),
            rollup,
            buffer: None,
        }
    }

    pub fn rolling(mut self, window_size: usize) -> Self {
        self.buffer = Some(WindowBuffer::with_window(window_size));
        self
    }

    fn compute_rollup<'a>(
        values: &[AnyValue<'a>],
        rollup: Rollup,
        dtype: DataType,
    ) -> ExprResult<'a> {
        if values.is_empty() {
            return match rollup {
                Rollup::Count => Ok(AnyValue::UInt64(0)),
                _ => Ok(AnyValue::Float32(0.0)),
            };
        }

        if values.len() == 1 {
            return match rollup {
                Rollup::Count => Ok(AnyValue::UInt64(1)),
                Rollup::Unique => Ok(values[0].clone()),
                _ => Ok(values[0].clone()),
            };
        }

        if let Rollup::Unique = rollup {
            return Ok(value::dedup_slice(values));
        } else if let Rollup::Count = rollup {
            return Ok(AnyValue::UInt64(values.len() as u64));
        } else if let Rollup::First = rollup {
            return Ok(values[0].clone());
        } else if let Rollup::Last = rollup {
            return Ok(values[values.len() - 1].clone());
        } else if let Rollup::Slope = rollup {
            if values.len() < 2 {
                return Ok(AnyValue::Float32(0.0));
            }

            let slope = values
                .iter()
                .filter_map(|v| v.extract::<f32>())
                .collect::<Slope<f32>>();

            return Ok(AnyValue::Float32(slope.value().unwrap_or(0.0)));
        }

        let stats = values
            .iter()
            .filter_map(|val| val.extract::<f32>())
            .collect::<Statistic>();

        let result = match rollup {
            Rollup::Mean => AnyValue::Float32(stats.mean()),
            Rollup::StdDev => AnyValue::Float32(stats.std_dev().unwrap()),
            Rollup::Min => AnyValue::Float32(stats.min()),
            Rollup::Max => AnyValue::Float32(stats.max()),
            Rollup::Sum => AnyValue::Float32(stats.sum()),
            Rollup::Count => AnyValue::UInt64(stats.count() as u64),
            _ => AnyValue::Null,
        };

        return Ok(result.cast(&dtype).unwrap_or(AnyValue::Null));
    }
}

impl<T> ExprQuery<T> for AggExpr
where
    T: ExprProjection,
{
    fn dispatch<'a>(&'a mut self, input: &T) -> ExprResult<'a> {
        let child_output = self.child.dispatch(input)?;
        let dtype = child_output.dtype();

        if let Some(buffer) = &mut self.buffer {
            buffer.push(child_output.into_static());
            return Self::compute_rollup(buffer.values(), self.rollup, dtype);
        }

        match child_output {
            AnyValue::Slice(values) => Self::compute_rollup(values, self.rollup, dtype),
            AnyValue::Vector(values) => Self::compute_rollup(&values, self.rollup, dtype),
            _ => match self.rollup {
                Rollup::Count => Ok(AnyValue::UInt64(1)),
                Rollup::Unique => Ok(AnyValue::Vector(vec![child_output])),
                _ => Ok(child_output),
            },
        }
    }
}

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Clone, Debug, PartialEq)]
pub struct BufferExpr {
    pub(super) buffer: WindowBuffer<AnyValue<'static>>,
    pub(super) child: Box<Expr>,
    pub(super) dtype: DataType,
}

impl BufferExpr {
    pub fn new(child: Expr, window_size: usize) -> Self {
        Self {
            buffer: WindowBuffer::with_window(window_size),
            child: Box::new(child),
            dtype: DataType::Null,
        }
    }
}

impl<T> ExprQuery<T> for BufferExpr
where
    T: ExprProjection,
{
    fn dispatch<'a>(&'a mut self, input: &T) -> ExprResult<'a> {
        let child_output = self.child.dispatch(input)?.into_static();

        if child_output.is_nested() {
            radiate_bail!(Expr: "BufferExpr does not support nested values");
        }

        if self.dtype == DataType::Null {
            self.dtype = child_output.dtype();
        } else if self.dtype != child_output.dtype() {
            radiate_bail!(Expr:
                "BufferExpr received value of type {:?} but expected {:?}",
                child_output.dtype(),
                self.dtype
            );
        }

        self.buffer.push(child_output);
        Ok(AnyValue::Slice(&self.buffer.values()))
    }
}