perspective_python/client/
pandas.rs1use pyo3::exceptions::{PyImportError, PyValueError};
14use pyo3::prelude::*;
15use pyo3::types::{PyAny, PyBytes, PyDict, PyList};
16
17use super::pyarrow;
18
19fn get_pandas_df_cls(py: Python<'_>) -> PyResult<Option<Bound<'_, PyAny>>> {
20 let sys = PyModule::import(py, "sys")?;
21 if sys.getattr("modules")?.contains("pandas")? {
22 let pandas = PyModule::import(py, "pandas")?;
23 Ok(Some(pandas.getattr("DataFrame")?.into_pyobject(py)?))
24 } else {
25 Ok(None)
26 }
27}
28
29pub fn is_pandas_df(py: Python, df: &Bound<'_, PyAny>) -> PyResult<bool> {
30 if let Some(df_class) = get_pandas_df_cls(py)? {
31 df.is_instance(&df_class)
32 } else {
33 Ok(false)
34 }
35}
36
37pub fn arrow_to_pandas(py: Python<'_>, arrow: &[u8]) -> PyResult<Py<PyAny>> {
44 let pyarrow = PyModule::import(py, "pyarrow")?;
45 let bytes = PyBytes::new(py, arrow);
46 Ok(pyarrow
47 .getattr("ipc")?
48 .getattr("open_stream")?
49 .call1((bytes,))?
50 .getattr("read_all")?
51 .call0()?
52 .getattr("to_pandas")?
53 .call0()?
54 .unbind())
55}
56
57pub fn pandas_to_arrow_bytes<'py>(
58 py: Python<'py>,
59 df: &Bound<'py, PyAny>,
60) -> PyResult<Bound<'py, PyBytes>> {
61 let pyarrow = match PyModule::import(py, "pyarrow") {
62 Ok(pyarrow) => pyarrow,
63 Err(_) => {
64 return Err(PyImportError::new_err(
65 "Perspective requires pyarrow to convert pandas DataFrames. Please install \
66 pyarrow.",
67 ));
68 },
69 };
70
71 let df_class = get_pandas_df_cls(py)?
72 .ok_or_else(|| PyValueError::new_err("Failed to import pandas.DataFrame"))?;
73
74 if !df.is_instance(&df_class)? {
75 return Err(PyValueError::new_err("Input is not a pandas.DataFrame"));
76 }
77
78 let kwargs = PyDict::new(py);
79 kwargs.set_item("preserve_index", true)?;
80
81 let table = pyarrow
82 .getattr("Table")?
83 .call_method("from_pandas", (df,), Some(&kwargs))?;
84
85 let old_names: Vec<String> = table.getattr("column_names")?.extract()?;
87 let mut new_names: Vec<String> = old_names
88 .into_iter()
89 .map(|e| {
90 if e == "__index_level_0__" {
91 "index".to_string()
92 } else {
93 e
94 }
95 })
96 .collect();
97
98 let names = PyList::new(py, new_names.clone())?;
99 let table = table.call_method1("rename_columns", (names,))?;
100
101 if new_names[new_names.len() - 1] == "index" {
103 new_names.rotate_right(1);
104 let order = PyList::new(py, new_names)?;
105 let table = table.call_method1("select", (order,))?;
106 pyarrow::to_arrow_bytes(py, &table)
107 } else {
108 pyarrow::to_arrow_bytes(py, &table)
109 }
110}