radiate-core 1.3.0

Core traits and interfaces for the Radiate genetic algorithm library.
Documentation
use super::{Evaluate, Expr, ExprResult};
use crate::stats::ExprSelector;
use radiate_error::radiate_bail;
use radiate_utils::{AnyValue, DataType, Quantile, Slope, Statistic, WindowBuffer, dedup_slice};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::fmt::Debug;

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Clone, Debug, PartialEq)]
pub enum Rollup {
    First,
    Last,
    Mean,
    StdDev,
    Min,
    Max,
    Sum,
    Var,
    Skew,
    Count,
    Unique,
    Slope,
    Quantile(Quantile<f32>),
}

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Clone, Debug, PartialEq)]
pub struct AggExpr {
    pub(super) child: Box<Expr>,
    pub(super) rollup: Rollup,
    pub(super) buffer: Option<WindowBuffer<AnyValue<'static>>>,
}

impl AggExpr {
    pub fn new(child: Expr, rollup: Rollup) -> Self {
        Self {
            child: Box::new(child),
            rollup,
            buffer: None,
        }
    }

    pub fn rolling(mut self, window_size: usize) -> Self {
        self.buffer = Some(WindowBuffer::with_capacity(window_size));
        self
    }

    pub(super) fn reset(&mut self) {
        if let Some(buf) = &mut self.buffer {
            buf.clear();
        }
        self.child.reset();

        if let Rollup::Quantile(q) = &mut self.rollup {
            q.clear();
        }
    }

    fn compute_rollup<'a>(
        values: &[AnyValue<'a>],
        rollup: &mut Rollup,
        dtype: DataType,
    ) -> ExprResult<'a> {
        if values.is_empty() {
            return match rollup {
                Rollup::Count => Ok(AnyValue::UInt64(0)),
                _ => Ok(AnyValue::Float32(0.0)),
            };
        }

        if values.len() == 1 {
            return match rollup {
                Rollup::Count => Ok(AnyValue::UInt64(1)),
                Rollup::Unique => Ok(values[0].clone()),
                _ => Ok(values[0].clone()),
            };
        }

        if let Rollup::Unique = rollup {
            return Ok(dedup_slice(values));
        } else if let Rollup::Count = rollup {
            return Ok(AnyValue::UInt64(values.len() as u64));
        } else if let Rollup::First = rollup {
            return Ok(values[0].clone());
        } else if let Rollup::Last = rollup {
            return Ok(values[values.len() - 1].clone());
        } else if let Rollup::Slope = rollup {
            if values.len() < 2 {
                return Ok(AnyValue::Float32(0.0));
            }

            let slope = values
                .iter()
                .filter_map(|v| v.extract::<f32>())
                .collect::<Slope<f32>>();

            return Ok(AnyValue::Float32(slope.value().unwrap_or(0.0)));
        } else if let Rollup::Quantile(quantile) = rollup {
            quantile.clear();
            for v in values.iter().filter_map(|v| v.extract::<f32>()) {
                if v.is_finite() {
                    quantile.add(v);
                }
            }

            return Ok(quantile
                .value()
                .map(AnyValue::Float32)
                .unwrap_or(AnyValue::Null));
        }

        let stats = values
            .iter()
            .filter_map(|val| val.extract::<f32>())
            .collect::<Statistic>();

        let result = match rollup {
            Rollup::Mean => AnyValue::Float32(stats.mean()),
            Rollup::StdDev => AnyValue::Float32(stats.std_dev().unwrap()),
            Rollup::Min => AnyValue::Float32(stats.min()),
            Rollup::Max => AnyValue::Float32(stats.max()),
            Rollup::Sum => AnyValue::Float32(stats.sum()),
            Rollup::Count => AnyValue::UInt64(stats.count() as u64),
            _ => AnyValue::Null,
        };

        Ok(result.cast(&dtype).unwrap_or(AnyValue::Null))
    }
}

impl<T> Evaluate<T> for AggExpr
where
    T: ExprSelector,
{
    fn eval<'a>(&'a mut self, metrics: &T) -> ExprResult<'a> {
        let child_output = self.child.eval(metrics)?;
        let dtype = child_output.dtype();

        if let Some(buffer) = &mut self.buffer {
            if child_output.is_nested() {
                radiate_bail!(Expr: "AggExpr with rolling window does not support nested values");
            }

            buffer.push(child_output.into_static());
            return Self::compute_rollup(buffer.values(), &mut self.rollup, dtype);
        }

        match child_output {
            AnyValue::Slice(values) => {
                let elem_dtype = if let DataType::List(inner) = dtype {
                    *inner
                } else {
                    dtype
                };
                Self::compute_rollup(values, &mut self.rollup, elem_dtype)
            }
            AnyValue::Vector(values) => {
                let elem_dtype = if let DataType::List(inner) = dtype {
                    *inner
                } else {
                    dtype
                };
                Self::compute_rollup(&values, &mut self.rollup, elem_dtype)
            }
            _ => match self.rollup {
                Rollup::Count => Ok(AnyValue::UInt64(1)),
                Rollup::Unique => Ok(AnyValue::Vector(vec![child_output])),
                Rollup::Quantile(ref mut q) => {
                    if let Some(v) = child_output.extract::<f32>() {
                        if v.is_finite() {
                            q.add(v);
                        }
                    } else {
                        return Ok(AnyValue::Null);
                    }

                    Ok(q.value().map(AnyValue::Float32).unwrap_or(AnyValue::Null))
                }
                _ => Ok(child_output),
            },
        }
    }
}