use std::{
ops::{Add, Div},
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
};
use anyhow::anyhow;
use chrono::{NaiveDate, NaiveDateTime};
use itertools::{Itertools, izip};
use polars::{
frame::DataFrame,
prelude::{AnyValue, ChunkAgg, DataType, NamedFrom, SeriesMethods, TimeUnit},
series::{ChunkCompareEq, Series},
};
use ratatui::widgets::Cell;
use rayon::iter::{ParallelBridge, ParallelIterator};
use unicode_width::UnicodeWidthStr;
use crate::{AppResult, misc::ragged_vec::RaggedVec, tui::sheet::SheetSection};
use super::type_ext::HasSubsequence;
pub trait AnyValueExt {
fn into_single_line(self) -> String;
fn width(self, num_buffer: &mut NumBuffer) -> usize;
fn into_multi_line(self) -> String;
fn into_cell(self, width: usize) -> Cell<'static>;
fn fuzzy_cmp(self, other: &str) -> bool;
fn parse_bool(slice: &str) -> Option<AnyValue<'static>>;
fn parse_date(slice: &str, fmt: &str) -> Option<AnyValue<'static>>;
fn parse_datetime(slice: &str, fmt: &str) -> Option<AnyValue<'static>>;
}
impl AnyValueExt for AnyValue<'_> {
fn into_single_line(self) -> String {
match self {
AnyValue::Null => "".to_owned(),
AnyValue::StringOwned(v) => {
v.chars().map(|c| if c == '\t' { ' ' } else { c }).collect()
}
AnyValue::String(v) => v.chars().map(|c| if c == '\t' { ' ' } else { c }).collect(),
AnyValue::Categorical(idx, rev_map) => {
rev_map.cat_to_str(idx).unwrap_or_default().to_owned()
}
AnyValue::CategoricalOwned(idx, rev_map) => {
rev_map.cat_to_str(idx).unwrap_or_default().to_owned()
}
AnyValue::Binary(buf) => format!("Blob (Length: {})", buf.len()),
AnyValue::BinaryOwned(buf) => format!("Blob (Length: {})", buf.len()),
_ => self.to_string(),
}
}
fn width(self, num_buffer: &mut NumBuffer) -> usize {
match self {
AnyValue::Null => 0,
AnyValue::Boolean(v) => {
if v {
4 } else {
5 }
}
AnyValue::String(s) => s.lines().next().unwrap_or_default().width(),
AnyValue::UInt8(u) => num_buffer.itoa.format(u).len(),
AnyValue::UInt16(u) => num_buffer.itoa.format(u).len(),
AnyValue::UInt32(u) => num_buffer.itoa.format(u).len(),
AnyValue::UInt64(u) => num_buffer.itoa.format(u).len(),
AnyValue::UInt128(u) => num_buffer.itoa.format(u).len(),
AnyValue::Int8(i) => num_buffer.itoa.format(i).len(),
AnyValue::Int16(i) => num_buffer.itoa.format(i).len(),
AnyValue::Int32(i) => num_buffer.itoa.format(i).len(),
AnyValue::Int64(i) => num_buffer.itoa.format(i).len(),
AnyValue::Int128(i) => num_buffer.itoa.format(i).len(),
AnyValue::Float32(f) => num_buffer.ryu.format(f).len(),
AnyValue::Float64(f) => num_buffer.ryu.format(f).len(),
AnyValue::Date(_) => 10, AnyValue::Datetime(_, _, _) | AnyValue::DatetimeOwned(_, _, _) => 19, _ => self.to_string().width(),
}
}
fn into_multi_line(self) -> String {
match self {
AnyValue::Null => "".to_owned(),
AnyValue::StringOwned(v) => {
v.chars().map(|c| if c == '\t' { ' ' } else { c }).collect()
}
AnyValue::String(v) => v.chars().map(|c| if c == '\t' { ' ' } else { c }).collect(),
AnyValue::Categorical(idx, rev_map) => {
rev_map.cat_to_str(idx).unwrap_or_default().to_owned()
}
AnyValue::CategoricalOwned(idx, rev_map) => {
rev_map.cat_to_str(idx).unwrap_or_default().to_owned()
}
AnyValue::Binary(buf) => bytes_to_string(buf),
AnyValue::BinaryOwned(buf) => bytes_to_string(buf),
_ => self.to_string(),
}
}
fn into_cell(self, width: usize) -> Cell<'static> {
match self {
AnyValue::Float32(f) => Cell::new(format!("{f:>w$.2}", w = width)),
AnyValue::Float64(f) => Cell::new(format!("{f:>w$.2}", w = width)),
_ => Cell::new(self.into_single_line()),
}
}
fn fuzzy_cmp(self, other: &str) -> bool {
match self {
AnyValue::Null => false,
AnyValue::StringOwned(pl_small_str) => pl_small_str.has_subsequence(other),
AnyValue::String(val) => val.has_subsequence(other),
_ => self.into_multi_line().has_subsequence(other),
}
}
fn parse_bool(slice: &str) -> Option<AnyValue<'static>> {
match slice {
"true" => Some(AnyValue::Boolean(true)),
"false" => Some(AnyValue::Boolean(false)),
_ => None,
}
}
fn parse_date(slice: &str, fmt: &str) -> Option<AnyValue<'static>> {
NaiveDate::parse_from_str(slice, fmt)
.map(|date| {
const UNIX_EPOCH: NaiveDate = match NaiveDate::from_ymd_opt(1970, 1, 1) {
Some(date) => date,
None => unreachable!(),
};
AnyValue::Date(date.signed_duration_since(UNIX_EPOCH).num_days() as i32)
})
.ok()
}
fn parse_datetime(slice: &str, fmt: &str) -> Option<AnyValue<'static>> {
NaiveDateTime::parse_from_str(slice, fmt)
.map(|date| {
AnyValue::DatetimeOwned(
date.and_utc().timestamp_millis(),
TimeUnit::Milliseconds,
None,
)
})
.ok()
}
}
#[derive(Default, Clone)]
pub struct NumBuffer {
ryu: ryu::Buffer,
itoa: itoa::Buffer,
}
pub trait SeriesExt {
fn refine_to_string(&self) -> AppResult<Series>;
fn refine_to_int(&self) -> AppResult<Series>;
fn refine_to_float(&self) -> AppResult<Series>;
fn refine_to_bool(&self) -> AppResult<Series>;
fn refine_to_date(&self) -> AppResult<Series>;
fn refine_to_datetime(&self) -> AppResult<Series>;
}
impl SeriesExt for Series {
fn refine_to_string(&self) -> AppResult<Series> {
let casted = self.cast(&DataType::String)?;
if casted.is_null().equal(&self.is_null()).all() {
Ok(casted)
} else {
Err(anyhow!(
"Column '{}' cannot be refined to {}",
self.name(),
DataType::String
))
}
}
fn refine_to_int(&self) -> AppResult<Series> {
let casted = self.cast(&DataType::Int64)?;
if casted.is_null().equal(&self.is_null()).all() {
Ok(casted)
} else {
Err(anyhow!(
"Column '{}' cannot be refined to {}",
self.name(),
DataType::Int64
))
}
}
fn refine_to_float(&self) -> AppResult<Series> {
let casted = self.cast(&DataType::Float64)?;
if casted.is_null().equal(&self.is_null()).all() {
Ok(casted)
} else {
Err(anyhow!(
"Column '{}' cannot be refined to {}",
self.name(),
DataType::Float64
))
}
}
fn refine_to_bool(&self) -> AppResult<Series> {
self.try_map_all(|val| match val {
AnyValue::String(s) => AnyValue::parse_bool(s),
AnyValue::StringOwned(s) => AnyValue::parse_bool(s.as_str()),
AnyValue::Null => Some(AnyValue::Null),
_ => None,
})
.ok_or(anyhow!(
"Column '{}' cannot be refined to {}",
self.name(),
DataType::Boolean
))
}
fn refine_to_date(&self) -> AppResult<Series> {
[
"%Y-%m-%d", "%Y/%m/%d", "%Y.%m.%d", "%Y %m %d", "%Y%m%d", "%d-%m-%Y", "%d/%m/%Y",
"%d.%m.%Y", "%d %m %Y", "%d%m%Y", "%m-%d-%Y", "%m/%d/%Y", "%m.%d.%Y", "%m %d %Y",
"%m%d%Y", "%B %d %Y", "%B-%d-%Y", "%Y-%j",
]
.into_iter()
.find_map(|fmt| {
self.try_map_all(|val| match val {
AnyValue::String(s) => AnyValue::parse_date(s, fmt),
AnyValue::StringOwned(s) => AnyValue::parse_date(s.as_str(), fmt),
AnyValue::Null => Some(AnyValue::Null),
_ => None,
})
})
.ok_or(anyhow!(
"Column '{}' cannot be refined to {}",
self.name(),
DataType::Date
))
}
fn refine_to_datetime(&self) -> AppResult<Series> {
[
"%Y-%m-%d %H:%M:%S",
"%Y-%m-%dT%H:%M:%S",
"%Y-%m-%dT%H:%M:%S%.f",
"%Y/%m/%d %H:%M:%S",
"%Y %m %d %H:%M:%S",
"%Y.%m.%d %H:%M:%S",
"%d-%m-%Y %H:%M:%S",
"%d/%m/%Y %H:%M:%S",
"%d %m %Y %H:%M:%S",
"%d.%m.%Y %H:%M:%S",
"%m-%d-%Y %H:%M:%S",
"%m/%d/%Y %H:%M:%S",
"%m %d %Y %H:%M:%S",
"%m.%d.%Y %H:%M:%S",
"%B %d %Y %H:%M:%S",
"%B-%d-%Y %H:%M:%S",
"%Y%m%dT%H%M%S",
]
.into_iter()
.find_map(|fmt| {
self.try_map_all(|val| match val {
AnyValue::String(s) => AnyValue::parse_datetime(s, fmt),
AnyValue::StringOwned(s) => AnyValue::parse_datetime(s.as_str(), fmt),
AnyValue::Null => Some(AnyValue::Null),
_ => None,
})
})
.ok_or(anyhow!(
"Column '{}' cannot be refined to {}",
self.name(),
DataType::Datetime(TimeUnit::Milliseconds, None)
))
}
}
pub trait DataFrameExt {
fn widths(&self) -> Vec<usize>;
fn get_sheet_sections(&self, pos: usize) -> Vec<SheetSection>;
fn scatter_plot_data(&self, x_label: &str, y_label: &str) -> AppResult<RaggedVec<(f64, f64)>>;
#[allow(clippy::type_complexity)]
fn scatter_plot_data_grouped(
&self,
x_label: &str,
y_label: &str,
group_by: &str,
) -> AppResult<(RaggedVec<(f64, f64)>, Vec<String>)>;
fn histogram_plot_data(&self, col: &str, buckets: usize) -> AppResult<Vec<(String, u64)>>;
}
pub trait TryMapAll {
fn try_map_all(
&self,
f: impl Fn(AnyValue) -> Option<AnyValue<'static>> + Sync + Send + 'static,
) -> Option<Series>;
}
fn bytes_to_string(buf: impl AsRef<[u8]>) -> String {
let buf = buf.as_ref();
let index_width = buf.len().div(16).to_string().len();
let index_width = if index_width % 2 == 0 {
index_width
} else {
index_width + 1
};
format!(
"Blob (Length: {})\n{}",
buf.len(),
buf.iter()
.map(|b| format!("{b:02X}"))
.chunks(8)
.into_iter()
.map(|mut chunk| chunk.join(" "))
.chunks(2)
.into_iter()
.enumerate()
.map(|(idx, mut chunk)| format!("{:0index_width$}: {}", idx, chunk.join(" ")))
.join("\n")
)
}
impl DataFrameExt for DataFrame {
fn widths(&self) -> Vec<usize> {
self.columns()
.iter()
.map(|col| series_width(col.as_materialized_series()))
.collect()
}
fn get_sheet_sections(&self, pos: usize) -> Vec<SheetSection> {
izip!(
self.get_column_names().into_iter(),
self.get(pos)
.unwrap_or_default()
.into_iter()
.map(AnyValueExt::into_multi_line),
self.dtypes()
)
.map(|(header, content, dtype)| SheetSection::new(format!("{header} ({dtype})"), content))
.collect_vec()
}
fn scatter_plot_data(&self, x_label: &str, y_label: &str) -> AppResult<RaggedVec<(f64, f64)>> {
Ok(self
.column(x_label)?
.cast(&DataType::Float64)?
.f64()?
.iter()
.zip(
self.column(y_label)?
.cast(&DataType::Float64)?
.f64()?
.iter(),
)
.filter_map(|(x, y)| Some((x?, y?)))
.collect())
}
fn scatter_plot_data_grouped(
&self,
x_label: &str,
y_label: &str,
group_by: &str,
) -> AppResult<(RaggedVec<(f64, f64)>, Vec<String>)> {
let mut groups = Vec::new();
let mut data = RaggedVec::new();
for (name, df) in self
.partition_by(vec![group_by], true)?
.into_iter()
.map(|df| {
let name = df
.column(group_by)
.and_then(|column| column.get(0))
.map(AnyValueExt::into_single_line)
.unwrap_or("null".to_owned());
(name, df)
})
.sorted_by(|(a, _), (b, _)| a.cmp(b))
{
groups.push(name);
data.push(df.scatter_plot_data(x_label, y_label)?);
}
Ok((data, groups))
}
fn histogram_plot_data(&self, col_name: &str, buckets: usize) -> AppResult<Vec<(String, u64)>> {
let col = self.column(col_name)?;
match col.dtype() {
DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Int128
| DataType::Float32
| DataType::Float64
| DataType::Decimal(_, _) => continues_histogram(
col.as_materialized_series()
.value_counts(true, true, "value".into(), false)?,
buckets,
),
DataType::Boolean | DataType::String => discrete_histogram(
col.as_materialized_series()
.value_counts(true, true, "value".into(), false)?,
),
_ => Err(anyhow!("Unsupported column type"))?,
}
}
}
fn series_width(series: &Series) -> usize {
series.name().width().max(
series
.iter()
.par_bridge()
.fold_with((0_usize, NumBuffer::default()), |(width, mut buf), val| {
(width.max(val.width(&mut buf)), buf)
})
.map(|(w, _)| w)
.max()
.unwrap_or_default(),
)
}
impl TryMapAll for Series {
fn try_map_all(
&self,
cast: impl Fn(AnyValue) -> Option<AnyValue<'static>> + Sync + Send + 'static,
) -> Option<Series> {
let break_out = Arc::new(AtomicBool::new(false));
let mut new = vec![AnyValue::Null; self.len()];
std::thread::scope(|scope| {
let piece_len = if self.len() > num_cpus::get() {
self.len() / num_cpus::get()
} else {
1
};
for (idx, new_chunk) in new.chunks_mut(piece_len).enumerate() {
let offset = (idx * piece_len) as i64;
let break_out = break_out.clone();
let cast = &cast;
scope.spawn(move || {
let series = self.slice(offset, piece_len);
for (new_val, val) in new_chunk.iter_mut().zip(series.iter()) {
if let Some(parsed) = cast(val) {
*new_val = parsed;
} else {
break_out.store(true, Ordering::Relaxed);
break;
}
if break_out.load(Ordering::Relaxed) {
break;
}
}
});
}
});
(!break_out.load(Ordering::Relaxed)).then_some(Series::new(self.name().to_owned(), new))
}
}
fn discrete_histogram(mut counts: DataFrame) -> AppResult<Vec<(String, u64)>> {
counts.rechunk_mut();
Ok(counts[0]
.as_materialized_series()
.iter()
.map(AnyValue::into_single_line)
.zip(counts[1].as_materialized_series().u32()?.iter())
.map(|(v, c)| (v, c.unwrap_or_default() as u64))
.collect_vec())
}
fn continues_histogram(counts: DataFrame, buckets: usize) -> AppResult<Vec<(String, u64)>> {
let casted = counts[0].cast(&DataType::Float64)?;
let arr = casted.f64()?;
let (min, max) = arr.min_max().ok_or(anyhow!("No value found"))?;
let width = (max - min) / (buckets as f64);
let counts = arr
.iter()
.flatten()
.zip(counts[1].as_materialized_series().u32()?.iter().flatten())
.fold(vec![0; buckets], |mut buckets, (v, c)| {
let idx = (((v - min) / width) as usize).min(buckets.len().saturating_sub(1));
buckets[idx] += c;
buckets
});
let label_len = format!("{max:.2}").len();
Ok(counts
.into_iter()
.enumerate()
.map(|(idx, r)| {
let start = (idx as f64) * width + min;
let end = (idx.add(1) as f64) * width + min;
(
format!(" {start:>w$.2} - {end:>w$.2}", w = label_len),
r as u64,
)
})
.collect())
}