use std::collections::HashMap;
use std::fmt::Debug;
use crate::core::error::{Error, Result};
use crate::dataframe::base::DataFrame;
use crate::series::base::Series;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Axis {
Column = 0,
Row = 1,
}
pub trait ApplyExt {
fn apply<F>(&self, f: F, axis: Axis, result_name: Option<String>) -> Result<Series<String>>
where
F: Fn(&Series<String>) -> String;
fn applymap<F>(&self, f: F) -> Result<DataFrame>
where
F: Fn(&str) -> String;
fn mask<F>(&self, condition: F, other: &str) -> Result<DataFrame>
where
F: Fn(&str) -> bool;
fn where_func<F>(&self, condition: F, other: &str) -> Result<DataFrame>
where
F: Fn(&str) -> bool;
fn replace(&self, replace_map: &HashMap<String, String>) -> Result<DataFrame>;
fn duplicated(&self, subset: Option<&[String]>, keep: Option<&str>) -> Result<Series<bool>>;
fn drop_duplicates(&self, subset: Option<&[String]>, keep: Option<&str>) -> Result<DataFrame>;
}
impl ApplyExt for DataFrame {
fn apply<F>(&self, f: F, axis: Axis, result_name: Option<String>) -> Result<Series<String>>
where
F: Fn(&Series<String>) -> String,
{
match axis {
Axis::Column => {
let mut results = Vec::new();
for column_name in &self.column_names() {
let string_values = self.get_column_string_values(column_name)?;
let series = Series::new(string_values, Some(column_name.to_string()))?;
let result = f(&series);
results.push(result);
}
Series::new(results, result_name)
}
Axis::Row => {
let mut results = Vec::new();
for row_idx in 0..self.row_count() {
let mut row_values = Vec::new();
for column_name in &self.column_names() {
let string_values = self.get_column_string_values(column_name)?;
if row_idx < string_values.len() {
row_values.push(string_values[row_idx].clone());
}
}
let row_series = Series::new(row_values, Some(format!("row_{}", row_idx)))?;
let result = f(&row_series);
results.push(result);
}
Series::new(results, result_name)
}
}
}
fn applymap<F>(&self, f: F) -> Result<DataFrame>
where
F: Fn(&str) -> String,
{
let mut result = DataFrame::new();
for column_name in &self.column_names() {
let string_values = self.get_column_string_values(column_name)?;
let transformed_values: Vec<String> = string_values.iter().map(|val| f(val)).collect();
let new_series = Series::new(transformed_values, Some(column_name.to_string()))?;
result.add_column(column_name.to_string(), new_series)?;
}
Ok(result)
}
fn mask<F>(&self, condition: F, other: &str) -> Result<DataFrame>
where
F: Fn(&str) -> bool,
{
let mut result = DataFrame::new();
for column_name in &self.column_names() {
let string_values = self.get_column_string_values(column_name)?;
let masked_values: Vec<String> = string_values
.iter()
.map(|val| {
if condition(val) {
other.to_string()
} else {
val.clone()
}
})
.collect();
let new_series = Series::new(masked_values, Some(column_name.to_string()))?;
result.add_column(column_name.to_string(), new_series)?;
}
Ok(result)
}
fn where_func<F>(&self, condition: F, other: &str) -> Result<DataFrame>
where
F: Fn(&str) -> bool,
{
let mut result = DataFrame::new();
for column_name in &self.column_names() {
let string_values = self.get_column_string_values(column_name)?;
let where_values: Vec<String> = string_values
.iter()
.map(|val| {
if condition(val) {
val.clone()
} else {
other.to_string()
}
})
.collect();
let new_series = Series::new(where_values, Some(column_name.to_string()))?;
result.add_column(column_name.to_string(), new_series)?;
}
Ok(result)
}
fn replace(&self, replace_map: &HashMap<String, String>) -> Result<DataFrame> {
let mut result = DataFrame::new();
for column_name in &self.column_names() {
let string_values = self.get_column_string_values(column_name)?;
let replaced_values: Vec<String> = string_values
.iter()
.map(|val| replace_map.get(val).cloned().unwrap_or_else(|| val.clone()))
.collect();
let new_series = Series::new(replaced_values, Some(column_name.to_string()))?;
result.add_column(column_name.to_string(), new_series)?;
}
Ok(result)
}
fn duplicated(&self, subset: Option<&[String]>, keep: Option<&str>) -> Result<Series<bool>> {
let keep_option = keep.unwrap_or("first");
let mut result = vec![false; self.row_count()];
let columns_to_check = if let Some(subset_cols) = subset {
subset_cols.to_vec()
} else {
self.column_names()
};
for col_name in &columns_to_check {
if !self.contains_column(col_name) {
return Err(Error::ColumnNotFound(col_name.to_string()));
}
}
let mut row_data = Vec::new();
for row_idx in 0..self.row_count() {
let mut row_values = Vec::new();
for col_name in &columns_to_check {
let column_values = self.get_column_string_values(col_name)?;
if row_idx < column_values.len() {
row_values.push(column_values[row_idx].clone());
}
}
row_data.push(row_values);
}
match keep_option {
"first" => {
let mut seen = std::collections::HashSet::new();
for (idx, row) in row_data.iter().enumerate() {
if seen.contains(row) {
result[idx] = true;
} else {
seen.insert(row.clone());
}
}
}
"last" => {
let mut seen = std::collections::HashSet::new();
for (idx, row) in row_data.iter().enumerate().rev() {
if seen.contains(row) {
result[idx] = true;
} else {
seen.insert(row.clone());
}
}
}
"false" => {
let mut counts = std::collections::HashMap::new();
for row in &row_data {
*counts.entry(row.clone()).or_insert(0) += 1;
}
for (idx, row) in row_data.iter().enumerate() {
if counts[row] > 1 {
result[idx] = true;
}
}
}
_ => {
return Err(Error::InvalidValue(format!(
"Invalid keep option: {}. Must be 'first', 'last', or 'false'",
keep_option
)));
}
}
Series::new(result, Some("duplicated".to_string()))
}
fn drop_duplicates(&self, subset: Option<&[String]>, keep: Option<&str>) -> Result<DataFrame> {
let keep_option = keep.unwrap_or("first");
let columns_to_check = if let Some(subset_cols) = subset {
subset_cols.to_vec()
} else {
self.column_names()
};
for col_name in &columns_to_check {
if !self.contains_column(col_name) {
return Err(Error::ColumnNotFound(col_name.to_string()));
}
}
let mut row_data = Vec::new();
for row_idx in 0..self.row_count() {
let mut row_values = Vec::new();
for col_name in &columns_to_check {
let column_values = self.get_column_string_values(col_name)?;
if row_idx < column_values.len() {
row_values.push(column_values[row_idx].clone());
}
}
row_data.push(row_values);
}
let mut rows_to_keep = Vec::new();
match keep_option {
"first" => {
let mut seen = std::collections::HashSet::new();
for (idx, row) in row_data.iter().enumerate() {
if !seen.contains(row) {
seen.insert(row.clone());
rows_to_keep.push(idx);
}
}
}
"last" => {
let mut seen = std::collections::HashSet::new();
for (idx, row) in row_data.iter().enumerate().rev() {
if !seen.contains(row) {
seen.insert(row.clone());
rows_to_keep.push(idx);
}
}
rows_to_keep.reverse(); }
"false" => {
let mut counts = std::collections::HashMap::new();
for row in &row_data {
*counts.entry(row.clone()).or_insert(0) += 1;
}
for (idx, row) in row_data.iter().enumerate() {
if counts[row] == 1 {
rows_to_keep.push(idx);
}
}
}
_ => {
return Err(Error::InvalidValue(format!(
"Invalid keep option: {}. Must be 'first', 'last', or 'false'",
keep_option
)));
}
}
let mut result = DataFrame::new();
for column_name in &self.column_names() {
let column_values = self.get_column_string_values(column_name)?;
let filtered_values: Vec<String> = rows_to_keep
.iter()
.filter_map(|&row_idx| column_values.get(row_idx).cloned())
.collect();
let new_series = Series::new(filtered_values, Some(column_name.to_string()))?;
result.add_column(column_name.to_string(), new_series)?;
}
Ok(result)
}
}
pub use crate::dataframe::apply::Axis as LegacyAxis;