use pandrs::error::{PandRSError, Result};
use pandrs::series::{CategoricalOrder, StringCategorical};
use pandrs::DataFrame;
use std::collections::HashMap;
pub trait CategoricalExt {
#[allow(clippy::result_large_err)]
fn astype_categorical(
&self,
column_name: &str,
categories: Option<Vec<String>>,
ordered: Option<CategoricalOrder>,
) -> Result<Self>
where
Self: Sized;
#[allow(clippy::result_large_err)]
fn add_categorical_column(
&mut self,
column_name: String,
categorical: StringCategorical,
) -> Result<()>;
#[allow(clippy::result_large_err)]
fn set_categorical_ordered(&mut self, column_name: &str, order: CategoricalOrder)
-> Result<()>;
#[allow(clippy::result_large_err)]
fn get_categorical_aggregates(
&self,
column_names: &[&str],
value_column: &str,
aggregator: impl Fn(Vec<String>) -> Result<usize>,
) -> Result<HashMap<Vec<String>, usize>>;
#[allow(clippy::result_large_err)]
fn reorder_categories(&mut self, column_name: &str, new_categories: Vec<String>) -> Result<()>;
#[allow(clippy::result_large_err)]
fn add_categories(&mut self, column_name: &str, categories: Vec<String>) -> Result<()>;
#[allow(clippy::result_large_err)]
fn remove_categories(&mut self, column_name: &str, categories: &[String]) -> Result<()>;
}
impl CategoricalExt for DataFrame {
fn astype_categorical(
&self,
column_name: &str,
_categories: Option<Vec<String>>,
ordered: Option<CategoricalOrder>,
) -> Result<Self> {
let mut result = self.clone();
let order_bool = matches!(ordered, Some(CategoricalOrder::Ordered));
let values = result.get_column_string_values(column_name)?;
let cat = StringCategorical::new(values, None, order_bool)?;
let mut new_df = DataFrame::new();
for col in result.column_names() {
if col != column_name {
let series = result.get_column_string_values(&col)?;
new_df.add_column(
col.to_string(),
pandrs::series::Series::new(series, Some(col.to_string()))?,
)?;
}
}
let series = cat.to_series(Some(column_name.to_string()))?;
new_df.add_column(column_name.to_string(), series)?;
result = new_df;
Ok(result)
}
fn add_categorical_column(
&mut self,
column_name: String,
categorical: StringCategorical,
) -> Result<()> {
let series = categorical.to_series(Some(column_name.clone()))?;
self.add_column(column_name, series)
}
fn set_categorical_ordered(
&mut self,
column_name: &str,
order: CategoricalOrder,
) -> Result<()> {
if !self.contains_column(column_name) {
return Err(pandrs::error::PandRSError::ColumnNotFound(
column_name.to_string(),
));
}
let values = self.get_column_string_values(column_name)?;
let order_bool = match order {
CategoricalOrder::Ordered => true,
CategoricalOrder::Unordered => false,
};
let cat = StringCategorical::new(values, None, order_bool)?;
let mut new_df = DataFrame::new();
for col in self.column_names() {
if col != column_name {
let series = self.get_column_string_values(&col)?;
new_df.add_column(
col.to_string(),
pandrs::series::Series::new(series, Some(col.to_string()))?,
)?;
}
}
let series = cat.to_series(Some(column_name.to_string()))?;
new_df.add_column(column_name.to_string(), series)?;
*self = new_df;
Ok(())
}
fn get_categorical_aggregates(
&self,
column_names: &[&str],
value_column: &str,
aggregator: impl Fn(Vec<String>) -> Result<usize>,
) -> Result<HashMap<Vec<String>, usize>> {
for &col_name in column_names {
if !self.contains_column(col_name) {
return Err(pandrs::error::PandRSError::ColumnNotFound(
col_name.to_string(),
));
}
}
if !self.contains_column(value_column) {
return Err(pandrs::error::PandRSError::ColumnNotFound(
value_column.to_string(),
));
}
let row_count = self.row_count();
let mut grouped_data: HashMap<Vec<String>, Vec<String>> = HashMap::new();
for row_idx in 0..row_count {
let mut key = Vec::new();
for &col_name in column_names {
let values = self.get_column_string_values(col_name)?;
if row_idx < values.len() {
key.push(values[row_idx].clone());
}
}
let values = self.get_column_string_values(value_column)?;
if row_idx < values.len() {
grouped_data
.entry(key)
.or_default()
.push(values[row_idx].clone());
}
}
let mut result = HashMap::new();
for (key, values) in grouped_data {
let agg_value = aggregator(values)?;
result.insert(key, agg_value);
}
Ok(result)
}
#[allow(clippy::result_large_err)]
#[allow(clippy::result_large_err)]
fn reorder_categories(&mut self, column_name: &str, new_categories: Vec<String>) -> Result<()> {
if !self.contains_column(column_name) {
return Err(PandRSError::ColumnNotFound(column_name.to_string()));
}
let values = self.get_column_string_values(column_name)?;
let current_values: std::collections::HashSet<String> = values.iter().cloned().collect();
let new_cats_set: std::collections::HashSet<String> =
new_categories.iter().cloned().collect();
for value in ¤t_values {
if !new_cats_set.contains(value) {
return Err(PandRSError::InvalidValue(format!(
"Value '{value}' exists in column but not in new categories"
)));
}
}
let cat = StringCategorical::new(values, Some(new_categories), true)?;
let series = cat.to_series(Some(column_name.to_string()))?;
let mut new_df = DataFrame::new();
for col in self.column_names() {
if col != column_name {
let col_values = self.get_column_string_values(&col)?;
new_df.add_column(
col.to_string(),
pandrs::series::Series::new(col_values, Some(col.to_string()))?,
)?;
}
}
new_df.add_column(column_name.to_string(), series)?;
*self = new_df;
Ok(())
}
#[allow(clippy::result_large_err)]
#[allow(clippy::result_large_err)]
fn add_categories(&mut self, column_name: &str, categories: Vec<String>) -> Result<()> {
if !self.contains_column(column_name) {
return Err(PandRSError::ColumnNotFound(column_name.to_string()));
}
let values = self.get_column_string_values(column_name)?;
let mut current_categories: Vec<String> = values.to_vec();
current_categories.sort();
current_categories.dedup();
let mut all_categories = current_categories;
for cat in categories {
if !all_categories.contains(&cat) {
all_categories.push(cat);
}
}
let cat = StringCategorical::new(values, Some(all_categories), true)?;
let series = cat.to_series(Some(column_name.to_string()))?;
let mut new_df = DataFrame::new();
for col in self.column_names() {
if col != column_name {
let col_values = self.get_column_string_values(&col)?;
new_df.add_column(
col.to_string(),
pandrs::series::Series::new(col_values, Some(col.to_string()))?,
)?;
}
}
new_df.add_column(column_name.to_string(), series)?;
*self = new_df;
Ok(())
}
#[allow(clippy::result_large_err)]
#[allow(clippy::result_large_err)]
fn remove_categories(&mut self, column_name: &str, categories: &[String]) -> Result<()> {
if !self.contains_column(column_name) {
return Err(PandRSError::ColumnNotFound(column_name.to_string()));
}
let values = self.get_column_string_values(column_name)?;
let mut current_categories: Vec<String> = values.to_vec();
current_categories.sort();
current_categories.dedup();
let remove_set: std::collections::HashSet<&String> = categories.iter().collect();
let filtered_categories: Vec<String> = current_categories
.into_iter()
.filter(|cat| !remove_set.contains(cat))
.collect();
for value in &values {
if remove_set.contains(&value) {
return Err(PandRSError::InvalidValue(format!(
"Cannot remove category '{value}' as it has data"
)));
}
}
let cat = StringCategorical::new(values, Some(filtered_categories), true)?;
let series = cat.to_series(Some(column_name.to_string()))?;
let mut new_df = DataFrame::new();
for col in self.column_names() {
if col != column_name {
let col_values = self.get_column_string_values(&col)?;
new_df.add_column(
col.to_string(),
pandrs::series::Series::new(col_values, Some(col.to_string()))?,
)?;
}
}
new_df.add_column(column_name.to_string(), series)?;
*self = new_df;
Ok(())
}
}