use ndarray::{s, Array2};
use serde::Serialize;
use crate::{dataframe::TopN, error::Error};
use super::ColumnFrame;
#[derive(Debug, Clone, Serialize, PartialEq)]
pub struct SortedDataFrame<'a> {
pub row_indicies: Vec<usize>,
pub df: &'a ColumnFrame,
}
impl<'a> SortedDataFrame<'a> {
pub fn new(df: &'a ColumnFrame, row_indicies: Vec<usize>) -> Self {
Self { row_indicies, df }
}
pub fn get_sorted(&self) -> ColumnFrame {
let mut df = self.df.data_frame.clone();
self.row_indicies
.iter()
.enumerate()
.for_each(|(cur_idx, row_idx)| {
df.slice_mut(s![cur_idx, ..])
.assign(&self.df.data_frame.slice(s![*row_idx, ..]));
});
ColumnFrame::new(self.df.index.clone(), df)
}
pub fn topn(&self, topn: TopN) -> Result<ColumnFrame, Error> {
let nrows = self.df.len();
let ncols = self.df.index.len();
if nrows == 0 {
return Ok(ColumnFrame::new(self.df.index.clone(), Default::default()));
}
let arr = match topn {
TopN::First(topn) => {
let mut arr = Array2::default((topn, ncols));
let mut idx = 0;
for row_idx in self.row_indicies.iter().take(topn) {
arr.row_mut(idx)
.assign(&self.df.data_frame.slice(s![*row_idx, ..]));
idx += 1;
if idx == topn {
break;
}
}
arr
}
TopN::Last(topn) => {
let mut arr = Array2::default((topn, ncols));
let mut idx = 0;
for row_idx in self.row_indicies.iter().rev().take(topn) {
arr.row_mut(idx)
.assign(&self.df.data_frame.slice(s![*row_idx, ..]));
idx += 1;
if idx == topn {
break;
}
}
arr
}
};
Ok(ColumnFrame::new(self.df.index.clone(), arr))
}
}