datafusion_python/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::future::Future;
19use std::sync::{Arc, OnceLock};
20use std::time::Duration;
21
22use datafusion::common::ScalarValue;
23use datafusion::datasource::TableProvider;
24use datafusion::execution::context::SessionContext;
25use datafusion::logical_expr::Volatility;
26use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
27use pyo3::exceptions::PyValueError;
28use pyo3::prelude::*;
29use pyo3::types::PyCapsule;
30use tokio::runtime::Runtime;
31use tokio::task::JoinHandle;
32use tokio::time::sleep;
33
34use crate::common::data_type::PyScalarValue;
35use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
36use crate::TokioRuntime;
37
38/// Utility to get the Tokio Runtime from Python
39#[inline]
40pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime {
41    // NOTE: Other pyo3 python libraries have had issues with using tokio
42    // behind a forking app-server like `gunicorn`
43    // If we run into that problem, in the future we can look to `delta-rs`
44    // which adds a check in that disallows calls from a forked process
45    // https://github.com/delta-io/delta-rs/blob/87010461cfe01563d91a4b9cd6fa468e2ad5f283/python/src/utils.rs#L10-L31
46    static RUNTIME: OnceLock<TokioRuntime> = OnceLock::new();
47    RUNTIME.get_or_init(|| TokioRuntime(tokio::runtime::Runtime::new().unwrap()))
48}
49
50#[inline]
51pub(crate) fn is_ipython_env(py: Python) -> &'static bool {
52    static IS_IPYTHON_ENV: OnceLock<bool> = OnceLock::new();
53    IS_IPYTHON_ENV.get_or_init(|| {
54        py.import("IPython")
55            .and_then(|ipython| ipython.call_method0("get_ipython"))
56            .map(|ipython| !ipython.is_none())
57            .unwrap_or(false)
58    })
59}
60
61/// Utility to get the Global Datafussion CTX
62#[inline]
63pub(crate) fn get_global_ctx() -> &'static SessionContext {
64    static CTX: OnceLock<SessionContext> = OnceLock::new();
65    CTX.get_or_init(SessionContext::new)
66}
67
68/// Utility to collect rust futures with GIL released and respond to
69/// Python interrupts such as ``KeyboardInterrupt``. If a signal is
70/// received while the future is running, the future is aborted and the
71/// corresponding Python exception is raised.
72pub fn wait_for_future<F>(py: Python, fut: F) -> PyResult<F::Output>
73where
74    F: Future + Send,
75    F::Output: Send,
76{
77    let runtime: &Runtime = &get_tokio_runtime().0;
78    const INTERVAL_CHECK_SIGNALS: Duration = Duration::from_millis(1_000);
79
80    py.detach(|| {
81        runtime.block_on(async {
82            tokio::pin!(fut);
83            loop {
84                tokio::select! {
85                    res = &mut fut => break Ok(res),
86                    _ = sleep(INTERVAL_CHECK_SIGNALS) => {
87                        Python::attach(|py| {
88                                // Execute a no-op Python statement to trigger signal processing.
89                                // This is necessary because py.check_signals() alone doesn't
90                                // actually check for signals - it only raises an exception if
91                                // a signal was already set during a previous Python API call.
92                                // Running even trivial Python code forces the interpreter to
93                                // process any pending signals (like KeyboardInterrupt).
94                                py.run(cr"pass", None, None)?;
95                                py.check_signals()
96                        })?;
97                    }
98                }
99            }
100        })
101    })
102}
103
104/// Spawn a [`Future`] on the Tokio runtime and wait for completion
105/// while respecting Python signal handling.
106pub(crate) fn spawn_future<F, T>(py: Python, fut: F) -> PyDataFusionResult<T>
107where
108    F: Future<Output = datafusion::common::Result<T>> + Send + 'static,
109    T: Send + 'static,
110{
111    let rt = &get_tokio_runtime().0;
112    let handle: JoinHandle<datafusion::common::Result<T>> = rt.spawn(fut);
113    // Wait for the join handle while respecting Python signal handling.
114    // We handle errors in two steps so `?` maps the error types correctly:
115    // 1) convert any Python-related error from `wait_for_future` into `PyDataFusionError`
116    // 2) convert any DataFusion error (inner result) into `PyDataFusionError`
117    let inner_result = wait_for_future(py, async {
118        // handle.await yields `Result<datafusion::common::Result<T>, JoinError>`
119        // map JoinError into a DataFusion error so the async block returns
120        // `datafusion::common::Result<T>` (i.e. Result<T, DataFusionError>)
121        match handle.await {
122            Ok(inner) => inner,
123            Err(join_err) => Err(to_datafusion_err(join_err)),
124        }
125    })?; // converts PyErr -> PyDataFusionError
126
127    // `inner_result` is `datafusion::common::Result<T>`; use `?` to convert
128    // the inner DataFusion error into `PyDataFusionError` via `From` and
129    // return the inner `T` on success.
130    Ok(inner_result?)
131}
132
133pub(crate) fn parse_volatility(value: &str) -> PyDataFusionResult<Volatility> {
134    Ok(match value {
135        "immutable" => Volatility::Immutable,
136        "stable" => Volatility::Stable,
137        "volatile" => Volatility::Volatile,
138        value => {
139            return Err(PyDataFusionError::Common(format!(
140                "Unsupported volatility type: `{value}`, supported \
141                 values are: immutable, stable and volatile."
142            )))
143        }
144    })
145}
146
147pub(crate) fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyResult<()> {
148    let capsule_name = capsule.name()?;
149    if capsule_name.is_none() {
150        return Err(PyValueError::new_err(format!(
151            "Expected {name} PyCapsule to have name set."
152        )));
153    }
154
155    let capsule_name = capsule_name.unwrap().to_str()?;
156    if capsule_name != name {
157        return Err(PyValueError::new_err(format!(
158            "Expected name '{name}' in PyCapsule, instead got '{capsule_name}'"
159        )));
160    }
161
162    Ok(())
163}
164
165pub(crate) fn table_provider_from_pycapsule(
166    obj: &Bound<PyAny>,
167) -> PyResult<Option<Arc<dyn TableProvider>>> {
168    if obj.hasattr("__datafusion_table_provider__")? {
169        let capsule = obj.getattr("__datafusion_table_provider__")?.call0()?;
170        let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
171        validate_pycapsule(capsule, "datafusion_table_provider")?;
172
173        let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
174        let provider: ForeignTableProvider = provider.into();
175
176        Ok(Some(Arc::new(provider)))
177    } else {
178        Ok(None)
179    }
180}
181
182pub(crate) fn py_obj_to_scalar_value(py: Python, obj: Py<PyAny>) -> PyResult<ScalarValue> {
183    // convert Python object to PyScalarValue to ScalarValue
184
185    let pa = py.import("pyarrow")?;
186
187    // Convert Python object to PyArrow scalar
188    let scalar = pa.call_method1("scalar", (obj,))?;
189
190    // Convert PyArrow scalar to PyScalarValue
191    let py_scalar = PyScalarValue::extract_bound(scalar.as_ref())
192        .map_err(|e| PyValueError::new_err(format!("Failed to extract PyScalarValue: {e}")))?;
193
194    // Convert PyScalarValue to ScalarValue
195    Ok(py_scalar.into())
196}