use crate::error::ChartonError;
use polars::datatypes::DataType;
use polars::prelude::*;
use std::collections::HashMap;
use std::io::Cursor;
pub trait IntoChartonSource {
fn into_source(self) -> Result<DataFrameSource, ChartonError>;
}
#[derive(Clone)]
pub struct DataFrameSource {
pub(crate) df: DataFrame,
}
impl DataFrameSource {
pub(crate) fn new(df: DataFrame) -> Self {
Self { df }
}
pub fn from_parquet_bytes(bytes: &[u8]) -> Result<Self, ChartonError> {
let cursor = Cursor::new(bytes);
let df = ParquetReader::new(cursor).finish()?;
Ok(Self::new(df))
}
pub(crate) fn column(&self, name: &str) -> Result<Series, ChartonError> {
let col = self.df.column(name)?;
let series = col.as_materialized_series();
Ok(series.clone())
}
}
impl IntoChartonSource for &DataFrame {
fn into_source(self) -> Result<DataFrameSource, ChartonError> {
Ok(DataFrameSource::new(self.clone()))
}
}
impl IntoChartonSource for &LazyFrame {
fn into_source(self) -> Result<DataFrameSource, ChartonError> {
let df = self.clone().collect()?;
Ok(DataFrameSource::new(df))
}
}
impl IntoChartonSource for &[u8] {
fn into_source(self) -> Result<DataFrameSource, ChartonError> {
DataFrameSource::from_parquet_bytes(self)
}
}
impl IntoChartonSource for &Vec<u8> {
fn into_source(self) -> Result<DataFrameSource, ChartonError> {
let cursor = Cursor::new(self);
let df = ParquetReader::new(cursor).finish()?;
Ok(DataFrameSource::new(df))
}
}
impl IntoChartonSource for Vec<u8> {
fn into_source(self) -> Result<DataFrameSource, ChartonError> {
DataFrameSource::from_parquet_bytes(self.as_slice())
}
}
pub(crate) fn convert_numeric_types(
df_source: DataFrameSource,
) -> Result<DataFrameSource, ChartonError> {
let mut new_columns = Vec::new();
for col in df_source.df.get_columns() {
use polars::datatypes::DataType::*;
match col.dtype() {
UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Int128 | Float32
| Float64 => {
let casted = col.cast(&Float64)?;
new_columns.push(casted);
}
_ => {
new_columns.push(col.clone());
}
}
}
let new_df = DataFrame::new(new_columns)?;
Ok(DataFrameSource::new(new_df))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SemanticType {
Continuous,
Discrete,
Temporal,
}
pub(crate) fn interpret_semantic_type(dtype: &DataType) -> SemanticType {
match dtype {
DataType::Float32
| DataType::Float64
| DataType::Int32
| DataType::Int64
| DataType::UInt32
| DataType::UInt64 => SemanticType::Continuous,
DataType::Date | DataType::Datetime(_, _) | DataType::Time => SemanticType::Temporal,
DataType::String | DataType::Categorical(_, _) | DataType::Boolean => {
SemanticType::Discrete
}
_ => SemanticType::Discrete,
}
}
pub(crate) fn check_schema(
df: &mut DataFrame,
required_columns: &[&str],
expected_semantics: &HashMap<&str, Vec<SemanticType>>,
) -> Result<(), ChartonError> {
let schema = df.schema();
for &col_name in required_columns {
let actual_dtype = schema.get(col_name).ok_or_else(|| {
ChartonError::Encoding(format!("Column '{}' not found in DataFrame", col_name))
})?;
let actual_semantic = interpret_semantic_type(actual_dtype);
if let Some(allowed_semantics) = expected_semantics.get(col_name)
&& !allowed_semantics.contains(&actual_semantic)
{
return Err(ChartonError::Data(format!(
"Column '{}' (Type: {:?}) is categorized as {:?}, but expected one of {:?}",
col_name, actual_dtype, actual_semantic, allowed_semantics
)));
}
}
Ok(())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum AggregateOp {
#[default]
Sum,
Mean,
Median,
Min,
Max,
Count,
}
impl AggregateOp {
pub fn into_expr(&self, field: &str) -> Expr {
match self {
Self::Sum => col(field).sum(),
Self::Mean => col(field).mean(),
Self::Median => col(field).median(),
Self::Min => col(field).min(),
Self::Max => col(field).max(),
Self::Count => col(field).count(),
}
}
}
impl From<&str> for AggregateOp {
fn from(s: &str) -> Self {
match s.to_lowercase().as_str() {
"mean" | "avg" => Self::Mean,
"sum" => Self::Sum,
"min" => Self::Min,
"max" => Self::Max,
"count" | "n" => Self::Count,
"median" => Self::Median,
_ => Self::Sum,
}
}
}