use crate::temex_error::TemexError;
use crate::temex_trace::TemexTrace;
use polars::frame::row::Row;
use polars::prelude::*;
use std::io;
use std::io::prelude::*;
fn is_boolean(df: &polars::frame::DataFrame) -> Result<(), TemexError> {
for field in df.fields().into_iter() {
if field.dtype != polars::datatypes::DataType::Boolean {
return Err(TemexError::PolarsNonBooleanColumn(
field.name.as_str().to_owned(),
));
}
}
Ok(())
}
fn get_df_labels(df: &polars::frame::DataFrame) -> Vec<String> {
df.get_column_names()
.into_iter()
.map(|x| x.to_owned())
.collect()
}
fn get_idx_mapping(input_labels: &[String]) -> Vec<usize> {
let mut sorted_labels = input_labels.to_vec();
sorted_labels.sort_unstable();
let mut input_idx_to_sorted: Vec<usize> = vec![];
for label in input_labels.iter() {
let mut sorted_iter = sorted_labels.iter();
let idx_in_sorted = sorted_iter.position(|x| x == label).unwrap();
input_idx_to_sorted.push(idx_in_sorted);
}
input_idx_to_sorted
}
impl TryFrom<polars::frame::DataFrame> for TemexTrace {
type Error = TemexError;
fn try_from(df: polars::frame::DataFrame) -> Result<Self, Self::Error> {
is_boolean(&df)?;
let labels = get_df_labels(&df);
let input_idx_to_sorted = get_idx_mapping(&labels);
let width = labels.len();
let mut data: Vec<u8> = vec![];
let mut row = polars::frame::row::Row::new(vec![AnyValue::Boolean(false); width]);
for i in 0..df.height() {
df.get_row_amortized(i, &mut row)?;
let mut trace_element: Vec<u8> = vec![b'0'; width];
for i in 0..row.0.len() {
if row.0[i] == polars::datatypes::AnyValue::Boolean(true) {
trace_element[input_idx_to_sorted[i]] = b'1';
} else {
}
}
trace_element.push(b'\n'); data.append(&mut trace_element);
}
Ok(TemexTrace { labels, data })
}
}
impl TryFrom<TemexTrace> for polars::frame::DataFrame {
type Error = TemexError;
fn try_from(trace: TemexTrace) -> Result<Self, Self::Error> {
let mut rows: Vec<Row> = vec![];
let buffer = io::BufReader::new(trace.data.as_slice());
for line_result in buffer.lines() {
let line = line_result?;
let row_vec: Vec<AnyValue> = line
.chars()
.map(|c| match c {
'1' => AnyValue::Boolean(true),
'0' => AnyValue::Boolean(false),
_ => unreachable!(),
})
.collect();
rows.push(Row::new(row_vec));
}
let mut df = polars::frame::DataFrame::from_rows(&rows)?;
df.set_column_names(&trace.labels)?;
Ok(df)
}
}
#[cfg(test)]
mod tests {
use super::*;
use polars::df;
#[test]
fn try_from_polars_works() {
let df = df!("p1" => &[true, true, true],
"p2" => &[false, false, false],
"p3" => &[true, false, true])
.unwrap();
let trace = TemexTrace::try_from(df.clone()).unwrap();
assert_eq!(
df,
polars::frame::DataFrame::try_from(trace.clone()).unwrap()
);
assert_eq!(
trace.labels,
vec!["p1".to_owned(), "p2".to_owned(), "p3".to_owned()]
);
assert_eq!(
std::string::String::from_utf8(trace.data).unwrap(),
"101\n100\n101\n".to_string()
);
}
}