trs-dataframe 0.11.1

Dataframe library for Teiresias
Documentation
/// Row-oriented candidate storage and conversion utilities.
pub mod candidate;
/// Column-oriented dataframe storage, joins, keys, and indexing.
pub mod dataframe;
/// Error types used throughout the crate.
pub mod error;
pub use candidate::CandidateData;
pub use data_value;
pub use data_value::DataValue;
#[cfg(feature = "python")]
pub use dataframe::python::DataFrameOrDict;
pub use dataframe::DataFrame;
/// Expression-based row filtering for dataframes.
pub mod filter;
pub use dataframe::join::{JoinBy, JoinById, JoinRelation};
pub use dataframe::{
    column_store::{
        typed_array::{TypedData, TypedDataArray},
        ColumnFrame, KeyIndex, MaybeView,
    },
    index::hash_datavalue,
    key::Key,
};
/// Convenience alias for a string-keyed map of `DataValue` vectors.
pub type MLChefMap = halfbrown::HashMap<smartstring::alias::String, Vec<DataValue>>;
pub use ndarray;

#[cfg(feature = "jmalloc")]
#[global_allocator]
static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc;

#[cfg(feature = "polars-df")]
pub use polars;
/// Discriminant for the primitive type stored in a column or [`DataValue`].
#[derive(
    Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq, Hash, Default,
)]
#[cfg_attr(feature = "python", pyo3::pyclass(eq, eq_int))]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
pub enum DataType {
    Bool,
    U32,
    I32,
    U8,
    U64,
    I64,
    F32,
    F64,
    I128,
    U128,
    String,
    Bytes,
    #[default]
    Unknown,
    Vec,
    Map,
}

#[inline]
/// Autodetector for the data type from [`DataValue`]
pub fn detect_dtype(value: &DataValue) -> DataType {
    use DataValue::*;
    match value {
        Bool(_) => DataType::Bool,
        I32(_) => DataType::I32,
        U32(_) => DataType::U32,
        I64(_) => DataType::I64,
        U64(_) => DataType::U64,
        F32(_) => DataType::F32,
        F64(_) => DataType::F64,
        I128(_) => DataType::I128,
        U128(_) => DataType::U128,
        String(_) => DataType::String,
        Bytes(_) => DataType::Bytes,
        Vec(_) => DataType::Vec,
        Map(_) => DataType::Map,
        _ => DataType::Unknown,
    }
}
/// Scans a slice of [`DataValue`]s and returns the dominant [`DataType`].
///
/// Inspects up to three consecutive values of the same type before
/// short-circuiting. Heterogeneous slices return the type of the last
/// observed change.
pub fn detect_dtype_arr(value: &[DataValue]) -> DataType {
    let mut dtype = DataType::Unknown;
    let mut find_count = 3;
    for val in value {
        let new_dtype = detect_dtype(val);
        if new_dtype != dtype {
            dtype = new_dtype;
        } else if new_dtype == dtype && !matches!(dtype, DataType::Unknown) {
            find_count -= 1;
        }
        if find_count == 0 {
            break;
        }
    }

    dtype
}

#[cfg(feature = "python")]
use pyo3::prelude::*;

#[cfg(feature = "python")]
///
/// ```
/// use pyo3::prelude::*;
///
///  fn main() {
///     let result = pyo3::Python::with_gil(|py| -> PyResult<()> {
///         let module = PyModule::new(py, "trs_dataframe")?;
///         let _m = trs_dataframe::trs_dataframe(py, module)?;
///         Ok(())
///         });
///     assert!(result.is_ok(), "{:?}", result);
///  }
///
/// ```
#[pymodule]
pub fn trs_dataframe(_py: pyo3::Python<'_>, m: pyo3::Bound<'_, PyModule>) -> pyo3::PyResult<()> {
    m.add_class::<DataFrame>()?;
    m.add_class::<JoinRelation>()?;
    m.add_class::<Key>()?;
    Ok(())
}
#[cfg(test)]
mod test {
    use crate::dataframe::column_store::convert_data_value;

    use super::*;
    use rstest::*;

    #[rstest]
    #[case(DataType::Bool, DataValue::Bool(true))]
    #[case(DataType::I32, DataValue::I32(1))]
    #[case(DataType::U32, DataValue::U32(1))]
    #[case(DataType::I64, DataValue::I64(1))]
    #[case(DataType::U64, DataValue::U64(1))]
    #[case(DataType::F32, DataValue::F32(1.0))]
    #[case(DataType::F64, DataValue::F64(1.0))]
    #[case(DataType::U128, DataValue::U128(1))]
    #[case(DataType::I128, DataValue::I128(1))]
    #[case(DataType::String, DataValue::String("1".into()))]
    #[case(DataType::Bytes, DataValue::Bytes(b"1".to_vec()))]
    #[case(DataType::Vec, DataValue::Vec(vec![DataValue::I32(1)]))]
    #[case(DataType::Map, DataValue::Map(std::collections::HashMap::new()))]
    #[case(DataType::Unknown, DataValue::Null)]
    fn detection_test(#[case] dtype: DataType, #[case] value: DataValue) {
        assert_eq!(detect_dtype(&value), dtype);
        let serde_dtype: DataType =
            serde_json::from_str(&serde_json::to_string(&dtype).expect("BUG: cannot serialize"))
                .expect("BUG: cannot deserialize");
        assert_eq!(serde_dtype, dtype);
        let dt = convert_data_value(value.clone(), dtype);
        assert_eq!(dt, value);
    }

    #[test]
    fn detect_dtype_arr_unknown_for_empty() {
        assert_eq!(detect_dtype_arr(&[]), DataType::Unknown);
    }

    #[test]
    fn detect_dtype_arr_settles_on_repeated_dtype() {
        // Three matching readings short-circuit the loop and lock the dtype.
        let arr = vec![
            DataValue::I32(1),
            DataValue::I32(2),
            DataValue::I32(3),
            DataValue::I32(4),
            DataValue::F64(5.0), // never observed because the loop exits first.
        ];
        assert_eq!(detect_dtype_arr(&arr), DataType::I32);
    }

    #[test]
    fn detect_dtype_arr_overrides_on_change() {
        // The dtype keeps updating until 3 consecutive matches are seen.
        let arr = vec![DataValue::I32(1), DataValue::Null, DataValue::F64(1.0)];
        assert_eq!(detect_dtype_arr(&arr), DataType::F64);
    }

    #[test]
    fn detect_dtype_arr_ignores_repeated_unknown() {
        // Repeated Unknown does not decrement the find counter — verifies the
        // `!matches!(dtype, DataType::Unknown)` guard.
        let arr = vec![DataValue::Null, DataValue::Null, DataValue::Null];
        assert_eq!(detect_dtype_arr(&arr), DataType::Unknown);
    }
}