datafusion_python/
utils.rs1use crate::common::data_type::PyScalarValue;
19use crate::errors::{PyDataFusionError, PyDataFusionResult};
20use crate::TokioRuntime;
21use datafusion::common::ScalarValue;
22use datafusion::execution::context::SessionContext;
23use datafusion::logical_expr::Volatility;
24use pyo3::exceptions::PyValueError;
25use pyo3::prelude::*;
26use pyo3::types::PyCapsule;
27use std::future::Future;
28use std::sync::OnceLock;
29use tokio::runtime::Runtime;
30
31#[inline]
33pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime {
34 static RUNTIME: OnceLock<TokioRuntime> = OnceLock::new();
40 RUNTIME.get_or_init(|| TokioRuntime(tokio::runtime::Runtime::new().unwrap()))
41}
42
43#[inline]
45pub(crate) fn get_global_ctx() -> &'static SessionContext {
46 static CTX: OnceLock<SessionContext> = OnceLock::new();
47 CTX.get_or_init(SessionContext::new)
48}
49
50pub fn wait_for_future<F>(py: Python, f: F) -> F::Output
52where
53 F: Future + Send,
54 F::Output: Send,
55{
56 let runtime: &Runtime = &get_tokio_runtime().0;
57 py.allow_threads(|| runtime.block_on(f))
58}
59
60pub(crate) fn parse_volatility(value: &str) -> PyDataFusionResult<Volatility> {
61 Ok(match value {
62 "immutable" => Volatility::Immutable,
63 "stable" => Volatility::Stable,
64 "volatile" => Volatility::Volatile,
65 value => {
66 return Err(PyDataFusionError::Common(format!(
67 "Unsupportad volatility type: `{value}`, supported \
68 values are: immutable, stable and volatile."
69 )))
70 }
71 })
72}
73
74pub(crate) fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyResult<()> {
75 let capsule_name = capsule.name()?;
76 if capsule_name.is_none() {
77 return Err(PyValueError::new_err(
78 "Expected schema PyCapsule to have name set.",
79 ));
80 }
81
82 let capsule_name = capsule_name.unwrap().to_str()?;
83 if capsule_name != name {
84 return Err(PyValueError::new_err(format!(
85 "Expected name '{}' in PyCapsule, instead got '{}'",
86 name, capsule_name
87 )));
88 }
89
90 Ok(())
91}
92
93pub(crate) fn py_obj_to_scalar_value(py: Python, obj: PyObject) -> PyResult<ScalarValue> {
94 let pa = py.import("pyarrow")?;
97
98 let scalar = pa.call_method1("scalar", (obj,))?;
100
101 let py_scalar = PyScalarValue::extract_bound(scalar.as_ref())
103 .map_err(|e| PyValueError::new_err(format!("Failed to extract PyScalarValue: {}", e)))?;
104
105 Ok(py_scalar.into())
107}