querent_synapse/querent/
py_runtime.rs1use crate::{
2 callbacks::PyEventCallbackInterface,
3 config::{Config, Neo4jQueryConfig},
4 cross::{CLRepr, CLReprPython},
5 querent::errors::QuerentError,
6 tokio_runtime,
7};
8use log::{error, trace};
9use once_cell::sync::OnceCell;
10use pyo3::{
11 prelude::*,
12 types::{PyFunction, PyTuple},
13};
14use std::{fmt::Formatter, future::Future, pin::Pin};
15use tokio::sync::oneshot;
16
17#[derive(Debug)]
18pub struct PyAsyncFun {
19 fun: Py<PyFunction>,
20 args: Vec<CLRepr>,
21 callback: PyAsyncCallback,
22 config: Option<Config>,
23 query_config: Option<Neo4jQueryConfig>,
24}
25
26pub enum PyAsyncCallback {
27 Channel(oneshot::Sender<Result<CLRepr, QuerentError>>),
28}
29
30impl std::fmt::Debug for PyAsyncCallback {
31 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
32 match self {
33 PyAsyncCallback::Channel(_) => write!(f, "Channel<hidden>"),
34 }
35 }
36}
37
38impl PyAsyncFun {
39 pub fn split(
40 self,
41 ) -> (Py<PyFunction>, Vec<CLRepr>, PyAsyncCallback, Option<Config>, Option<Neo4jQueryConfig>) {
42 (self.fun, self.args, self.callback, self.config, self.query_config)
43 }
44}
45
46enum PyAsyncFunResult {
47 Poll(Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send>>),
48}
49
50pub struct PyRuntime {
51 sender: tokio::sync::mpsc::Sender<PyAsyncFun>,
52}
53
54impl PyRuntime {
55 pub async fn call_async(
56 &self,
57 fun: Py<PyFunction>,
58 args: Vec<CLRepr>,
59 config: Option<Config>,
60 query_config: Option<Neo4jQueryConfig>,
61 ) -> Result<CLRepr, QuerentError> {
62 let (rx, tx) = oneshot::channel();
63
64 self.sender
65 .send(PyAsyncFun {
66 fun,
67 args,
68 callback: PyAsyncCallback::Channel(rx),
69 config,
70 query_config,
71 })
72 .await
73 .map_err(|err| {
74 QuerentError::internal(format!("Unable to schedule python function call: {}", err))
75 })?;
76
77 tx.await?
78 }
79
80 fn process_coroutines(task: PyAsyncFun) -> Result<(), QuerentError> {
81 let (fun, args, callback, config, query_config) = task.split();
82
83 let task_result = Python::with_gil(move |py| -> PyResult<PyAsyncFunResult> {
84 let mut args_tuple = Vec::with_capacity(args.len());
85
86 if let Some(config) = config {
88 args_tuple.push(config.to_object(py));
89 } else if let Some(query_config) = query_config {
90 args_tuple.push(query_config.to_object(py));
91 }
92
93 for arg in args {
94 args_tuple.push(arg.into_py(py)?);
95 }
96 let args = PyTuple::new(py, args_tuple);
97 let call_res = fun.call1(py, args)?;
98 let fut = pyo3_asyncio::tokio::into_future(call_res.as_ref(py))?;
99 Ok(PyAsyncFunResult::Poll(Box::pin(fut)))
100 });
101 let task_result = match task_result {
102 Ok(r) => r,
103 Err(err) => {
104 match callback {
105 PyAsyncCallback::Channel(chan) => {
106 let send_res = chan
107 .send(Err(QuerentError::internal(format!("Python error: {}", err))));
108 if send_res.is_err() {
109 return Err(QuerentError::internal(
110 "Unable to send result back to consumer".to_string(),
111 ));
112 }
113 },
114 };
115
116 return Ok(());
117 },
118 };
119
120 match task_result {
121 PyAsyncFunResult::Poll(fut) => {
122 tokio::spawn(async move {
123 let fut_res = fut.await;
124
125 let res = Python::with_gil(move |py| -> Result<CLRepr, PyErr> {
126 let res = match fut_res {
127 Ok(r) => CLRepr::from_python_ref(r.as_ref(py)),
128 Err(err) => Err(err),
129 };
130
131 res
132 });
133
134 match callback {
135 PyAsyncCallback::Channel(chan) => {
136 let _ = match res {
137 Ok(r) => chan.send(Ok(r)),
138 Err(err) => chan.send(Err(QuerentError::internal(format!(
139 "Python error: {}",
140 err
141 )))),
142 };
143 },
144 }
145 });
146 },
147 };
148
149 Ok(())
150 }
151
152 pub fn new() -> Self {
153 let (sender, mut receiver) = tokio::sync::mpsc::channel::<PyAsyncFun>(1024);
154
155 trace!("New Python runtime");
156
157 std::thread::spawn(|| {
158 trace!("Initializing executor in a separate thread");
159
160 std::thread::spawn(|| {
161 pyo3_asyncio::tokio::get_runtime()
162 .block_on(pyo3_asyncio::tokio::re_exports::pending::<()>())
163 });
164
165 let res = Python::with_gil(|py| -> Result<(), PyErr> {
166 pyo3_asyncio::tokio::run(py, async move {
167 loop {
168 if let Some(task) = receiver.recv().await {
169 trace!("New task");
170
171 if let Err(err) = Self::process_coroutines(task) {
172 error!("Error while processing python task: {:?}", err)
173 };
174 }
175 }
176 })
177 });
178 match res {
179 Ok(_) => trace!("Python runtime loop was closed without error"),
180 Err(err) => error!("Critical error while processing python call: {}", err),
181 }
182 });
183
184 Self { sender }
185 }
186}
187
188static PY_RUNTIME: OnceCell<PyRuntime> = OnceCell::new();
189
190pub fn py_runtime() -> Result<&'static PyRuntime, QuerentError> {
191 if let Some(runtime) = PY_RUNTIME.get() {
192 Ok(runtime)
193 } else {
194 let runtime = PyRuntime::new();
195 PY_RUNTIME
196 .set(runtime)
197 .map(|_| PY_RUNTIME.get().unwrap())
198 .map_err(|_| QuerentError::internal("Unable to set PyRuntime".to_string()))
199 }
200}
201
202pub fn call_async(
203 fun: Py<PyFunction>,
204 args: Vec<CLRepr>,
205 config: Option<Config>,
206 query_config: Option<Neo4jQueryConfig>,
207) -> Result<impl Future<Output = Result<CLRepr, QuerentError>>, QuerentError> {
208 let runtime = py_runtime()?;
209 Ok(runtime.call_async(fun, args, config, query_config))
210}
211
212pub fn py_runtime_init() -> Result<(), QuerentError> {
213 if PY_RUNTIME.get().is_some() {
214 return Ok(());
215 }
216
217 let runtime = tokio_runtime()?;
218
219 pyo3::prepare_freethreaded_python();
220
221 pyo3_asyncio::tokio::init_with_runtime(runtime)
222 .map_err(|_| QuerentError::internal("Unable to initialize Python runtime".to_string()))?;
223 if PY_RUNTIME.set(PyRuntime::new()).is_err() {
224 Err(QuerentError::internal("Unable to set PyRuntime".to_string()))
225 } else {
226 Ok(())
227 }
228}