datafusion_python/
utils.rs1use std::future::Future;
19use std::sync::{Arc, OnceLock};
20use std::time::Duration;
21
22use datafusion::datasource::TableProvider;
23use datafusion::execution::context::SessionContext;
24use datafusion::logical_expr::Volatility;
25use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
26use datafusion_ffi::table_provider::FFI_TableProvider;
27use pyo3::IntoPyObjectExt;
28use pyo3::exceptions::{PyImportError, PyTypeError, PyValueError};
29use pyo3::prelude::*;
30use pyo3::types::{PyCapsule, PyType};
31use tokio::runtime::Runtime;
32use tokio::task::JoinHandle;
33use tokio::time::sleep;
34
35use crate::TokioRuntime;
36use crate::context::PySessionContext;
37use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err, to_datafusion_err};
38
39#[inline]
41pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime {
42 static RUNTIME: OnceLock<TokioRuntime> = OnceLock::new();
48 RUNTIME.get_or_init(|| TokioRuntime(tokio::runtime::Runtime::new().unwrap()))
49}
50
51#[inline]
52pub(crate) fn is_ipython_env(py: Python) -> &'static bool {
53 static IS_IPYTHON_ENV: OnceLock<bool> = OnceLock::new();
54 IS_IPYTHON_ENV.get_or_init(|| {
55 py.import("IPython")
56 .and_then(|ipython| ipython.call_method0("get_ipython"))
57 .map(|ipython| !ipython.is_none())
58 .unwrap_or(false)
59 })
60}
61
62#[inline]
64pub(crate) fn get_global_ctx() -> &'static Arc<SessionContext> {
65 static CTX: OnceLock<Arc<SessionContext>> = OnceLock::new();
66 CTX.get_or_init(|| Arc::new(SessionContext::new()))
67}
68
69pub fn wait_for_future<F>(py: Python, fut: F) -> PyResult<F::Output>
74where
75 F: Future + Send,
76 F::Output: Send,
77{
78 let runtime: &Runtime = &get_tokio_runtime().0;
79 const INTERVAL_CHECK_SIGNALS: Duration = Duration::from_millis(1_000);
80
81 py.run(cr"pass", None, None)?;
84 py.check_signals()?;
85
86 py.detach(|| {
87 runtime.block_on(async {
88 tokio::pin!(fut);
89 loop {
90 tokio::select! {
91 res = &mut fut => break Ok(res),
92 _ = sleep(INTERVAL_CHECK_SIGNALS) => {
93 Python::attach(|py| {
94 py.run(cr"pass", None, None)?;
101 py.check_signals()
102 })?;
103 }
104 }
105 }
106 })
107 })
108}
109
110pub(crate) fn spawn_future<F, T>(py: Python, fut: F) -> PyDataFusionResult<T>
113where
114 F: Future<Output = datafusion::common::Result<T>> + Send + 'static,
115 T: Send + 'static,
116{
117 let rt = &get_tokio_runtime().0;
118 let handle: JoinHandle<datafusion::common::Result<T>> = rt.spawn(fut);
119 let inner_result = wait_for_future(py, async {
124 match handle.await {
128 Ok(inner) => inner,
129 Err(join_err) => Err(to_datafusion_err(join_err)),
130 }
131 })?; Ok(inner_result?)
137}
138
139pub(crate) fn parse_volatility(value: &str) -> PyDataFusionResult<Volatility> {
140 Ok(match value {
141 "immutable" => Volatility::Immutable,
142 "stable" => Volatility::Stable,
143 "volatile" => Volatility::Volatile,
144 value => {
145 return Err(PyDataFusionError::Common(format!(
146 "Unsupported volatility type: `{value}`, supported \
147 values are: immutable, stable and volatile."
148 )));
149 }
150 })
151}
152
153pub(crate) fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyResult<()> {
154 let capsule_name = capsule.name()?;
155 if capsule_name.is_none() {
156 return Err(PyValueError::new_err(format!(
157 "Expected {name} PyCapsule to have name set."
158 )));
159 }
160
161 let capsule_name = capsule_name.unwrap().to_str()?;
162 if capsule_name != name {
163 return Err(PyValueError::new_err(format!(
164 "Expected name '{name}' in PyCapsule, instead got '{capsule_name}'"
165 )));
166 }
167
168 Ok(())
169}
170
171pub(crate) fn table_provider_from_pycapsule<'py>(
172 mut obj: Bound<'py, PyAny>,
173 session: Bound<'py, PyAny>,
174) -> PyResult<Option<Arc<dyn TableProvider>>> {
175 if obj.hasattr("__datafusion_table_provider__")? {
176 obj = obj
177 .getattr("__datafusion_table_provider__")?
178 .call1((session,)).map_err(|err| {
179 let py = obj.py();
180 if err.get_type(py).is(PyType::new::<PyTypeError>(py)) {
181 PyImportError::new_err("Incompatible libraries. DataFusion 52.0.0 introduced an incompatible signature change for table providers. Either downgrade DataFusion or upgrade your function library.")
182 } else {
183 err
184 }
185 })?;
186 }
187
188 if let Ok(capsule) = obj.downcast::<PyCapsule>().map_err(py_datafusion_err) {
189 validate_pycapsule(capsule, "datafusion_table_provider")?;
190
191 let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
192 let provider: Arc<dyn TableProvider> = provider.into();
193
194 Ok(Some(provider))
195 } else {
196 Ok(None)
197 }
198}
199
200pub(crate) fn extract_logical_extension_codec(
201 py: Python,
202 obj: Option<Bound<PyAny>>,
203) -> PyResult<Arc<FFI_LogicalExtensionCodec>> {
204 let obj = match obj {
205 Some(obj) => obj,
206 None => PySessionContext::global_ctx()?.into_bound_py_any(py)?,
207 };
208 let capsule = if obj.hasattr("__datafusion_logical_extension_codec__")? {
209 obj.getattr("__datafusion_logical_extension_codec__")?
210 .call0()?
211 } else {
212 obj
213 };
214 let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
215
216 validate_pycapsule(capsule, "datafusion_logical_extension_codec")?;
217
218 let codec = unsafe { capsule.reference::<FFI_LogicalExtensionCodec>() };
219 Ok(Arc::new(codec.clone()))
220}
221
222pub(crate) fn create_logical_extension_capsule<'py>(
223 py: Python<'py>,
224 codec: &FFI_LogicalExtensionCodec,
225) -> PyResult<Bound<'py, PyCapsule>> {
226 let name = cr"datafusion_logical_extension_codec".into();
227 let codec = codec.clone();
228
229 PyCapsule::new(py, codec, Some(name))
230}