use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::Hash;
use crate::core::error::Result;
use crate::na::NA;
use crate::series::{NASeries, Series};
pub use crate::series::categorical::{
Categorical as LegacyCategorical, CategoricalOrder as LegacyCategoricalOrder,
StringCategorical as LegacyStringCategorical,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CategoricalOrder {
Unordered,
Ordered,
}
#[derive(Debug, Clone)]
pub struct Categorical<T>
where
T: Debug + Clone + Eq + Hash,
{
_phantom: std::marker::PhantomData<T>,
categories_list: Vec<T>,
values: Vec<T>,
codes: Vec<i32>,
ordered_flag: bool,
category_to_code: HashMap<T, i32>,
}
impl<T> Categorical<T>
where
T: Debug + Clone + Eq + Hash,
{
pub fn new(values: Vec<T>, categories: Option<Vec<T>>, ordered: bool) -> Result<Self> {
let mut categories_list = if let Some(cats) = categories {
cats
} else {
let mut unique = Vec::new();
for v in &values {
if !unique.contains(v) {
unique.push(v.clone());
}
}
unique
};
let mut category_to_code: HashMap<T, i32> = HashMap::new();
for (i, cat) in categories_list.iter().enumerate() {
category_to_code.insert(cat.clone(), i as i32);
}
let mut codes = Vec::with_capacity(values.len());
for v in &values {
if let Some(&code) = category_to_code.get(v) {
codes.push(code);
} else {
let new_code = categories_list.len() as i32;
categories_list.push(v.clone());
category_to_code.insert(v.clone(), new_code);
codes.push(new_code);
}
}
Ok(Self {
_phantom: std::marker::PhantomData,
categories_list,
values,
codes,
ordered_flag: ordered,
category_to_code,
})
}
pub fn new_compact(values: Vec<T>, categories: Option<Vec<T>>, ordered: bool) -> Result<Self> {
let mut cat = Self::new(values, categories, ordered)?;
cat.values = Vec::new(); Ok(cat)
}
pub fn memory_usage_bytes(&self) -> usize {
let codes_size = self.codes.len() * std::mem::size_of::<i32>();
let categories_overhead = self.categories_list.len() * std::mem::size_of::<T>();
codes_size + categories_overhead
}
pub fn decode(&self) -> Vec<Option<T>> {
self.codes
.iter()
.map(|&code| {
if code < 0 {
None
} else {
self.categories_list.get(code as usize).cloned()
}
})
.collect()
}
pub fn encode(&self, values: &[T]) -> Vec<i32> {
values
.iter()
.map(|v| self.category_to_code.get(v).copied().unwrap_or(-1))
.collect()
}
pub fn num_categories(&self) -> usize {
self.categories_list.len()
}
pub fn contains_category(&self, value: &T) -> bool {
self.category_to_code.contains_key(value)
}
pub fn get_code(&self, value: &T) -> Option<i32> {
self.category_to_code.get(value).copied()
}
pub fn get_category(&self, code: i32) -> Option<&T> {
if code < 0 {
None
} else {
self.categories_list.get(code as usize)
}
}
pub fn remove_unused_categories(&mut self) -> Result<()> {
let mut used_codes: std::collections::HashSet<i32> = std::collections::HashSet::new();
for &code in &self.codes {
if code >= 0 {
used_codes.insert(code);
}
}
let mut new_categories = Vec::new();
let mut old_to_new: HashMap<i32, i32> = HashMap::new();
for (old_code, cat) in self.categories_list.iter().enumerate() {
if used_codes.contains(&(old_code as i32)) {
let new_code = new_categories.len() as i32;
old_to_new.insert(old_code as i32, new_code);
new_categories.push(cat.clone());
}
}
for code in &mut self.codes {
if *code >= 0 {
*code = old_to_new.get(code).copied().unwrap_or(-1);
}
}
self.category_to_code.clear();
for (i, cat) in new_categories.iter().enumerate() {
self.category_to_code.insert(cat.clone(), i as i32);
}
self.categories_list = new_categories;
Ok(())
}
pub fn factorize(&self) -> (Vec<i32>, Vec<T>) {
(self.codes.clone(), self.categories_list.clone())
}
pub fn from_na_vec(
values: Vec<NA<T>>,
categories: Option<Vec<T>>,
ordered: Option<CategoricalOrder>,
) -> Result<Self> {
let non_na_values: Vec<T> = values.iter().filter_map(|v| v.value().cloned()).collect();
Self::new(
non_na_values,
categories,
ordered.map_or(false, |o| matches!(o, CategoricalOrder::Ordered)),
)
}
pub fn categories(&self) -> &Vec<T> {
&self.categories_list
}
pub fn len(&self) -> usize {
self.values.len()
}
pub fn is_empty(&self) -> bool {
self.values.is_empty()
}
pub fn codes(&self) -> &Vec<i32> {
&self.codes
}
pub fn ordered(&self) -> CategoricalOrder {
if self.ordered_flag {
CategoricalOrder::Ordered
} else {
CategoricalOrder::Unordered
}
}
pub fn set_ordered(&mut self, order: CategoricalOrder) {
self.ordered_flag = matches!(order, CategoricalOrder::Ordered);
}
pub fn get(&self, index: usize) -> Option<&T> {
self.values.get(index)
}
pub fn to_series(&self, name: Option<String>) -> Result<Series<T>>
where
T: 'static + Clone + Debug + Send + Sync,
{
Series::new(self.values.clone(), name)
}
pub fn reorder_categories(&mut self, new_categories: Vec<T>) -> Result<()> {
self.categories_list = new_categories;
Ok(())
}
pub fn add_categories(&mut self, new_categories: Vec<T>) -> Result<()> {
for cat in new_categories {
if !self.categories_list.contains(&cat) {
self.categories_list.push(cat);
}
}
Ok(())
}
pub fn remove_categories(&mut self, categories_to_remove: &[T]) -> Result<()> {
self.categories_list
.retain(|cat| !categories_to_remove.contains(cat));
Ok(())
}
pub fn value_counts(&self) -> Result<Series<usize>> {
let mut counts = HashMap::new();
for value in &self.values {
*counts.entry(value).or_insert(0) += 1;
}
let mut values = Vec::new();
let mut indices = Vec::new();
for (val, count) in counts {
indices.push(format!("{:?}", val));
values.push(count);
}
Series::new(values, Some("count".to_string()))
}
pub fn to_na_vec(&self) -> Vec<NA<T>>
where
T: Clone,
{
self.values.iter().map(|v| NA::Value(v.clone())).collect()
}
pub fn to_na_series(&self, name: Option<String>) -> Result<NASeries<T>>
where
T: 'static + Clone + Debug + Send + Sync,
{
NASeries::new(self.to_na_vec(), name)
}
pub fn union(&self, other: &Self) -> Result<Self> {
let mut all_categories = self.categories_list.clone();
for cat in other.categories() {
if !all_categories.contains(cat) {
all_categories.push(cat.clone());
}
}
Self::new(self.values.clone(), Some(all_categories), self.ordered_flag)
}
pub fn intersection(&self, other: &Self) -> Result<Self> {
let mut common_categories = Vec::new();
for cat in self.categories() {
if other.categories().contains(cat) {
common_categories.push(cat.clone());
}
}
let filtered_values: Vec<T> = self
.values
.iter()
.filter(|v| common_categories.contains(v))
.cloned()
.collect();
Self::new(filtered_values, Some(common_categories), self.ordered_flag)
}
pub fn difference(&self, other: &Self) -> Result<Self> {
let mut diff_categories = Vec::new();
for cat in self.categories() {
if !other.categories().contains(cat) {
diff_categories.push(cat.clone());
}
}
let filtered_values: Vec<T> = self
.values
.iter()
.filter(|v| diff_categories.contains(v))
.cloned()
.collect();
Self::new(filtered_values, Some(diff_categories), self.ordered_flag)
}
}
pub type StringCategorical = Categorical<String>;