Skip to main content

trs_dataframe/dataframe/
python.rs

1use std::collections::HashMap;
2
3use crate::{DataFrame, DataValue, JoinRelation, Key};
4use data_value::Extract as _;
5use ndarray::Array1;
6use numpy::{IntoPyArray, PyArray2};
7use pyo3::{
8    exceptions::PyTypeError,
9    prelude::*,
10    types::{PyBytes, PyList},
11    IntoPyObjectExt,
12};
13use tracing::trace;
14
15impl DataFrame {
16    fn select_data(
17        &self,
18        keys: Option<Vec<String>>,
19        transposed: Option<bool>,
20    ) -> Result<ndarray::Array2<DataValue>, crate::error::Error> {
21        let keys = keys
22            .map(|x| x.into_iter().map(Key::from).collect::<Vec<Key>>())
23            .unwrap_or(self.keys());
24        if transposed.unwrap_or(false) {
25            self.select(Some(keys.as_slice()))
26        } else {
27            self.select_transposed(Some(keys.as_slice()))
28        }
29    }
30}
31
32pub enum DataFrameOrDict {
33    DataFrame(DataFrame),
34    Dict(HashMap<String, DataValue>),
35}
36
37impl DataFrameOrDict {
38    pub fn new(object: Bound<'_, PyAny>) -> Result<DataFrameOrDict, PyErr> {
39        if let Ok(df) = object.extract::<DataFrame>() {
40            Ok(DataFrameOrDict::DataFrame(df))
41        } else if let Ok(df) = object.extract::<HashMap<String, Vec<DataValue>>>() {
42            Ok(DataFrameOrDict::DataFrame(DataFrame::from_dict(df)))
43        } else {
44            let dict: HashMap<String, DataValue> = object.extract()?;
45            Ok(DataFrameOrDict::Dict(dict))
46        }
47    }
48}
49
50impl From<DataFrameOrDict> for DataFrame {
51    fn from(value: DataFrameOrDict) -> Self {
52        match value {
53            DataFrameOrDict::DataFrame(df) => df,
54            DataFrameOrDict::Dict(dict) => DataFrame::from_dict(
55                dict.into_iter()
56                    .map(|(key, value)| (key, vec![value]))
57                    .collect::<HashMap<String, Vec<DataValue>>>(),
58            ),
59        }
60    }
61}
62#[pymethods]
63impl DataFrame {
64    /// Create a new empty DataFrame.
65    #[new]
66    pub fn init() -> Self {
67        Self::default()
68    }
69
70    /// Create a DataFrame from a polars dataframe in python.
71    /// ```text
72    /// import polars as pl
73    /// df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
74    /// tei_df = tdf.DataFrame.from_polars(df)
75    /// ```
76    #[cfg(feature = "polars-df")]
77    #[staticmethod]
78    pub fn from_polars(df: pyo3_polars::PyDataFrame) -> Self {
79        df.0.into()
80    }
81
82    /// Create a DataFrame from a dictionary.
83    /// ```text
84    /// df = tdf.DataFrame.from_dict({"a": [1, 2, 3], "b": [4, 5, 6]})
85    /// ```
86    #[staticmethod]
87    pub fn from_dict(df: HashMap<String, Vec<DataValue>>) -> Self {
88        let mut result_df: Vec<(Key, Vec<DataValue>)> = Vec::new();
89        for (key, value) in df.into_iter() {
90            let dtype = crate::detect_dtype_arr(&value);
91            let key = Key::new(key.as_str(), dtype);
92            result_df.push((key, value));
93        }
94
95        result_df.into()
96    }
97
98    /// Returns the keys of the DataFrame.
99    pub fn keys(&self) -> Vec<Key> {
100        self.dataframe.keys().to_vec()
101    }
102
103    /// Convert the DataFrame to polars DataFrame.
104    /// ```text
105    /// df = tdf.DataFrame.from_dict({"a": [1, 2, 3], "b": [4, 5, 6]});
106    /// df.set_dtype_for_column("a", trs.DataType.I32)
107    /// ```  
108    pub fn set_dtype_for_column(&mut self, key: String, dtype: crate::DataType) -> PyResult<()> {
109        self.dataframe
110            .enforce_dtype_for_column(key.as_str(), dtype)
111            .map_err(|e| {
112                PyErr::new::<PyTypeError, _>(format!("Cannot set dtype for columnĀ {key}: {e}"))
113            })
114    }
115
116    /// Convert the DataFrame to polars DataFrame.
117    /// This requires the `polars-df` feature to be enabled.
118    /// ```text
119    /// import polars as pl
120    /// original_df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
121    /// df = tdf.DataFrame.from_polars(original_df)
122    /// polars_df = df.as_polars()
123    /// assert polars_df.frame_equal(original_df)
124    /// ```  
125    #[cfg(feature = "polars-df")]
126    #[pyo3(name = "as_polars")]
127    pub fn py_as_polars(&self) -> PyResult<pyo3_polars::PyDataFrame> {
128        let df = self
129            .as_polars()
130            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot prepare polars DF: {e}")))?;
131        Ok(pyo3_polars::PyDataFrame(df))
132    }
133
134    /// Apply a function to the DataFrame.
135    /// The function should accept a DataFrame and return a DataFrame.
136    /// ```text
137    /// def my_function(df):
138    ///    # Perform some operations on the DataFrame
139    ///   return df
140    /// /// df = tdf.DataFrame.init()
141    /// df.apply(my_function)
142    /// ```
143    pub fn apply(&mut self, function: Bound<'_, PyAny>) -> Result<(), PyErr> {
144        let df: DataFrame = pyo3::Python::attach(|py| {
145            let self_ = self
146                .clone()
147                .into_pyobject(py)
148                .expect("BUG: cannot convert to PyObject");
149            let result = function.call1((self_,)).expect("BUG: cannot call function");
150            result
151                .extract::<Bound<DataFrame>>()
152                .expect("BUG: cannot extract data frame")
153                .unbind()
154                .extract(py)
155                .expect("BUG: cannot extract data frame")
156        });
157        self.dataframe = df.dataframe;
158        Ok(())
159    }
160
161    /// Returns slice from dataframe as numpy.array of uint32 of the given keys.
162    /// If `transposed` is true, the keys will be transposed.
163    /// If `keys` is None, all keys will be used.
164    /// ```text
165    /// import numpy as np
166    /// df = tdf.DataFrame.init()
167    /// df.push({"key1": 1, "key2": 2})
168    /// df.push({"key1": 11, "key2": 21})
169    /// a_np = df.as_numpy_u32(['key1', 'key2'])
170    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.uint32))
171    /// ```
172    #[pyo3(signature = (keys=None, transposed=None))]
173    pub fn as_numpy_u32<'py>(
174        &self,
175        keys: Option<Vec<String>>,
176        transposed: Option<bool>,
177        py: Python<'py>,
178    ) -> PyResult<Bound<'py, numpy::PyArray2<u32>>> {
179        let data = self
180            .select_data(keys, transposed)
181            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
182        Ok(PyArray2::from_array(py, &data.mapv(|x| u32::extract(&x))))
183    }
184
185    /// Returns slice from dataframe as numpy.array of uint64 of the given keys.
186    /// If `transposed` is true, the keys will be transposed.
187    /// If `keys` is None, all keys will be used.
188    /// ```text
189    /// import numpy as np
190    /// df = tdf.DataFrame.init()
191    /// df.push({"key1": 1, "key2": 2})
192    /// df.push({"key1": 11, "key2": 21})
193    /// a_np = df.as_numpy_u64(['key1', 'key2'])
194    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.uint64))
195    /// ```
196    #[pyo3(signature = (keys=None, transposed=None))]
197    pub fn as_numpy_u64<'py>(
198        &self,
199        keys: Option<Vec<String>>,
200        transposed: Option<bool>,
201        py: Python<'py>,
202    ) -> PyResult<Bound<'py, numpy::PyArray2<u64>>> {
203        let data = self
204            .select_data(keys, transposed)
205            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
206        Ok(PyArray2::from_array(py, &data.mapv(|x| u64::extract(&x))))
207    }
208
209    /// Returns slice from dataframe as numpy.array of int32 of the given keys.
210    /// If `transposed` is true, the keys will be transposed.
211    /// If `keys` is None, all keys will be used.
212    /// ```text
213    /// import numpy as np
214    /// df = tdf.DataFrame.init()
215    /// df.push({"key1": 1, "key2": 2})
216    /// df.push({"key1": 11, "key2": 21})
217    /// a_np = df.as_numpy_i32(['key1', 'key2'])
218    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.int32))
219    /// ```
220    #[pyo3(signature = (keys=None, transposed=None))]
221    pub fn as_numpy_i32<'py>(
222        &self,
223        keys: Option<Vec<String>>,
224        transposed: Option<bool>,
225        py: Python<'py>,
226    ) -> PyResult<Bound<'py, numpy::PyArray2<i32>>> {
227        let data = self
228            .select_data(keys, transposed)
229            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
230        Ok(PyArray2::from_array(py, &data.mapv(|x| i32::extract(&x))))
231    }
232
233    /// Returns slice from dataframe as numpy.array of int64 of the given keys.
234    /// If `transposed` is true, the keys will be transposed.
235    /// If `keys` is None, all keys will be used.
236    /// ```text
237    /// import numpy as np
238    /// df = tdf.DataFrame.init()
239    /// df.push({"key1": 1, "key2": 2})
240    /// df.push({"key1": 11, "key2": 21})
241    /// a_np = df.as_numpy_i64(['key1', 'key2'])
242    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.int64))
243    /// ```
244    #[pyo3(signature = (keys=None, transposed=None))]
245    pub fn as_numpy_i64<'py>(
246        &self,
247        keys: Option<Vec<String>>,
248        transposed: Option<bool>,
249        py: Python<'py>,
250    ) -> PyResult<Bound<'py, numpy::PyArray2<i64>>> {
251        let data = self
252            .select_data(keys, transposed)
253            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
254        Ok(PyArray2::from_array(py, &data.mapv(|x| i64::extract(&x))))
255    }
256
257    /// Returns slice from dataframe as numpy.array of float32 of the given keys.
258    /// If `transposed` is true, the keys will be transposed.
259    /// If `keys` is None, all keys will be used.
260    /// ```text
261    /// import numpy as np
262    /// df = tdf.DataFrame.init()
263    /// df.push({"key1": 1, "key2": 2})
264    /// df.push({"key1": 11, "key2": 21})
265    /// a_np = df.as_numpy_f32(['key1', 'key2'])
266    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.float32))
267    /// ```
268    #[pyo3(signature = (keys=None, transposed=None))]
269    pub fn as_numpy_f32<'py>(
270        &self,
271        keys: Option<Vec<String>>,
272        transposed: Option<bool>,
273        py: Python<'py>,
274    ) -> PyResult<Bound<'py, numpy::PyArray2<f32>>> {
275        let data = self
276            .select_data(keys, transposed)
277            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
278        Ok(PyArray2::from_array(py, &data.mapv(|x| f32::extract(&x))))
279    }
280
281    /// Returns slice from dataframe as numpy.array of float64 of the given keys.
282    /// If `transposed` is true, the keys will be transposed.
283    /// If `keys` is None, all keys will be used.
284    /// ```text
285    /// import numpy as np
286    /// df = tdf.DataFrame.init()
287    /// df.push({"key1": 1, "key2": 2})
288    /// df.push({"key1": 11, "key2": 21})
289    /// a_np = df.as_numpy_f64(['key1', 'key2'])
290    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.float64))
291    /// ```
292    #[pyo3(signature = (keys=None, transposed=None))]
293    pub fn as_numpy_f64<'py>(
294        &self,
295        keys: Option<Vec<String>>,
296        transposed: Option<bool>,
297        py: Python<'py>,
298    ) -> PyResult<Bound<'py, numpy::PyArray2<f64>>> {
299        let data = self
300            .select_data(keys, transposed)
301            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
302        Ok(PyArray2::from_array(py, &data.mapv(|x| f64::extract(&x))))
303    }
304
305    /// Returns slice from dataframe as numpy.array of objects of the given keys.
306    /// If `transposed` is true, the keys will be transposed.
307    /// If `keys` is None, all keys will be used.
308    /// ```text
309    /// import numpy as np
310    /// df = tdf.DataFrame.init()
311    /// df.push({"key1": "a", "key2": "b"})
312    /// df.push({"key1": "d", "key2": "c"})
313    /// a_np = df.as_numpy_f64(['key1', 'key2'])
314    /// assert np.array_equal(a_np, np.array([["a", d""], ["b", "c"]], dtype=np.object))
315    /// ```
316    #[pyo3(signature = (keys=None, transposed=None))]
317    pub fn as_numpy_str<'py>(
318        &self,
319        keys: Option<Vec<String>>,
320        transposed: Option<bool>,
321        py: Python<'py>,
322    ) -> PyResult<Bound<'py, numpy::PyArray2<Py<PyAny>>>> {
323        self.as_numpy(keys, transposed, py)
324    }
325
326    /// Returns slice from dataframe as numpy.array of float64 of the given keys.
327    /// If `transposed` is true, the keys will be transposed.
328    /// If `keys` is None, all keys will be used.
329    /// ```text
330    /// import numpy as np
331    /// df = tdf.DataFrame.init()
332    /// df.push({"key1": "a", "key2": 1})
333    /// df.push({"key1": "d", "key2": 2})
334    /// a_np = df.as_numpy_f64(['key1', 'key2'])
335    /// assert np.array_equal(a_np, np.array([["a", d""], [1, 2]], dtype=np.object))
336    /// ```
337    #[pyo3(signature = (keys=None, transposed=None))]
338    pub fn as_numpy<'py>(
339        &self,
340        keys: Option<Vec<String>>,
341        transposed: Option<bool>,
342        py: Python<'py>,
343    ) -> PyResult<Bound<'py, numpy::PyArray2<Py<PyAny>>>> {
344        let data = self
345            .select_data(keys, transposed)
346            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
347        let data = data.mapv(|x| {
348            String::extract(&x)
349                .into_py_any(py)
350                .expect("cannot convert string to py object")
351        });
352        Ok(data.into_pyarray(py))
353    }
354
355    #[pyo3(name = "shrink")]
356    pub fn py_shrink(&mut self) {
357        self.dataframe.shrink();
358    }
359
360    #[pyo3(name = "add_metadata")]
361    pub fn py_add_metadata(&mut self, key: String, value: DataValue) {
362        self.metadata.insert(key, value);
363    }
364
365    #[pyo3(name = "get_metadata")]
366    pub fn py_get_metadata(&self, key: &str) -> Option<DataValue> {
367        self.metadata.get(key).cloned()
368    }
369
370    #[pyo3(name = "rename_key")]
371    pub fn py_rename_key(&mut self, key: &str, new_name: &str) -> Result<(), PyErr> {
372        // fixme this may have a problem when the type is different and checked
373        self.dataframe
374            .rename_key(key, new_name.into())
375            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("{e}")))
376    }
377
378    #[pyo3(name = "add_alias")]
379    pub fn py_add_alias(&mut self, key: &str, new_name: &str) -> Result<(), PyErr> {
380        self.dataframe
381            .add_alias(key, new_name)
382            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("{e}")))
383    }
384
385    /// Selects data from the DataFrame.
386    /// If `keys` is None, all keys will be used.
387    /// If `keys` is provided, only the specified keys will be selected.
388    /// Returns a list of lists, where each inner list represents a row of data.
389    /// ```text
390    /// import trs_dataframe as tdf
391    /// df = tdf.DataFrame.init()
392    /// df.push({"key1": 1, "key2": 2})
393    /// df.push({"key1": 11, "key2": 21})
394    /// # selected = df.select(["key1", "key2"])
395    /// # assert selected == [[1, 2], [11, 21]]
396    /// # selected = df.select()
397    #[pyo3(name = "select", signature = (keys=None, transposed=None))]
398    pub fn py_select<'py>(
399        &self,
400        py: Python<'py>,
401        keys: Option<Vec<String>>,
402        transposed: Option<bool>,
403    ) -> Result<Bound<'py, PyList>, PyErr> {
404        let keys = keys
405            .map(|x| x.into_iter().map(Key::from).collect::<Vec<Key>>())
406            .unwrap_or(self.keys());
407
408        let selected = if transposed.unwrap_or_default() {
409            self.select_transposed(Some(keys.as_slice()))
410                .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?
411        } else {
412            self.select(Some(keys.as_slice()))
413                .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?
414        };
415
416        let list = PyList::empty(py);
417        for rows in selected.rows() {
418            let row = PyList::empty(py);
419            for value in rows.iter() {
420                row.append(value.clone())
421                    .expect("BUG: cannot append to list");
422            }
423            list.append(row).expect("BUG: cannot append to list");
424        }
425        Ok(list)
426    }
427
428    /// Selects a column from the DataFrame.
429    /// If the column does not exist, it will raise a TypeError.
430    /// Returns a list of values in the selected column.
431    /// ```text
432    /// import trs_dataframe as tdf
433    /// df = tdf.DataFrame.init()
434    /// df.push({"key1": 1, "key2": 2})
435    /// df.push({"key1": 11, "key2": 21})
436    /// # selected = df.select_column("key1")
437    /// # assert selected == [1, 11]
438    /// # selected = df.select_column("key2")
439    /// # assert selected == [2, 21]
440    /// # selected = df.select_column("non_existing_key")  # Raises TypeError
441    /// ```
442    #[pyo3(name = "select_column")]
443    #[allow(deprecated)]
444    pub fn py_select_column<'py>(
445        &self,
446        py: Python<'py>,
447        key: String,
448    ) -> Result<Bound<'py, PyList>, PyErr> {
449        let selected = self
450            .select_column(Key::from(key))
451            .ok_or_else(|| PyErr::new::<PyTypeError, _>("Cannot select column"))?;
452
453        let list = PyList::empty(py);
454        for x in selected.to_vec().into_iter() {
455            list.append(x)?;
456        }
457
458        Ok(list)
459    }
460
461    /// Joins the current DataFrame with another DataFrame.
462    /// The join type is specified by the `join_type` parameter.
463    /// see [`JoinRelation`] for available join types.
464    /// ```text
465    /// import trs_dataframe as tdf
466    /// df1 = tdf.DataFrame.init()
467    /// df1.push({"key1": 1, "key2": 2})
468    /// df1.push({"key1": 11, "key2": 21})
469    /// df2 = tdf.DataFrame.init()
470    /// df2.push({"key1": 1, "key2": 3})
471    /// df2.push({"key1": 11, "key2": 23})
472    /// df1.join(df2, tei.JoinRelation.extend())
473    /// assert df1.select(["key1", "key2"]) == [[1, 2], [11, 21], [1, 3], [11, 23]]
474    /// ```
475    #[pyo3(name = "join")]
476    pub fn py_join(&mut self, other: DataFrame, join_type: JoinRelation) -> Result<(), PyErr> {
477        self.dataframe
478            .join(other.dataframe, &join_type)
479            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot join data: {e}")))?;
480
481        Ok(())
482    }
483
484    /// Pushes a new row of data into the DataFrame.
485    /// The data should be provided as a dictionary where keys are column names and values are the corresponding data values.
486    /// ```text
487    /// import trs_dataframe as tdf
488    /// df = tdf.DataFrame.init()
489    /// df.push({"key1": 1, "key2": 2})
490    /// df.push({"key1": 11, "key2": 21})
491    /// ```
492    #[pyo3(name = "push")]
493    pub fn py_push(&mut self, data: HashMap<Key, DataValue>) -> Result<(), PyErr> {
494        self.dataframe
495            .push(data)
496            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot join data: {e}")))?;
497        Ok(())
498    }
499
500    /// Adds a new column to the DataFrame.
501    /// The column is specified by a key and a vector of data values.
502    /// If the length of the data vector does not match the number of rows in the DataFrame, it will raise a TypeError.
503    /// ```text
504    /// import trs_dataframe as tdf
505    /// df = tdf.DataFrame.init()
506    /// df.push({"key1": 1, "key2": 2})
507    /// df.push({"key1": 11, "key2": 21})
508    /// df.add_column("key3", [3, 4])
509    /// assert df.select(["key1", "key2", "key3"]) == [[1, 2, 3], [11, 21, 4]]
510    /// ```
511    #[pyo3(name = "add_column")]
512    pub fn py_add_column(&mut self, key: Key, data: Vec<DataValue>) -> Result<(), PyErr> {
513        self.dataframe
514            .add_single_column(key, Array1::from_vec(data))
515            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot join data: {e}")))?;
516        Ok(())
517    }
518
519    pub fn add_constant(&mut self, key: Key, feature: DataValue) -> Result<(), PyErr> {
520        self.constants.insert(key, feature);
521        Ok(())
522    }
523
524    /// Filters the DataFrame by a given expression.
525    /// The expression should be a string that can be parsed by the DataFrame's filter method
526    ///
527    /// ```text
528    /// import trs_dataframe as tdf
529    /// df = tdf.DataFrame.init()
530    /// df.push({"key1": 1, "key2": 2})
531    /// df.push({"key1": 11, "key2": 21})
532    /// df.filter_by_expression("key1 > 5")
533    /// assert df.select(["key1", "key2"]) == [[11, 21 ]]
534    /// ```
535    pub fn filter_by_expression(&mut self, expression: String) -> Result<Self, PyErr> {
536        let filter = crate::filter::FilterRules::try_from(expression.as_str())
537            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot parse expression: {e}")))?;
538        self.filter(&filter)
539            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot filter data: {e}")))
540    }
541
542    fn __repr__(&self) -> String {
543        self.to_string()
544    }
545
546    fn __str__(&self) -> String {
547        self.to_string()
548    }
549
550    pub fn __iadd__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
551        trace!("{object:?}");
552        let df_or_dict = DataFrameOrDict::new(object)?;
553        match df_or_dict {
554            DataFrameOrDict::DataFrame(df) => {
555                self.dataframe += df.dataframe;
556            }
557            DataFrameOrDict::Dict(dict) => {
558                self.dataframe += dict;
559            }
560        }
561        Ok(())
562    }
563
564    pub fn __isub__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
565        trace!("{object:?}");
566
567        let df_or_dict = DataFrameOrDict::new(object)?;
568        match df_or_dict {
569            DataFrameOrDict::DataFrame(df) => {
570                self.dataframe -= df.dataframe;
571            }
572            DataFrameOrDict::Dict(dict) => {
573                self.dataframe -= dict;
574            }
575        }
576        Ok(())
577    }
578
579    pub fn __imul__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
580        trace!("{object:?}");
581        let df_or_dict = DataFrameOrDict::new(object)?;
582        match df_or_dict {
583            DataFrameOrDict::DataFrame(df) => {
584                self.dataframe *= df.dataframe;
585            }
586            DataFrameOrDict::Dict(dict) => {
587                self.dataframe *= dict;
588            }
589        }
590        Ok(())
591    }
592
593    pub fn __itruediv__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
594        trace!("{object:?}");
595        let df_or_dict = DataFrameOrDict::new(object)?;
596        match df_or_dict {
597            DataFrameOrDict::DataFrame(df) => {
598                self.dataframe /= df.dataframe;
599            }
600            DataFrameOrDict::Dict(dict) => {
601                self.dataframe /= dict;
602            }
603        }
604        Ok(())
605    }
606
607    pub fn __len__(&mut self) -> Result<usize, PyErr> {
608        Ok(self.dataframe.nrows())
609    }
610
611    pub fn serialize_to_json_string(&self) -> String {
612        serde_json::to_string(self).expect("Cannot serialize to strinng")
613    }
614
615    #[staticmethod]
616    pub fn deserialize_from_json_string(json_df: String) -> Self {
617        let mut df: DataFrame =
618            serde_json::from_str(json_df.as_str()).expect("Cannot deserialize from str");
619        let _ = df.dataframe.try_fix_dtype();
620
621        df
622    }
623
624    // derive Serialize and Deserialize
625    pub fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
626        let s: DataFrame = rmp_serde::decode::from_slice(state.as_bytes()).map_err(|e| {
627            pyo3::PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
628                "Cannot deserialize object {e}"
629            ))
630        })?;
631        *self = s;
632        self.dataframe.try_fix_dtype().map_err(|e| {
633            pyo3::PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
634                "Cannot deserialize object {e}"
635            ))
636        })?;
637        Ok(())
638    }
639    pub fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
640        let buf = rmp_serde::encode::to_vec(self).map_err(|e| {
641            // let buf = serde_json::to_string(self).map_err(|e| {
642            pyo3::PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
643                "Cannot deserialize object {e}"
644            ))
645        })?;
646        Ok(PyBytes::new(py, &buf))
647    }
648
649    pub fn __del__(&mut self) {
650        self.dataframe = Default::default();
651    }
652
653    // pub fn __reduce__(&self, py: Python<'_>) -> PyResult<(Py<PyAny>, Py<PyAny>)> {
654    //     let cls = py.get_type::<Self>();
655    //     Ok((
656    //         cls.into(),
657    //         pyo3::types::PyTuple::new(py, &[self.__getstate__(py)?])?
658    //             .into_any()
659    //             .unbind(),
660    //     ))
661    // }
662}
663
664#[cfg(test)]
665mod test {
666
667    use super::*;
668    use crate::DataType;
669    use data_value::{stdhashmap, DataValue};
670    use halfbrown::hashmap;
671    use pyo3::ffi::c_str;
672    use rstest::*;
673    use tracing_test::traced_test;
674
675    #[fixture]
676    fn df() -> DataFrame {
677        let mut df = DataFrame::init();
678        assert!(df
679            .push(hashmap! {
680                Key::new("key1", DataType::U32) => DataValue::U32(1),
681                Key::new("key2", DataType::U32) => DataValue::U32(2),
682            })
683            .is_ok());
684        assert!(df
685            .push(hashmap! {
686                Key::from("key1") => DataValue::U32(11),
687                Key::from("key2") => DataValue::U32(21),
688            })
689            .is_ok());
690        df
691    }
692
693    #[fixture]
694    fn hm() -> HashMap<String, DataValue> {
695        stdhashmap!(
696            "key1".to_string() => DataValue::U32(2),
697            "key2".to_string() => DataValue::U32(3),
698        )
699    }
700
701    #[rstest]
702    fn serde_py(df: DataFrame) {
703        let str_df = df.serialize_to_json_string();
704        assert!(!str_df.is_empty());
705
706        let loaded = DataFrame::deserialize_from_json_string(str_df);
707
708        assert_eq!(loaded, df);
709    }
710    #[cfg(feature = "python")]
711    #[rstest]
712    fn pickle_py(df: DataFrame) {
713        pyo3::Python::attach(|py| {
714            let bytes = df.__getstate__(py);
715            assert!(bytes.is_ok());
716
717            let mut deser = DataFrame::default();
718            assert!(deser.__setstate__(bytes.unwrap().into()).is_ok());
719            assert_eq!(deser, df);
720        });
721    }
722    #[rstest]
723    fn test_select_data(df: DataFrame) {
724        let data = df.select_data(Some(vec!["key1".into(), "key2".into()]), Some(false));
725        assert!(data.is_ok());
726        assert_eq!(
727            data.unwrap(),
728            ndarray::array![[1u32.into(), 11u32.into()], [2u32.into(), 21u32.into()]]
729        );
730
731        let data = df.select_data(Some(vec!["key1".into(), "key2".into()]), Some(true));
732        assert!(data.is_ok());
733        assert_eq!(
734            data.unwrap(),
735            ndarray::array![[1u32.into(), 2u32.into()], [11u32.into(), 21u32.into()]]
736        );
737    }
738
739    #[cfg(feature = "python")]
740    #[rstest]
741    fn test_from_create() {
742        pyo3::Python::attach(|_py| {
743            let mut hm: HashMap<String, Vec<DataValue>> = Default::default();
744            let value: Vec<DataValue> = vec![1i32.into(), 22i32.into()];
745            hm.insert("a".into(), value);
746
747            let mut df = DataFrame::from_dict(hm);
748            assert_eq!(
749                df.select(Some(&["a".into()])),
750                Ok(ndarray::array![
751                    [DataValue::from(1i32)],
752                    [DataValue::from(22i32)]
753                ]),
754            );
755            assert!(df.set_dtype_for_column("a".into(), DataType::U32).is_ok());
756            assert_eq!(
757                df.select(Some(&["a".into()])),
758                Ok(ndarray::array![
759                    [DataValue::from(1u32)],
760                    [DataValue::from(22u32)]
761                ]),
762            );
763        });
764        #[cfg(feature = "polars-df")]
765        {
766            let pdf = polars::df!(
767                "a" => [1u64, 2u64, 3u64],
768                "b" => [4f64, 5f64, 6f64],
769                "c" => [7i64, 8i64, 9i64]
770            )
771            .expect("BUG: should be ok");
772            let df = DataFrame::from_polars(pyo3_polars::PyDataFrame(pdf));
773            assert_eq!(
774                df.select(Some(&["a".into(), "b".into(), "c".into()])),
775                crate::df! {
776                    "a" => [1u64, 2u64, 3u64],
777                    "b" => [4f64, 5f64, 6f64],
778                    "c" => [7i64, 8i64, 9i64]
779                }
780                .select(Some(&["a".into(), "b".into(), "c".into()])),
781            );
782            let keys = df.keys();
783            assert_eq!(
784                keys,
785                vec![
786                    Key::new("a", DataType::U64),
787                    Key::new("b", DataType::F64),
788                    Key::new("c", DataType::I64),
789                ]
790            )
791        }
792    }
793
794    #[rstest]
795    #[traced_test]
796    fn basic_ops_add(mut df: DataFrame, hm: HashMap<String, DataValue>) {
797        let mut df_expect = df.clone();
798        let df2 = df.clone();
799        let exec = Python::attach(|py| -> PyResult<()> {
800            df.__iadd__(df.clone().into_pyobject(py)?.into_any())?;
801            df_expect.dataframe += df2.dataframe;
802            tracing::trace!("{} vs {}", df, df_expect);
803            assert_eq!(df.dataframe, df_expect.dataframe);
804
805            df.__iadd__(hm.clone().into_pyobject(py)?.into_any())?;
806            df_expect.dataframe += hm;
807            tracing::trace!("{} vs {}", df, df_expect);
808            assert_eq!(df.dataframe, df_expect.dataframe);
809
810            Ok(())
811        });
812
813        assert!(exec.is_ok(), "{:?}", exec);
814    }
815
816    #[rstest]
817    #[traced_test]
818    fn basic_ops_sub(mut df: DataFrame, hm: HashMap<String, DataValue>) {
819        let mut df_expect = df.clone();
820        let df2 = df.clone();
821        let exec = Python::attach(|py| -> PyResult<()> {
822            df.__isub__(df.clone().into_pyobject(py)?.into_any())?;
823            df_expect.dataframe -= df2.dataframe;
824            tracing::trace!("{} vs {}", df, df_expect);
825            assert_eq!(df.dataframe, df_expect.dataframe);
826
827            df.__isub__(hm.clone().into_pyobject(py)?.into_any())?;
828            df_expect.dataframe -= hm;
829            tracing::trace!("{} vs {}", df, df_expect);
830            assert_eq!(df.dataframe, df_expect.dataframe);
831
832            Ok(())
833        });
834
835        assert!(exec.is_ok(), "{:?}", exec);
836    }
837
838    #[rstest]
839    #[traced_test]
840    fn basic_ops_mul(mut df: DataFrame, hm: HashMap<String, DataValue>) {
841        let mut df_expect = df.clone();
842        let df2 = df.clone();
843        let exec = Python::attach(|py| -> PyResult<()> {
844            df.__imul__(df.clone().into_pyobject(py)?.into_any())?;
845            df_expect.dataframe *= df2.dataframe;
846            tracing::trace!("{} vs {}", df, df_expect);
847            assert_eq!(df.dataframe, df_expect.dataframe);
848
849            df.__imul__(hm.clone().into_pyobject(py)?.into_any())?;
850            df_expect.dataframe *= hm;
851            tracing::trace!("{} vs {}", df, df_expect);
852            assert_eq!(df.dataframe, df_expect.dataframe);
853            Ok(())
854        });
855
856        assert!(exec.is_ok(), "{:?}", exec);
857    }
858
859    #[rstest]
860    #[traced_test]
861    fn basic_ops_div(mut df: DataFrame, hm: HashMap<String, DataValue>) {
862        let mut df_expect = df.clone();
863        let df2 = df.clone();
864        let exec = Python::attach(|py| -> PyResult<()> {
865            df.__itruediv__(df.clone().into_pyobject(py)?.into_any())?;
866            df_expect.dataframe /= df2.dataframe;
867            tracing::trace!("{} vs {}", df, df_expect);
868            assert_eq!(df.dataframe, df_expect.dataframe);
869
870            df.__itruediv__(hm.clone().into_pyobject(py)?.into_any())?;
871            df_expect.dataframe /= hm;
872            tracing::trace!("{} vs {}", df, df_expect);
873            assert_eq!(df.dataframe, df_expect.dataframe);
874            Ok(())
875        });
876
877        assert!(exec.is_ok(), "{:?}", exec);
878    }
879
880    #[rstest]
881    #[traced_test]
882    #[rstest]
883    fn test_numpy(mut df: DataFrame) {
884        let exec = Python::attach(|py| -> PyResult<()> {
885            let code = c_str!(
886                r#"
887def example(df):
888    import numpy as np
889    a_np = df.as_numpy_f32(['key1', 'key2'])
890    print(a_np)
891    b_np = df.as_numpy_u32(['key1', 'key'])
892    print(b_np)
893    b_np = df.as_numpy_i32(['key1', 'key'])
894    print(b_np)
895    b_np = df.as_numpy_i64(['key1', 'key'])
896    print(b_np)
897    b_np = df.as_numpy_u64(['key1', 'key'])
898    print(b_np)
899    b_np = df.as_numpy_f64(['key1', 'key'])
900    print(b_np)
901    b_np = df.as_numpy_f64(['key1', 'key'], transposed=True)
902    print(b_np)
903    b_np = df.as_numpy(['key1', 'key'], transposed=True)
904    print(b_np)
905    b_np = df.as_numpy_str(['key1', 'key'], transposed=True)
906    print(b_np)
907    return df
908            "#
909            );
910            let fun: Py<PyAny> = PyModule::from_code(py, code, c_str!(""), c_str!(""))?
911                .getattr("example")?
912                .into();
913            let result = fun.call1(py, (df.clone(),));
914            assert!(df.py_join(df.clone(), JoinRelation::default()).is_ok());
915            // user may not have installed polars, we need to get an error in that
916            // case
917            if py.import("numpy").is_ok() {
918                assert!(result.is_ok(), "{:?}", result);
919            } else {
920                assert!(result.is_err(), "{:?}", result);
921            }
922            Ok(())
923        });
924        assert!(exec.is_ok(), "{:?}", exec);
925    }
926
927    #[rstest]
928    #[traced_test]
929    #[rstest]
930    fn test_fill_from_python(df: DataFrame) {
931        let exec = Python::attach(|_py| -> PyResult<()> {
932            let hm = stdhashmap!(
933                Key::from("key1") => DataValue::U32(1),
934                Key::from("key2") => DataValue::U32(2),
935            );
936            let mut df2 = DataFrame::init();
937            assert!(df2.py_push(hm).is_ok());
938            assert!(df2
939                .py_push(stdhashmap!(
940                    Key::from("key1") => DataValue::U32(11),
941                    Key::from("key2") => DataValue::U32(21),
942                ))
943                .is_ok());
944
945            assert_eq!(df, df2);
946
947            let mut df2 = DataFrame::init();
948            assert!(df2
949                .py_add_column(
950                    Key::from("key1"),
951                    vec![DataValue::U32(1), DataValue::U32(11)]
952                )
953                .is_ok());
954            assert!(df2
955                .py_add_column(
956                    Key::from("key2"),
957                    vec![DataValue::U32(2), DataValue::U32(21)]
958                )
959                .is_ok());
960
961            assert_eq!(df, df2);
962            Ok(())
963        });
964        assert!(exec.is_ok(), "{:?}", exec);
965    }
966
967    #[rstest]
968    fn basic_python_dataframe(mut df: DataFrame) {
969        let exec = Python::attach(|py| -> PyResult<()> {
970            let fun: Py<PyAny> = PyModule::from_code(
971                py,
972                c_str!(
973                    "
974def example(df):
975    print(df)
976    df.shrink()
977    assert len(df) == 2
978    df.add_alias('key1', 'key1-alias')
979    a = df.select(['key1', 'key2'])
980    print(a)
981    b = df.select(['key1-alias', 'key2'])
982    print(b)
983    df.rename_key('key1', 'key1new')
984    df.rename_key('key1new', 'key1')
985    assert a == [[1, 2], [11, 21]]
986    assert a == b
987    df.add_metadata('test', 1)
988    m = df.get_metadata('test')
989    assert m == 1
990    b = df.select_transposed(['key1', 'key2'])
991    print(b)
992    assert b == [[1, 11], [2, 21]]
993    c = df.select_column('key1')
994    print(c)
995    assert c == [1, 11]
996
997    a += b
998    print(a)
999    assert a == [[2, 13], [4, 23]]
1000    a -= b
1001    print(a)
1002    assert e == a
1003    f = e * b
1004    print(f)
1005    assert f == [[1, 22], [44, 441]]
1006    g = f / b
1007    print(g)
1008    assert g == e
1009
1010                "
1011                ),
1012                c_str!(""),
1013                c_str!(""),
1014            )?
1015            .getattr("example")?
1016            .into();
1017            let _ = fun.call1(py, (df.clone(),));
1018            assert!(df.py_join(df.clone(), JoinRelation::default()).is_ok());
1019            Ok(())
1020        });
1021        assert!(exec.is_ok(), "{:?}", exec);
1022    }
1023
1024    #[rstest]
1025    fn dummy_test_apply(mut df: DataFrame) {
1026        let exec = Python::attach(|py| -> PyResult<()> {
1027            let fun: Py<PyAny> = PyModule::from_code(
1028                py,
1029                c_str!(
1030                    r#"
1031def multiply_by_ten(x):
1032    print(x)
1033    x *= {"key1": 10}
1034    print(x)
1035    return x
1036
1037def example(df):
1038    print(df)
1039    df.apply(multiply_by_ten)
1040                "#
1041                ),
1042                c_str!(""),
1043                c_str!(""),
1044            )?
1045            .getattr("example")?
1046            .into();
1047            let _ = fun.call1(py, (df.clone(),));
1048            assert!(df.py_join(df.clone(), JoinRelation::default()).is_ok());
1049            Ok(())
1050        });
1051        assert!(exec.is_ok(), "{:?}", exec);
1052    }
1053}