use crate::{Error, Result, Value};
use polars::prelude::*;
mod cumulative;
mod ewma;
mod group_by;
mod pivot;
mod rolling;
mod unpivot;
pub use cumulative::cumulative_agg;
pub use ewma::ewma;
pub use group_by::{group_by, group_by_agg};
pub use pivot::pivot;
pub use rolling::{rolling_agg, rolling_std, WindowFunction};
pub use unpivot::unpivot;
pub(super) fn any_value_to_value(any_val: &AnyValue) -> Result<Value> {
use serde_json::Value as JsonValue;
let json_val = match any_val {
AnyValue::Null => JsonValue::Null,
AnyValue::Boolean(b) => JsonValue::Bool(*b),
AnyValue::Int8(i) => JsonValue::Number(serde_json::Number::from(*i)),
AnyValue::Int16(i) => JsonValue::Number(serde_json::Number::from(*i)),
AnyValue::Int32(i) => JsonValue::Number(serde_json::Number::from(*i)),
AnyValue::Int64(i) => JsonValue::Number(serde_json::Number::from(*i)),
AnyValue::UInt8(i) => JsonValue::Number(serde_json::Number::from(*i)),
AnyValue::UInt16(i) => JsonValue::Number(serde_json::Number::from(*i)),
AnyValue::UInt32(i) => JsonValue::Number(serde_json::Number::from(*i)),
AnyValue::UInt64(i) => JsonValue::Number(serde_json::Number::from(*i)),
AnyValue::Float32(f) => JsonValue::Number(
serde_json::Number::from_f64(f64::from(*f))
.ok_or_else(|| Error::operation("Invalid float"))?,
),
AnyValue::Float64(f) => JsonValue::Number(
serde_json::Number::from_f64(*f).ok_or_else(|| Error::operation("Invalid float"))?,
),
AnyValue::String(s) => JsonValue::String((*s).to_string()),
_ => return Err(Error::operation("Unsupported AnyValue type")),
};
Ok(Value::from_json(json_val))
}
pub(super) fn df_to_array(df: &DataFrame) -> Result<Vec<Value>> {
let columns = df.get_column_names();
let mut result = Vec::with_capacity(df.height());
for row_idx in 0..df.height() {
let mut obj = std::collections::HashMap::new();
for col_name in &columns {
let series = df.column(col_name).map_err(Error::from)?;
let any_val = series.get(row_idx).map_err(Error::from)?;
let value = any_value_to_value(&any_val)?;
obj.insert(col_name.to_string(), value);
}
result.push(Value::Object(obj));
}
Ok(result)
}
#[derive(Debug, Clone)]
pub enum AggregationFunction {
Count,
Sum(String),
Mean(String),
Median(String),
Min(String),
Max(String),
Std(String),
Var(String),
First(String),
Last(String),
List(String),
CountUnique(String),
StringConcat(String, Option<String>), }
impl AggregationFunction {
pub fn to_polars_expr(&self) -> Result<Expr> {
match self {
AggregationFunction::Count => Ok(len().alias("count")),
AggregationFunction::Sum(col_name) => {
Ok(col(col_name).sum().alias(format!("{col_name}_sum")))
}
AggregationFunction::Mean(col_name) => {
Ok(col(col_name).mean().alias(format!("{col_name}_mean")))
}
AggregationFunction::Median(col_name) => {
Ok(col(col_name).median().alias(format!("{col_name}_median")))
}
AggregationFunction::Min(col_name) => {
Ok(col(col_name).min().alias(format!("{col_name}_min")))
}
AggregationFunction::Max(col_name) => {
Ok(col(col_name).max().alias(format!("{col_name}_max")))
}
AggregationFunction::Std(col_name) => {
Ok(col(col_name).std(1).alias(format!("{col_name}_std")))
}
AggregationFunction::Var(col_name) => {
Ok(col(col_name).var(1).alias(format!("{col_name}_var")))
}
AggregationFunction::First(col_name) => {
Ok(col(col_name).first().alias(format!("{col_name}_first")))
}
AggregationFunction::Last(col_name) => {
Ok(col(col_name).last().alias(format!("{col_name}_last")))
}
AggregationFunction::List(col_name) => {
Ok(col(col_name).alias(format!("{col_name}_list")))
}
AggregationFunction::CountUnique(col_name) => Ok(col(col_name)
.n_unique()
.alias(format!("{col_name}_nunique"))),
AggregationFunction::StringConcat(col_name, separator) => {
let _sep = separator.as_deref().unwrap_or(",");
Ok(col(col_name).alias(format!("{col_name}_concat")))
}
}
}
#[must_use]
pub fn output_column_name(&self) -> String {
match self {
AggregationFunction::Count => "count".to_string(),
AggregationFunction::Sum(col_name) => format!("{col_name}_sum"),
AggregationFunction::Mean(col_name) => format!("{col_name}_mean"),
AggregationFunction::Median(col_name) => format!("{col_name}_median"),
AggregationFunction::Min(col_name) => format!("{col_name}_min"),
AggregationFunction::Max(col_name) => format!("{col_name}_max"),
AggregationFunction::Std(col_name) => format!("{col_name}_std"),
AggregationFunction::Var(col_name) => format!("{col_name}_var"),
AggregationFunction::First(col_name) => format!("{col_name}_first"),
AggregationFunction::Last(col_name) => format!("{col_name}_last"),
AggregationFunction::List(col_name) => format!("{col_name}_list"),
AggregationFunction::CountUnique(col_name) => format!("{col_name}_nunique"),
AggregationFunction::StringConcat(col_name, _) => format!("{col_name}_concat"),
}
}
}
pub(super) fn compare_values_for_ordering(a: &Value, b: &Value) -> std::cmp::Ordering {
use std::cmp::Ordering;
match (a, b) {
(Value::Null, Value::Null) => Ordering::Equal,
(Value::Null, _) => Ordering::Less,
(_, Value::Null) => Ordering::Greater,
(Value::Bool(a), Value::Bool(b)) => a.cmp(b),
(Value::Int(a), Value::Int(b)) => a.cmp(b),
(Value::Float(a), Value::Float(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
(Value::String(a), Value::String(b)) => a.cmp(b),
#[allow(clippy::cast_precision_loss)]
(Value::Int(a), Value::Float(b)) => (*a as f64).partial_cmp(b).unwrap_or(Ordering::Equal),
#[allow(clippy::cast_precision_loss)]
(Value::Float(a), Value::Int(b)) => a.partial_cmp(&(*b as f64)).unwrap_or(Ordering::Equal),
_ => a.to_string().cmp(&b.to_string()),
}
}
#[cfg(test)]
mod tests;