pub mod candidate;
pub mod dataframe;
pub mod error;
pub use candidate::CandidateData;
pub use data_value;
pub use data_value::DataValue;
pub use dataframe::DataFrame;
pub mod filter;
pub use dataframe::join::{JoinBy, JoinById, JoinRelation};
pub use dataframe::{
column_store::{ColumnFrame, KeyIndex},
index::hash_datavalue,
key::Key,
};
pub type MLChefMap = halfbrown::HashMap<smartstring::alias::String, Vec<DataValue>>;
pub use ndarray;
#[cfg(feature = "jmalloc")]
#[global_allocator]
static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc;
#[cfg(feature = "polars-df")]
pub use polars;
#[derive(
Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq, Hash, Default,
)]
#[cfg_attr(feature = "python", pyo3::pyclass(eq, eq_int))]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
pub enum DataType {
Bool,
U32,
I32,
U8,
U64,
I64,
F32,
F64,
I128,
U128,
String,
Bytes,
#[default]
Unknown,
Vec,
Map,
}
pub fn detect_dtype(value: &DataValue) -> DataType {
use DataValue::*;
match value {
Bool(_) => DataType::Bool,
I32(_) => DataType::I32,
U32(_) => DataType::U32,
I64(_) => DataType::I64,
U64(_) => DataType::U64,
F32(_) => DataType::F32,
F64(_) => DataType::F64,
I128(_) => DataType::I128,
U128(_) => DataType::U128,
String(_) => DataType::String,
Bytes(_) => DataType::Bytes,
Vec(_) => DataType::Vec,
Map(_) => DataType::Map,
_ => DataType::Unknown,
}
}
pub fn detect_dtype_arr(value: &[DataValue]) -> DataType {
let mut dtype = DataType::Unknown;
let mut find_count = 3;
for val in value {
let new_dtype = detect_dtype(val);
if new_dtype != dtype {
dtype = new_dtype;
} else if new_dtype == dtype && !matches!(dtype, DataType::Unknown) {
find_count -= 1;
}
if find_count == 0 {
break;
}
}
dtype
}
#[cfg(feature = "python")]
use pyo3::prelude::*;
#[cfg(feature = "python")]
#[pymodule]
pub fn trs_dataframe(_py: pyo3::Python<'_>, m: pyo3::Bound<'_, PyModule>) -> pyo3::PyResult<()> {
m.add_class::<DataFrame>()?;
m.add_class::<JoinRelation>()?;
m.add_class::<Key>()?;
Ok(())
}
#[cfg(test)]
mod test {
use crate::dataframe::column_store::convert_data_value;
use super::*;
use rstest::*;
#[rstest]
#[case(DataType::Bool, DataValue::Bool(true))]
#[case(DataType::I32, DataValue::I32(1))]
#[case(DataType::U32, DataValue::U32(1))]
#[case(DataType::I64, DataValue::I64(1))]
#[case(DataType::U64, DataValue::U64(1))]
#[case(DataType::F32, DataValue::F32(1.0))]
#[case(DataType::F64, DataValue::F64(1.0))]
#[case(DataType::U128, DataValue::U128(1))]
#[case(DataType::I128, DataValue::I128(1))]
#[case(DataType::String, DataValue::String("1".into()))]
#[case(DataType::Bytes, DataValue::Bytes(b"1".to_vec()))]
#[case(DataType::Vec, DataValue::Vec(vec![DataValue::I32(1)]))]
#[case(DataType::Map, DataValue::Map(std::collections::HashMap::new()))]
#[case(DataType::Unknown, DataValue::Null)]
fn detection_test(#[case] dtype: DataType, #[case] value: DataValue) {
assert_eq!(detect_dtype(&value), dtype);
let serde_dtype: DataType =
serde_json::from_str(&serde_json::to_string(&dtype).expect("BUG: cannot serialize"))
.expect("BUG: cannot deserialize");
assert_eq!(serde_dtype, dtype);
let dt = convert_data_value(value.clone(), dtype);
assert_eq!(dt, value);
}
}