perspective_python/client/
pandas.rs

1// ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
2// ┃ ██████ ██████ ██████       █      █      █      █      █ █▄  ▀███ █       ┃
3// ┃ ▄▄▄▄▄█ █▄▄▄▄▄ ▄▄▄▄▄█  ▀▀▀▀▀█▀▀▀▀▀ █ ▀▀▀▀▀█ ████████▌▐███ ███▄  ▀█ █ ▀▀▀▀▀ ┃
4// ┃ █▀▀▀▀▀ █▀▀▀▀▀ █▀██▀▀ ▄▄▄▄▄ █ ▄▄▄▄▄█ ▄▄▄▄▄█ ████████▌▐███ █████▄   █ ▄▄▄▄▄ ┃
5// ┃ █      ██████ █  ▀█▄       █ ██████      █      ███▌▐███ ███████▄ █       ┃
6// ┣━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┫
7// ┃ Copyright (c) 2017, the Perspective Authors.                              ┃
8// ┃ ╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌ ┃
9// ┃ This file is part of the Perspective library, distributed under the terms ┃
10// ┃ of the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0). ┃
11// ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛
12
13use pyo3::exceptions::{PyImportError, PyValueError};
14use pyo3::prelude::*;
15use pyo3::types::{PyAny, PyBytes, PyDict, PyList};
16
17use super::pyarrow;
18
19fn get_pandas_df_cls(py: Python<'_>) -> PyResult<Option<Bound<'_, PyAny>>> {
20    let sys = PyModule::import(py, "sys")?;
21    if sys.getattr("modules")?.contains("pandas")? {
22        let pandas = PyModule::import(py, "pandas")?;
23        Ok(Some(pandas.getattr("DataFrame")?.into_pyobject(py)?))
24    } else {
25        Ok(None)
26    }
27}
28
29pub fn is_pandas_df(py: Python, df: &Bound<'_, PyAny>) -> PyResult<bool> {
30    if let Some(df_class) = get_pandas_df_cls(py)? {
31        df.is_instance(&df_class)
32    } else {
33        Ok(false)
34    }
35}
36
37// ipc_bytes = self.to_arrow()
38// table = pa.ipc.open_stream(ipc_bytes).read_all()
39// x = pd.DataFrame(table.to_pandas())
40// print("AAA", x)
41// return x
42
43pub fn arrow_to_pandas(py: Python<'_>, arrow: &[u8]) -> PyResult<Py<PyAny>> {
44    let pyarrow = PyModule::import(py, "pyarrow")?;
45    let bytes = PyBytes::new(py, arrow);
46    Ok(pyarrow
47        .getattr("ipc")?
48        .getattr("open_stream")?
49        .call1((bytes,))?
50        .getattr("read_all")?
51        .call0()?
52        .getattr("to_pandas")?
53        .call0()?
54        .unbind())
55}
56
57pub fn pandas_to_arrow_bytes<'py>(
58    py: Python<'py>,
59    df: &Bound<'py, PyAny>,
60) -> PyResult<Bound<'py, PyBytes>> {
61    let pyarrow = match PyModule::import(py, "pyarrow") {
62        Ok(pyarrow) => pyarrow,
63        Err(_) => {
64            return Err(PyImportError::new_err(
65                "Perspective requires pyarrow to convert pandas DataFrames. Please install \
66                 pyarrow.",
67            ));
68        },
69    };
70
71    let df_class = get_pandas_df_cls(py)?
72        .ok_or_else(|| PyValueError::new_err("Failed to import pandas.DataFrame"))?;
73
74    if !df.is_instance(&df_class)? {
75        return Err(PyValueError::new_err("Input is not a pandas.DataFrame"));
76    }
77
78    let kwargs = PyDict::new(py);
79    kwargs.set_item("preserve_index", true)?;
80
81    let table = pyarrow
82        .getattr("Table")?
83        .call_method("from_pandas", (df,), Some(&kwargs))?;
84
85    // rename from __index_level_0__ to index
86    let old_names: Vec<String> = table.getattr("column_names")?.extract()?;
87    let mut new_names: Vec<String> = old_names
88        .into_iter()
89        .map(|e| {
90            if e == "__index_level_0__" {
91                "index".to_string()
92            } else {
93                e
94            }
95        })
96        .collect();
97
98    let names = PyList::new(py, new_names.clone())?;
99    let table = table.call_method1("rename_columns", (names,))?;
100
101    // move the index column to be the first column.
102    if new_names[new_names.len() - 1] == "index" {
103        new_names.rotate_right(1);
104        let order = PyList::new(py, new_names)?;
105        let table = table.call_method1("select", (order,))?;
106        pyarrow::to_arrow_bytes(py, &table)
107    } else {
108        pyarrow::to_arrow_bytes(py, &table)
109    }
110}