1pub mod candidate;
2pub mod dataframe;
3pub mod error;
4pub use candidate::CandidateData;
5pub use data_value;
6pub use data_value::DataValue;
7pub use dataframe::DataFrame;
8pub mod filter;
9pub mod utils;
10pub use dataframe::join::{JoinBy, JoinById, JoinRelation};
11pub use dataframe::{
12 column_store::{ColumnFrame, KeyIndex},
13 key::Key,
14};
15pub type MLChefMap = halfbrown::HashMap<smartstring::alias::String, Vec<DataValue>>;
16pub use ndarray;
17
18#[cfg(feature = "jmalloc")]
19#[global_allocator]
20static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc;
21
22#[cfg(feature = "polars-df")]
23pub use polars;
24#[derive(
26 Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq, Hash, Default,
27)]
28#[cfg_attr(feature = "python", pyo3::pyclass(eq, eq_int))]
29#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
30pub enum DataType {
31 Bool,
32 U32,
33 I32,
34 U8,
35 U64,
36 I64,
37 F32,
38 F64,
39 I128,
40 U128,
41 String,
42 Bytes,
43 #[default]
44 Unknown,
45 Vec,
46 Map,
47}
48
49pub fn detect_dtype(value: &DataValue) -> DataType {
51 use DataValue::*;
52 match value {
53 Bool(_) => DataType::Bool,
54 I32(_) => DataType::I32,
55 U32(_) => DataType::U32,
56 I64(_) => DataType::I64,
57 U64(_) => DataType::U64,
58 F32(_) => DataType::F32,
59 F64(_) => DataType::F64,
60 I128(_) => DataType::I128,
61 U128(_) => DataType::U128,
62 String(_) => DataType::String,
63 Bytes(_) => DataType::Bytes,
64 Vec(_) => DataType::Vec,
65 Map(_) => DataType::Map,
66 _ => DataType::Unknown,
67 }
68}
69
70#[cfg(feature = "python")]
71use pyo3::prelude::*;
72
73#[cfg(feature = "python")]
74#[pymodule]
89pub fn trs_dataframe(_py: pyo3::Python<'_>, m: pyo3::Bound<'_, PyModule>) -> pyo3::PyResult<()> {
90 m.add_class::<DataFrame>()?;
91 m.add_class::<JoinRelation>()?;
92 m.add_class::<Key>()?;
93 Ok(())
94}
95#[cfg(test)]
96mod test {
97 use crate::dataframe::column_store::convert_data_value;
98
99 use super::*;
100 use rstest::*;
101
102 #[rstest]
103 #[case(DataType::Bool, DataValue::Bool(true))]
104 #[case(DataType::I32, DataValue::I32(1))]
105 #[case(DataType::U32, DataValue::U32(1))]
106 #[case(DataType::I64, DataValue::I64(1))]
107 #[case(DataType::U64, DataValue::U64(1))]
108 #[case(DataType::F32, DataValue::F32(1.0))]
109 #[case(DataType::F64, DataValue::F64(1.0))]
110 #[case(DataType::U128, DataValue::U128(1))]
111 #[case(DataType::I128, DataValue::I128(1))]
112 #[case(DataType::String, DataValue::String("1".into()))]
113 #[case(DataType::Bytes, DataValue::Bytes(b"1".to_vec()))]
114 #[case(DataType::Vec, DataValue::Vec(vec![DataValue::I32(1)]))]
115 #[case(DataType::Map, DataValue::Map(std::collections::HashMap::new()))]
116 #[case(DataType::Unknown, DataValue::Null)]
117 fn detection_test(#[case] dtype: DataType, #[case] value: DataValue) {
118 assert_eq!(detect_dtype(&value), dtype);
119 let serde_dtype: DataType =
120 serde_json::from_str(&serde_json::to_string(&dtype).expect("BUG: cannot serialize"))
121 .expect("BUG: cannot deserialize");
122 assert_eq!(serde_dtype, dtype);
123 let dt = convert_data_value(value.clone(), dtype);
124 assert_eq!(dt, value);
125 }
126}