querent_synapse/querent/query/
query_engine.rs1use crate::{
2 callbacks::{EventCallbackInterface, PyEventCallbackInterface},
3 comm::ChannelHandler,
4 config::{Config, Neo4jQueryConfig},
5 cross::{CLRepr, CLReprPython},
6 querent::{py_runtime, PyRuntime, QuerentError},
7 tokio_runtime,
8};
9use futures::TryFutureExt;
10use log;
11use pyo3::{prelude::*, types::PyFunction};
12use std::{collections::HashMap, sync::Mutex};
13use tokio::runtime::Runtime;
14
15#[derive(Debug, Clone)]
17#[pyclass]
18pub struct QueryEngine {
19 pub name: String,
20 pub id: String,
21 pub import: String,
22 pub attr: String,
23 pub code: Option<String>,
24 pub arguments: Vec<CLRepr>,
25 pub config: Option<Neo4jQueryConfig>,
26}
27
28pub struct QueryEngineManager {
29 pub workflows: Mutex<HashMap<String, QueryEngine>>,
31 pub runtime: &'static PyRuntime,
33}
34
35impl QueryEngineManager {
36 pub fn new() -> Result<Self, String> {
38 let runtime = py_runtime().map_err(|e| e.to_string())?;
39 Ok(Self { workflows: Mutex::new(HashMap::new()), runtime })
40 }
41 pub fn add_workflow(&self, workflow: QueryEngine) -> Result<(), String> {
43 let mut workflows =
44 self.workflows.lock().map_err(|e| format!("Mutex lock failed: {}", e))?;
45 if workflows.contains_key(&workflow.id) {
46 return Err("Workflow with the same ID already exists.".to_string());
47 } else {
48 workflows.insert(workflow.id.clone(), workflow.clone());
49 }
50 Ok(())
51 }
52
53 pub fn get_workflows(&self) -> Vec<QueryEngine> {
55 let workflows = self.workflows.lock().unwrap();
56 workflows.values().cloned().collect()
57 }
58
59 pub async fn start_workflows(&self) -> Result<(), QuerentError> {
61 let workflows = self.get_workflows();
62 let handles: Vec<_> = workflows
63 .iter()
64 .map(|_workflow| {
65 let args = _workflow.arguments.clone();
66 let res = match &_workflow.code {
67 None => Python::with_gil(|py| {
68 let async_mod = py.import(_workflow.import.as_str()).map_err(|e| {
69 log::error!("Failed to import module {}: {}", _workflow.import, e);
70 QuerentError::internal(e.to_string())
71 })?;
72
73 let coroutine =
74 async_mod.getattr(_workflow.attr.as_str()).map_err(|_| {
75 log::error!("Failed to find start function.");
76 QuerentError::internal("Failed to find start function.".to_string())
77 })?;
78
79 let querent_py_fun: Py<PyFunction> = coroutine.extract().map_err(|e| {
80 log::error!("Failed to extract function: {}", e);
81 QuerentError::internal(e.to_string())
82 })?;
83
84 let call_future = self.runtime.call_async(
85 querent_py_fun,
86 args,
87 None,
88 _workflow.config.clone(),
89 );
90 Ok(call_future)
91 }),
92 Some(code) => {
93 let module_file: String = _workflow.id.clone() + ".py";
94 Python::with_gil(|py| {
95 let dynamic_module = PyModule::from_code(
96 py,
97 code.as_str(),
98 module_file.as_str(),
99 _workflow.name.as_str(),
100 )
101 .map_err(|e| {
102 log::error!("Failed to import module {}: {}", _workflow.import, e);
103 QuerentError::internal(e.to_string())
104 })?;
105
106 let attr_fun =
107 dynamic_module.getattr(_workflow.attr.as_str()).map_err(|_| {
108 log::error!("Failed to find start function.");
109 QuerentError::internal(
110 "Failed to find start function.".to_string(),
111 )
112 })?;
113
114 let querent_py_fun: Py<PyFunction> =
115 attr_fun.extract().map_err(|e| {
116 log::error!("Failed to extract function: {}", e);
117 QuerentError::internal(e.to_string())
118 })?;
119
120 let call_future = self.runtime.call_async(
121 querent_py_fun,
122 args,
123 None,
124 _workflow.config.clone(),
125 );
126 Ok(call_future)
127 })
128 },
129 };
130 res
131 })
132 .collect();
133 for handle in handles {
134 match handle {
135 Ok(future) => match future.await {
136 Ok(_) => log::info!("Workflow started."),
137 Err(e) => {
138 log::error!("Failed to start workflow: {}", e);
139 return Err(QuerentError::internal(e.to_string()));
140 },
141 },
142 Err(e) => {
143 log::error!("Failed to start workflow: {}", e);
144 return Err(e);
145 },
146 }
147 }
148 Ok(())
149 }
150}
151
152impl Drop for QueryEngineManager {
153 fn drop(&mut self) {
155 log::info!("Dropping WorkflowManager");
156 let _ = self.runtime;
157 }
158}