datafusion_python/
utils.rs1use crate::{
19 common::data_type::PyScalarValue,
20 errors::{PyDataFusionError, PyDataFusionResult},
21 TokioRuntime,
22};
23use datafusion::{
24 common::ScalarValue, execution::context::SessionContext, logical_expr::Volatility,
25};
26use pyo3::prelude::*;
27use pyo3::{exceptions::PyValueError, types::PyCapsule};
28use std::{future::Future, sync::OnceLock, time::Duration};
29use tokio::{runtime::Runtime, time::sleep};
30#[inline]
32pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime {
33 static RUNTIME: OnceLock<TokioRuntime> = OnceLock::new();
39 RUNTIME.get_or_init(|| TokioRuntime(tokio::runtime::Runtime::new().unwrap()))
40}
41
42#[inline]
43pub(crate) fn is_ipython_env(py: Python) -> &'static bool {
44 static IS_IPYTHON_ENV: OnceLock<bool> = OnceLock::new();
45 IS_IPYTHON_ENV.get_or_init(|| {
46 py.import("IPython")
47 .and_then(|ipython| ipython.call_method0("get_ipython"))
48 .map(|ipython| !ipython.is_none())
49 .unwrap_or(false)
50 })
51}
52
53#[inline]
55pub(crate) fn get_global_ctx() -> &'static SessionContext {
56 static CTX: OnceLock<SessionContext> = OnceLock::new();
57 CTX.get_or_init(SessionContext::new)
58}
59
60pub fn wait_for_future<F>(py: Python, fut: F) -> PyResult<F::Output>
65where
66 F: Future + Send,
67 F::Output: Send,
68{
69 let runtime: &Runtime = &get_tokio_runtime().0;
70 const INTERVAL_CHECK_SIGNALS: Duration = Duration::from_millis(1_000);
71
72 py.allow_threads(|| {
73 runtime.block_on(async {
74 tokio::pin!(fut);
75 loop {
76 tokio::select! {
77 res = &mut fut => break Ok(res),
78 _ = sleep(INTERVAL_CHECK_SIGNALS) => {
79 Python::with_gil(|py| py.check_signals())?;
80 }
81 }
82 }
83 })
84 })
85}
86
87pub(crate) fn parse_volatility(value: &str) -> PyDataFusionResult<Volatility> {
88 Ok(match value {
89 "immutable" => Volatility::Immutable,
90 "stable" => Volatility::Stable,
91 "volatile" => Volatility::Volatile,
92 value => {
93 return Err(PyDataFusionError::Common(format!(
94 "Unsupportad volatility type: `{value}`, supported \
95 values are: immutable, stable and volatile."
96 )))
97 }
98 })
99}
100
101pub(crate) fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyResult<()> {
102 let capsule_name = capsule.name()?;
103 if capsule_name.is_none() {
104 return Err(PyValueError::new_err(
105 "Expected schema PyCapsule to have name set.",
106 ));
107 }
108
109 let capsule_name = capsule_name.unwrap().to_str()?;
110 if capsule_name != name {
111 return Err(PyValueError::new_err(format!(
112 "Expected name '{name}' in PyCapsule, instead got '{capsule_name}'"
113 )));
114 }
115
116 Ok(())
117}
118
119pub(crate) fn py_obj_to_scalar_value(py: Python, obj: PyObject) -> PyResult<ScalarValue> {
120 let pa = py.import("pyarrow")?;
123
124 let scalar = pa.call_method1("scalar", (obj,))?;
126
127 let py_scalar = PyScalarValue::extract_bound(scalar.as_ref())
129 .map_err(|e| PyValueError::new_err(format!("Failed to extract PyScalarValue: {e}")))?;
130
131 Ok(py_scalar.into())
133}