datafusion_python/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::common::data_type::PyScalarValue;
19use crate::errors::{PyDataFusionError, PyDataFusionResult};
20use crate::TokioRuntime;
21use datafusion::common::ScalarValue;
22use datafusion::execution::context::SessionContext;
23use datafusion::logical_expr::Volatility;
24use pyo3::exceptions::PyValueError;
25use pyo3::prelude::*;
26use pyo3::types::PyCapsule;
27use std::future::Future;
28use std::sync::OnceLock;
29use tokio::runtime::Runtime;
30
31/// Utility to get the Tokio Runtime from Python
32#[inline]
33pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime {
34    // NOTE: Other pyo3 python libraries have had issues with using tokio
35    // behind a forking app-server like `gunicorn`
36    // If we run into that problem, in the future we can look to `delta-rs`
37    // which adds a check in that disallows calls from a forked process
38    // https://github.com/delta-io/delta-rs/blob/87010461cfe01563d91a4b9cd6fa468e2ad5f283/python/src/utils.rs#L10-L31
39    static RUNTIME: OnceLock<TokioRuntime> = OnceLock::new();
40    RUNTIME.get_or_init(|| TokioRuntime(tokio::runtime::Runtime::new().unwrap()))
41}
42
43/// Utility to get the Global Datafussion CTX
44#[inline]
45pub(crate) fn get_global_ctx() -> &'static SessionContext {
46    static CTX: OnceLock<SessionContext> = OnceLock::new();
47    CTX.get_or_init(SessionContext::new)
48}
49
50/// Utility to collect rust futures with GIL released
51pub fn wait_for_future<F>(py: Python, f: F) -> F::Output
52where
53    F: Future + Send,
54    F::Output: Send,
55{
56    let runtime: &Runtime = &get_tokio_runtime().0;
57    py.allow_threads(|| runtime.block_on(f))
58}
59
60pub(crate) fn parse_volatility(value: &str) -> PyDataFusionResult<Volatility> {
61    Ok(match value {
62        "immutable" => Volatility::Immutable,
63        "stable" => Volatility::Stable,
64        "volatile" => Volatility::Volatile,
65        value => {
66            return Err(PyDataFusionError::Common(format!(
67                "Unsupportad volatility type: `{value}`, supported \
68                 values are: immutable, stable and volatile."
69            )))
70        }
71    })
72}
73
74pub(crate) fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyResult<()> {
75    let capsule_name = capsule.name()?;
76    if capsule_name.is_none() {
77        return Err(PyValueError::new_err(
78            "Expected schema PyCapsule to have name set.",
79        ));
80    }
81
82    let capsule_name = capsule_name.unwrap().to_str()?;
83    if capsule_name != name {
84        return Err(PyValueError::new_err(format!(
85            "Expected name '{}' in PyCapsule, instead got '{}'",
86            name, capsule_name
87        )));
88    }
89
90    Ok(())
91}
92
93pub(crate) fn py_obj_to_scalar_value(py: Python, obj: PyObject) -> PyResult<ScalarValue> {
94    // convert Python object to PyScalarValue to ScalarValue
95
96    let pa = py.import("pyarrow")?;
97
98    // Convert Python object to PyArrow scalar
99    let scalar = pa.call_method1("scalar", (obj,))?;
100
101    // Convert PyArrow scalar to PyScalarValue
102    let py_scalar = PyScalarValue::extract_bound(scalar.as_ref())
103        .map_err(|e| PyValueError::new_err(format!("Failed to extract PyScalarValue: {}", e)))?;
104
105    // Convert PyScalarValue to ScalarValue
106    Ok(py_scalar.into())
107}