datafusion_python/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::{
19    common::data_type::PyScalarValue,
20    errors::{PyDataFusionError, PyDataFusionResult},
21    TokioRuntime,
22};
23use datafusion::{
24    common::ScalarValue, execution::context::SessionContext, logical_expr::Volatility,
25};
26use pyo3::prelude::*;
27use pyo3::{exceptions::PyValueError, types::PyCapsule};
28use std::{future::Future, sync::OnceLock, time::Duration};
29use tokio::{runtime::Runtime, time::sleep};
30/// Utility to get the Tokio Runtime from Python
31#[inline]
32pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime {
33    // NOTE: Other pyo3 python libraries have had issues with using tokio
34    // behind a forking app-server like `gunicorn`
35    // If we run into that problem, in the future we can look to `delta-rs`
36    // which adds a check in that disallows calls from a forked process
37    // https://github.com/delta-io/delta-rs/blob/87010461cfe01563d91a4b9cd6fa468e2ad5f283/python/src/utils.rs#L10-L31
38    static RUNTIME: OnceLock<TokioRuntime> = OnceLock::new();
39    RUNTIME.get_or_init(|| TokioRuntime(tokio::runtime::Runtime::new().unwrap()))
40}
41
42#[inline]
43pub(crate) fn is_ipython_env(py: Python) -> &'static bool {
44    static IS_IPYTHON_ENV: OnceLock<bool> = OnceLock::new();
45    IS_IPYTHON_ENV.get_or_init(|| {
46        py.import("IPython")
47            .and_then(|ipython| ipython.call_method0("get_ipython"))
48            .map(|ipython| !ipython.is_none())
49            .unwrap_or(false)
50    })
51}
52
53/// Utility to get the Global Datafussion CTX
54#[inline]
55pub(crate) fn get_global_ctx() -> &'static SessionContext {
56    static CTX: OnceLock<SessionContext> = OnceLock::new();
57    CTX.get_or_init(SessionContext::new)
58}
59
60/// Utility to collect rust futures with GIL released and respond to
61/// Python interrupts such as ``KeyboardInterrupt``. If a signal is
62/// received while the future is running, the future is aborted and the
63/// corresponding Python exception is raised.
64pub fn wait_for_future<F>(py: Python, fut: F) -> PyResult<F::Output>
65where
66    F: Future + Send,
67    F::Output: Send,
68{
69    let runtime: &Runtime = &get_tokio_runtime().0;
70    const INTERVAL_CHECK_SIGNALS: Duration = Duration::from_millis(1_000);
71
72    py.allow_threads(|| {
73        runtime.block_on(async {
74            tokio::pin!(fut);
75            loop {
76                tokio::select! {
77                    res = &mut fut => break Ok(res),
78                    _ = sleep(INTERVAL_CHECK_SIGNALS) => {
79                        Python::with_gil(|py| py.check_signals())?;
80                    }
81                }
82            }
83        })
84    })
85}
86
87pub(crate) fn parse_volatility(value: &str) -> PyDataFusionResult<Volatility> {
88    Ok(match value {
89        "immutable" => Volatility::Immutable,
90        "stable" => Volatility::Stable,
91        "volatile" => Volatility::Volatile,
92        value => {
93            return Err(PyDataFusionError::Common(format!(
94                "Unsupportad volatility type: `{value}`, supported \
95                 values are: immutable, stable and volatile."
96            )))
97        }
98    })
99}
100
101pub(crate) fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyResult<()> {
102    let capsule_name = capsule.name()?;
103    if capsule_name.is_none() {
104        return Err(PyValueError::new_err(
105            "Expected schema PyCapsule to have name set.",
106        ));
107    }
108
109    let capsule_name = capsule_name.unwrap().to_str()?;
110    if capsule_name != name {
111        return Err(PyValueError::new_err(format!(
112            "Expected name '{name}' in PyCapsule, instead got '{capsule_name}'"
113        )));
114    }
115
116    Ok(())
117}
118
119pub(crate) fn py_obj_to_scalar_value(py: Python, obj: PyObject) -> PyResult<ScalarValue> {
120    // convert Python object to PyScalarValue to ScalarValue
121
122    let pa = py.import("pyarrow")?;
123
124    // Convert Python object to PyArrow scalar
125    let scalar = pa.call_method1("scalar", (obj,))?;
126
127    // Convert PyArrow scalar to PyScalarValue
128    let py_scalar = PyScalarValue::extract_bound(scalar.as_ref())
129        .map_err(|e| PyValueError::new_err(format!("Failed to extract PyScalarValue: {e}")))?;
130
131    // Convert PyScalarValue to ScalarValue
132    Ok(py_scalar.into())
133}