datafusion_python/
utils.rs1use std::future::Future;
19use std::sync::{Arc, OnceLock};
20use std::time::Duration;
21
22use datafusion::common::ScalarValue;
23use datafusion::datasource::TableProvider;
24use datafusion::execution::context::SessionContext;
25use datafusion::logical_expr::Volatility;
26use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
27use pyo3::exceptions::PyValueError;
28use pyo3::prelude::*;
29use pyo3::types::PyCapsule;
30use tokio::runtime::Runtime;
31use tokio::task::JoinHandle;
32use tokio::time::sleep;
33
34use crate::common::data_type::PyScalarValue;
35use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
36use crate::TokioRuntime;
37
38#[inline]
40pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime {
41 static RUNTIME: OnceLock<TokioRuntime> = OnceLock::new();
47 RUNTIME.get_or_init(|| TokioRuntime(tokio::runtime::Runtime::new().unwrap()))
48}
49
50#[inline]
51pub(crate) fn is_ipython_env(py: Python) -> &'static bool {
52 static IS_IPYTHON_ENV: OnceLock<bool> = OnceLock::new();
53 IS_IPYTHON_ENV.get_or_init(|| {
54 py.import("IPython")
55 .and_then(|ipython| ipython.call_method0("get_ipython"))
56 .map(|ipython| !ipython.is_none())
57 .unwrap_or(false)
58 })
59}
60
61#[inline]
63pub(crate) fn get_global_ctx() -> &'static SessionContext {
64 static CTX: OnceLock<SessionContext> = OnceLock::new();
65 CTX.get_or_init(SessionContext::new)
66}
67
68pub fn wait_for_future<F>(py: Python, fut: F) -> PyResult<F::Output>
73where
74 F: Future + Send,
75 F::Output: Send,
76{
77 let runtime: &Runtime = &get_tokio_runtime().0;
78 const INTERVAL_CHECK_SIGNALS: Duration = Duration::from_millis(1_000);
79
80 py.detach(|| {
81 runtime.block_on(async {
82 tokio::pin!(fut);
83 loop {
84 tokio::select! {
85 res = &mut fut => break Ok(res),
86 _ = sleep(INTERVAL_CHECK_SIGNALS) => {
87 Python::attach(|py| {
88 py.run(cr"pass", None, None)?;
95 py.check_signals()
96 })?;
97 }
98 }
99 }
100 })
101 })
102}
103
104pub(crate) fn spawn_future<F, T>(py: Python, fut: F) -> PyDataFusionResult<T>
107where
108 F: Future<Output = datafusion::common::Result<T>> + Send + 'static,
109 T: Send + 'static,
110{
111 let rt = &get_tokio_runtime().0;
112 let handle: JoinHandle<datafusion::common::Result<T>> = rt.spawn(fut);
113 let inner_result = wait_for_future(py, async {
118 match handle.await {
122 Ok(inner) => inner,
123 Err(join_err) => Err(to_datafusion_err(join_err)),
124 }
125 })?; Ok(inner_result?)
131}
132
133pub(crate) fn parse_volatility(value: &str) -> PyDataFusionResult<Volatility> {
134 Ok(match value {
135 "immutable" => Volatility::Immutable,
136 "stable" => Volatility::Stable,
137 "volatile" => Volatility::Volatile,
138 value => {
139 return Err(PyDataFusionError::Common(format!(
140 "Unsupported volatility type: `{value}`, supported \
141 values are: immutable, stable and volatile."
142 )))
143 }
144 })
145}
146
147pub(crate) fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyResult<()> {
148 let capsule_name = capsule.name()?;
149 if capsule_name.is_none() {
150 return Err(PyValueError::new_err(format!(
151 "Expected {name} PyCapsule to have name set."
152 )));
153 }
154
155 let capsule_name = capsule_name.unwrap().to_str()?;
156 if capsule_name != name {
157 return Err(PyValueError::new_err(format!(
158 "Expected name '{name}' in PyCapsule, instead got '{capsule_name}'"
159 )));
160 }
161
162 Ok(())
163}
164
165pub(crate) fn table_provider_from_pycapsule(
166 obj: &Bound<PyAny>,
167) -> PyResult<Option<Arc<dyn TableProvider>>> {
168 if obj.hasattr("__datafusion_table_provider__")? {
169 let capsule = obj.getattr("__datafusion_table_provider__")?.call0()?;
170 let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
171 validate_pycapsule(capsule, "datafusion_table_provider")?;
172
173 let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
174 let provider: ForeignTableProvider = provider.into();
175
176 Ok(Some(Arc::new(provider)))
177 } else {
178 Ok(None)
179 }
180}
181
182pub(crate) fn py_obj_to_scalar_value(py: Python, obj: Py<PyAny>) -> PyResult<ScalarValue> {
183 let pa = py.import("pyarrow")?;
186
187 let scalar = pa.call_method1("scalar", (obj,))?;
189
190 let py_scalar = PyScalarValue::extract_bound(scalar.as_ref())
192 .map_err(|e| PyValueError::new_err(format!("Failed to extract PyScalarValue: {e}")))?;
193
194 Ok(py_scalar.into())
196}