Skip to main content

datafusion_python/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::future::Future;
19use std::sync::{Arc, OnceLock};
20use std::time::Duration;
21
22use datafusion::datasource::TableProvider;
23use datafusion::execution::context::SessionContext;
24use datafusion::logical_expr::Volatility;
25use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
26use datafusion_ffi::table_provider::FFI_TableProvider;
27use pyo3::IntoPyObjectExt;
28use pyo3::exceptions::{PyImportError, PyTypeError, PyValueError};
29use pyo3::prelude::*;
30use pyo3::types::{PyCapsule, PyType};
31use tokio::runtime::Runtime;
32use tokio::task::JoinHandle;
33use tokio::time::sleep;
34
35use crate::TokioRuntime;
36use crate::context::PySessionContext;
37use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err, to_datafusion_err};
38
39/// Utility to get the Tokio Runtime from Python
40#[inline]
41pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime {
42    // NOTE: Other pyo3 python libraries have had issues with using tokio
43    // behind a forking app-server like `gunicorn`
44    // If we run into that problem, in the future we can look to `delta-rs`
45    // which adds a check in that disallows calls from a forked process
46    // https://github.com/delta-io/delta-rs/blob/87010461cfe01563d91a4b9cd6fa468e2ad5f283/python/src/utils.rs#L10-L31
47    static RUNTIME: OnceLock<TokioRuntime> = OnceLock::new();
48    RUNTIME.get_or_init(|| TokioRuntime(tokio::runtime::Runtime::new().unwrap()))
49}
50
51#[inline]
52pub(crate) fn is_ipython_env(py: Python) -> &'static bool {
53    static IS_IPYTHON_ENV: OnceLock<bool> = OnceLock::new();
54    IS_IPYTHON_ENV.get_or_init(|| {
55        py.import("IPython")
56            .and_then(|ipython| ipython.call_method0("get_ipython"))
57            .map(|ipython| !ipython.is_none())
58            .unwrap_or(false)
59    })
60}
61
62/// Utility to get the Global Datafussion CTX
63#[inline]
64pub(crate) fn get_global_ctx() -> &'static Arc<SessionContext> {
65    static CTX: OnceLock<Arc<SessionContext>> = OnceLock::new();
66    CTX.get_or_init(|| Arc::new(SessionContext::new()))
67}
68
69/// Utility to collect rust futures with GIL released and respond to
70/// Python interrupts such as ``KeyboardInterrupt``. If a signal is
71/// received while the future is running, the future is aborted and the
72/// corresponding Python exception is raised.
73pub fn wait_for_future<F>(py: Python, fut: F) -> PyResult<F::Output>
74where
75    F: Future + Send,
76    F::Output: Send,
77{
78    let runtime: &Runtime = &get_tokio_runtime().0;
79    const INTERVAL_CHECK_SIGNALS: Duration = Duration::from_millis(1_000);
80
81    // Some fast running processes that generate many `wait_for_future` calls like
82    // PartitionedDataFrameStreamReader::next require checking for interrupts early
83    py.run(cr"pass", None, None)?;
84    py.check_signals()?;
85
86    py.detach(|| {
87        runtime.block_on(async {
88            tokio::pin!(fut);
89            loop {
90                tokio::select! {
91                    res = &mut fut => break Ok(res),
92                    _ = sleep(INTERVAL_CHECK_SIGNALS) => {
93                        Python::attach(|py| {
94                                // Execute a no-op Python statement to trigger signal processing.
95                                // This is necessary because py.check_signals() alone doesn't
96                                // actually check for signals - it only raises an exception if
97                                // a signal was already set during a previous Python API call.
98                                // Running even trivial Python code forces the interpreter to
99                                // process any pending signals (like KeyboardInterrupt).
100                                py.run(cr"pass", None, None)?;
101                                py.check_signals()
102                        })?;
103                    }
104                }
105            }
106        })
107    })
108}
109
110/// Spawn a [`Future`] on the Tokio runtime and wait for completion
111/// while respecting Python signal handling.
112pub(crate) fn spawn_future<F, T>(py: Python, fut: F) -> PyDataFusionResult<T>
113where
114    F: Future<Output = datafusion::common::Result<T>> + Send + 'static,
115    T: Send + 'static,
116{
117    let rt = &get_tokio_runtime().0;
118    let handle: JoinHandle<datafusion::common::Result<T>> = rt.spawn(fut);
119    // Wait for the join handle while respecting Python signal handling.
120    // We handle errors in two steps so `?` maps the error types correctly:
121    // 1) convert any Python-related error from `wait_for_future` into `PyDataFusionError`
122    // 2) convert any DataFusion error (inner result) into `PyDataFusionError`
123    let inner_result = wait_for_future(py, async {
124        // handle.await yields `Result<datafusion::common::Result<T>, JoinError>`
125        // map JoinError into a DataFusion error so the async block returns
126        // `datafusion::common::Result<T>` (i.e. Result<T, DataFusionError>)
127        match handle.await {
128            Ok(inner) => inner,
129            Err(join_err) => Err(to_datafusion_err(join_err)),
130        }
131    })?; // converts PyErr -> PyDataFusionError
132
133    // `inner_result` is `datafusion::common::Result<T>`; use `?` to convert
134    // the inner DataFusion error into `PyDataFusionError` via `From` and
135    // return the inner `T` on success.
136    Ok(inner_result?)
137}
138
139pub(crate) fn parse_volatility(value: &str) -> PyDataFusionResult<Volatility> {
140    Ok(match value {
141        "immutable" => Volatility::Immutable,
142        "stable" => Volatility::Stable,
143        "volatile" => Volatility::Volatile,
144        value => {
145            return Err(PyDataFusionError::Common(format!(
146                "Unsupported volatility type: `{value}`, supported \
147                 values are: immutable, stable and volatile."
148            )));
149        }
150    })
151}
152
153pub(crate) fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyResult<()> {
154    let capsule_name = capsule.name()?;
155    if capsule_name.is_none() {
156        return Err(PyValueError::new_err(format!(
157            "Expected {name} PyCapsule to have name set."
158        )));
159    }
160
161    let capsule_name = capsule_name.unwrap().to_str()?;
162    if capsule_name != name {
163        return Err(PyValueError::new_err(format!(
164            "Expected name '{name}' in PyCapsule, instead got '{capsule_name}'"
165        )));
166    }
167
168    Ok(())
169}
170
171pub(crate) fn table_provider_from_pycapsule<'py>(
172    mut obj: Bound<'py, PyAny>,
173    session: Bound<'py, PyAny>,
174) -> PyResult<Option<Arc<dyn TableProvider>>> {
175    if obj.hasattr("__datafusion_table_provider__")? {
176        obj = obj
177            .getattr("__datafusion_table_provider__")?
178            .call1((session,)).map_err(|err| {
179            let py = obj.py();
180            if err.get_type(py).is(PyType::new::<PyTypeError>(py)) {
181                PyImportError::new_err("Incompatible libraries. DataFusion 52.0.0 introduced an incompatible signature change for table providers. Either downgrade DataFusion or upgrade your function library.")
182            } else {
183                err
184            }
185        })?;
186    }
187
188    if let Ok(capsule) = obj.downcast::<PyCapsule>().map_err(py_datafusion_err) {
189        validate_pycapsule(capsule, "datafusion_table_provider")?;
190
191        let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
192        let provider: Arc<dyn TableProvider> = provider.into();
193
194        Ok(Some(provider))
195    } else {
196        Ok(None)
197    }
198}
199
200pub(crate) fn extract_logical_extension_codec(
201    py: Python,
202    obj: Option<Bound<PyAny>>,
203) -> PyResult<Arc<FFI_LogicalExtensionCodec>> {
204    let obj = match obj {
205        Some(obj) => obj,
206        None => PySessionContext::global_ctx()?.into_bound_py_any(py)?,
207    };
208    let capsule = if obj.hasattr("__datafusion_logical_extension_codec__")? {
209        obj.getattr("__datafusion_logical_extension_codec__")?
210            .call0()?
211    } else {
212        obj
213    };
214    let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
215
216    validate_pycapsule(capsule, "datafusion_logical_extension_codec")?;
217
218    let codec = unsafe { capsule.reference::<FFI_LogicalExtensionCodec>() };
219    Ok(Arc::new(codec.clone()))
220}
221
222pub(crate) fn create_logical_extension_capsule<'py>(
223    py: Python<'py>,
224    codec: &FFI_LogicalExtensionCodec,
225) -> PyResult<Bound<'py, PyCapsule>> {
226    let name = cr"datafusion_logical_extension_codec".into();
227    let codec = codec.clone();
228
229    PyCapsule::new(py, codec, Some(name))
230}