temex 0.10.0

Regex-like temporal expressions for evaluating systems that change over time
Documentation
use crate::temex_error::TemexError;
use crate::temex_trace::TemexTrace;
use polars::frame::row::Row;
use polars::prelude::*;
use std::io;
use std::io::prelude::*;

// Ensure that the dataframe only has boolean values.
fn is_boolean(df: &polars::frame::DataFrame) -> Result<(), TemexError> {
    for field in df.fields().into_iter() {
        if field.dtype != polars::datatypes::DataType::Boolean {
            return Err(TemexError::PolarsNonBooleanColumn(
                field.name.as_str().to_owned(),
            ));
        }
    }

    Ok(())
}

// Get copy of the dataframe labels.
fn get_df_labels(df: &polars::frame::DataFrame) -> Vec<String> {
    df.get_column_names()
        .into_iter()
        .map(|x| x.to_owned())
        .collect()
}

// Get a mapping from the column indexes of the dataframe, which can be in
// arbitrary order, to what will be the indexes of the TemexTrace, whose
// columns will be in lexicographic order.
fn get_idx_mapping(input_labels: &[String]) -> Vec<usize> {
    let mut sorted_labels = input_labels.to_vec();
    sorted_labels.sort_unstable();

    let mut input_idx_to_sorted: Vec<usize> = vec![];

    for label in input_labels.iter() {
        let mut sorted_iter = sorted_labels.iter();
        let idx_in_sorted = sorted_iter.position(|x| x == label).unwrap();
        input_idx_to_sorted.push(idx_in_sorted);
    }
    input_idx_to_sorted
}

impl TryFrom<polars::frame::DataFrame> for TemexTrace {
    type Error = TemexError;

    fn try_from(df: polars::frame::DataFrame) -> Result<Self, Self::Error> {
        // ensure all the columns are boolean
        is_boolean(&df)?;

        let labels = get_df_labels(&df);
        let input_idx_to_sorted = get_idx_mapping(&labels);
        let width = labels.len();

        let mut data: Vec<u8> = vec![];

        let mut row = polars::frame::row::Row::new(vec![AnyValue::Boolean(false); width]);

        for i in 0..df.height() {
            df.get_row_amortized(i, &mut row)?;

            let mut trace_element: Vec<u8> = vec![b'0'; width];

            // convert to Vec<u8> and push to trace
            for i in 0..row.0.len() {
                if row.0[i] == polars::datatypes::AnyValue::Boolean(true) {
                    trace_element[input_idx_to_sorted[i]] = b'1';
                } else {
                    // do nothing; it's already been initialized to b'0'
                }
            }
            trace_element.push(b'\n'); // trace element delimiter
            data.append(&mut trace_element);
        }
        Ok(TemexTrace { labels, data })
    }
}

impl TryFrom<TemexTrace> for polars::frame::DataFrame {
    type Error = TemexError;

    fn try_from(trace: TemexTrace) -> Result<Self, Self::Error> {
        let mut rows: Vec<Row> = vec![];

        let buffer = io::BufReader::new(trace.data.as_slice());

        for line_result in buffer.lines() {
            let line = line_result?;
            let row_vec: Vec<AnyValue> = line
                .chars()
                .map(|c| match c {
                    '1' => AnyValue::Boolean(true),
                    '0' => AnyValue::Boolean(false),
                    _ => unreachable!(),
                })
                .collect();
            rows.push(Row::new(row_vec));
        }

        let mut df = polars::frame::DataFrame::from_rows(&rows)?;
        df.set_column_names(&trace.labels)?;

        Ok(df)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use polars::df;

    #[test]
    fn try_from_polars_works() {
        let df = df!("p1" => &[true, true, true],
                     "p2" => &[false, false, false],
                     "p3" => &[true, false, true])
        .unwrap();

        let trace = TemexTrace::try_from(df.clone()).unwrap();

        assert_eq!(
            df,
            polars::frame::DataFrame::try_from(trace.clone()).unwrap()
        );

        assert_eq!(
            trace.labels,
            vec!["p1".to_owned(), "p2".to_owned(), "p3".to_owned()]
        );

        assert_eq!(
            std::string::String::from_utf8(trace.data).unwrap(),
            "101\n100\n101\n".to_string()
        );
    }
}