querent_synapse/querent/
py_runtime.rs

1use crate::{
2	callbacks::PyEventCallbackInterface,
3	config::{Config, Neo4jQueryConfig},
4	cross::{CLRepr, CLReprPython},
5	querent::errors::QuerentError,
6	tokio_runtime,
7};
8use log::{error, trace};
9use once_cell::sync::OnceCell;
10use pyo3::{
11	prelude::*,
12	types::{PyFunction, PyTuple},
13};
14use std::{fmt::Formatter, future::Future, pin::Pin};
15use tokio::sync::oneshot;
16
17#[derive(Debug)]
18pub struct PyAsyncFun {
19	fun: Py<PyFunction>,
20	args: Vec<CLRepr>,
21	callback: PyAsyncCallback,
22	config: Option<Config>,
23	query_config: Option<Neo4jQueryConfig>,
24}
25
26pub enum PyAsyncCallback {
27	Channel(oneshot::Sender<Result<CLRepr, QuerentError>>),
28}
29
30impl std::fmt::Debug for PyAsyncCallback {
31	fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
32		match self {
33			PyAsyncCallback::Channel(_) => write!(f, "Channel<hidden>"),
34		}
35	}
36}
37
38impl PyAsyncFun {
39	pub fn split(
40		self,
41	) -> (Py<PyFunction>, Vec<CLRepr>, PyAsyncCallback, Option<Config>, Option<Neo4jQueryConfig>) {
42		(self.fun, self.args, self.callback, self.config, self.query_config)
43	}
44}
45
46enum PyAsyncFunResult {
47	Poll(Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send>>),
48}
49
50pub struct PyRuntime {
51	sender: tokio::sync::mpsc::Sender<PyAsyncFun>,
52}
53
54impl PyRuntime {
55	pub async fn call_async(
56		&self,
57		fun: Py<PyFunction>,
58		args: Vec<CLRepr>,
59		config: Option<Config>,
60		query_config: Option<Neo4jQueryConfig>,
61	) -> Result<CLRepr, QuerentError> {
62		let (rx, tx) = oneshot::channel();
63
64		self.sender
65			.send(PyAsyncFun {
66				fun,
67				args,
68				callback: PyAsyncCallback::Channel(rx),
69				config,
70				query_config,
71			})
72			.await
73			.map_err(|err| {
74				QuerentError::internal(format!("Unable to schedule python function call: {}", err))
75			})?;
76
77		tx.await?
78	}
79
80	fn process_coroutines(task: PyAsyncFun) -> Result<(), QuerentError> {
81		let (fun, args, callback, config, query_config) = task.split();
82
83		let task_result = Python::with_gil(move |py| -> PyResult<PyAsyncFunResult> {
84			let mut args_tuple = Vec::with_capacity(args.len());
85
86			// TODO simplify this code
87			if let Some(config) = config {
88				args_tuple.push(config.to_object(py));
89			} else if let Some(query_config) = query_config {
90				args_tuple.push(query_config.to_object(py));
91			}
92
93			for arg in args {
94				args_tuple.push(arg.into_py(py)?);
95			}
96			let args = PyTuple::new(py, args_tuple);
97			let call_res = fun.call1(py, args)?;
98			let fut = pyo3_asyncio::tokio::into_future(call_res.as_ref(py))?;
99			Ok(PyAsyncFunResult::Poll(Box::pin(fut)))
100		});
101		let task_result = match task_result {
102			Ok(r) => r,
103			Err(err) => {
104				match callback {
105					PyAsyncCallback::Channel(chan) => {
106						let send_res = chan
107							.send(Err(QuerentError::internal(format!("Python error: {}", err))));
108						if send_res.is_err() {
109							return Err(QuerentError::internal(
110								"Unable to send result back to consumer".to_string(),
111							));
112						}
113					},
114				};
115
116				return Ok(());
117			},
118		};
119
120		match task_result {
121			PyAsyncFunResult::Poll(fut) => {
122				tokio::spawn(async move {
123					let fut_res = fut.await;
124
125					let res = Python::with_gil(move |py| -> Result<CLRepr, PyErr> {
126						let res = match fut_res {
127							Ok(r) => CLRepr::from_python_ref(r.as_ref(py)),
128							Err(err) => Err(err),
129						};
130
131						res
132					});
133
134					match callback {
135						PyAsyncCallback::Channel(chan) => {
136							let _ = match res {
137								Ok(r) => chan.send(Ok(r)),
138								Err(err) => chan.send(Err(QuerentError::internal(format!(
139									"Python error: {}",
140									err
141								)))),
142							};
143						},
144					}
145				});
146			},
147		};
148
149		Ok(())
150	}
151
152	pub fn new() -> Self {
153		let (sender, mut receiver) = tokio::sync::mpsc::channel::<PyAsyncFun>(1024);
154
155		trace!("New Python runtime");
156
157		std::thread::spawn(|| {
158			trace!("Initializing executor in a separate thread");
159
160			std::thread::spawn(|| {
161				pyo3_asyncio::tokio::get_runtime()
162					.block_on(pyo3_asyncio::tokio::re_exports::pending::<()>())
163			});
164
165			let res = Python::with_gil(|py| -> Result<(), PyErr> {
166				pyo3_asyncio::tokio::run(py, async move {
167					loop {
168						if let Some(task) = receiver.recv().await {
169							trace!("New task");
170
171							if let Err(err) = Self::process_coroutines(task) {
172								error!("Error while processing python task: {:?}", err)
173							};
174						}
175					}
176				})
177			});
178			match res {
179				Ok(_) => trace!("Python runtime loop was closed without error"),
180				Err(err) => error!("Critical error while processing python call: {}", err),
181			}
182		});
183
184		Self { sender }
185	}
186}
187
188static PY_RUNTIME: OnceCell<PyRuntime> = OnceCell::new();
189
190pub fn py_runtime() -> Result<&'static PyRuntime, QuerentError> {
191	if let Some(runtime) = PY_RUNTIME.get() {
192		Ok(runtime)
193	} else {
194		let runtime = PyRuntime::new();
195		PY_RUNTIME
196			.set(runtime)
197			.map(|_| PY_RUNTIME.get().unwrap())
198			.map_err(|_| QuerentError::internal("Unable to set PyRuntime".to_string()))
199	}
200}
201
202pub fn call_async(
203	fun: Py<PyFunction>,
204	args: Vec<CLRepr>,
205	config: Option<Config>,
206	query_config: Option<Neo4jQueryConfig>,
207) -> Result<impl Future<Output = Result<CLRepr, QuerentError>>, QuerentError> {
208	let runtime = py_runtime()?;
209	Ok(runtime.call_async(fun, args, config, query_config))
210}
211
212pub fn py_runtime_init() -> Result<(), QuerentError> {
213	if PY_RUNTIME.get().is_some() {
214		return Ok(());
215	}
216
217	let runtime = tokio_runtime()?;
218
219	pyo3::prepare_freethreaded_python();
220
221	pyo3_asyncio::tokio::init_with_runtime(runtime)
222		.map_err(|_| QuerentError::internal("Unable to initialize Python runtime".to_string()))?;
223	if PY_RUNTIME.set(PyRuntime::new()).is_err() {
224		Err(QuerentError::internal("Unable to set PyRuntime".to_string()))
225	} else {
226		Ok(())
227	}
228}