1use crate::auth::retrieve_auth_token_client_credentials as retrieve_auth_token_client_credentials_rs;
2use crate::client::inference::{
3 clear_observations as clear_observations_rs, get_probability as get_probability_rs,
4 ObservationValue,
5};
6use crate::client::{Client, TimeoutAndRetries};
7use crate::types::entity::HSMLEntity;
8use crate::types::static_schema::entity_schema::ENTITY_SCHEMA_SWID;
9use crate::types::static_schema::link_schema::LINK_SCHEMA_SWID;
10use crate::utils;
11use once_cell::sync::Lazy;
12use pyo3::prelude::*;
13use pyo3::types::IntoPyDict;
14use pyo3::types::{PyBool, PyDict, PyFloat, PyInt, PyList, PyLong, PyString};
15use pyo3::wrap_pyfunction;
16use serde_json::Value;
17use std::collections::HashMap;
18use std::sync::Arc;
19use tokio::sync::Mutex;
20
21static mut CLIENT: Lazy<Option<Arc<Mutex<Client>>>> = Lazy::new(|| None);
27
28#[pymodule]
29fn genius_core_client(_py: Python, m: &PyModule) -> PyResult<()> {
30 m.add_function(wrap_pyfunction!(new_with_oauth2_token, m)?)?;
31 m.add_function(wrap_pyfunction!(make_swid, m)?)?;
32 m.add_class::<PyClient>()?;
33 m.add_class::<PyHSMLEntity>()?;
34
35 let auth_module = PyModule::new(_py, "auth")?;
36 let auth_utils_module = PyModule::new(_py, "utils")?;
37 auth_utils_module.add_function(wrap_pyfunction!(retrieve_auth_token_client_credentials, m)?)?;
38 auth_module.add_submodule(auth_utils_module)?;
39 m.add_submodule(auth_module)?;
40 Ok(())
41}
42
43#[pymodule]
44fn auth_utils(_py: Python, m: &PyModule) -> PyResult<()> {
45 m.add_function(wrap_pyfunction!(retrieve_auth_token_client_credentials, m)?)?;
46 Ok(())
47}
48
49#[pyfunction]
50pub fn make_swid(class: String) -> String {
51 utils::make_swid(&class)
52}
53
54#[pyfunction]
55pub fn new_with_oauth2_token(
56 py: Python,
57 protocol: String,
58 host: String,
59 port: String,
60 token: String,
61 timeout: Option<u64>,
62 retries: Option<u32>,
63) -> PyResult<&PyAny> {
64 pyo3_asyncio::tokio::future_into_py(py, async move {
65 let timeout_and_retries = TimeoutAndRetries {
66 timeout: tokio::time::Duration::from_secs(timeout.unwrap_or(30)),
67 retries: retries.unwrap_or(3),
68 };
69 let result = Client::new_with_oauth2_token(
70 crate::client::Protocol::from(protocol.as_str()),
71 host,
72 port,
73 token,
74 Some(timeout_and_retries),
75 )
76 .await;
77 match result {
78 Ok(client) => {
79 unsafe {
84 *CLIENT = Some(Arc::new(Mutex::new(client)));
85 }
86 Ok(PyClient {
93 inner: unsafe { CLIENT.as_ref().unwrap().clone() },
94 inference: PyInference {},
95 })
96 }
97 Err(err) => Err(PyErr::new::<pyo3::exceptions::PyException, _>(format!(
98 "{}",
99 err
100 ))),
101 }
102 })
103}
104
105#[pyclass]
106#[derive(Clone)]
107pub struct PyInference {}
108
109#[pymethods]
110impl PyInference {
111 pub fn get_probability(
112 &self,
113 py: Python,
114 variables: Vec<String>,
115 evidence: Option<&PyDict>,
116 ) -> PyResult<PyObject> {
117 let evidence_map: Option<HashMap<String, ObservationValue>> = evidence.map(|dict| {
118 dict.into_iter()
119 .map(|(key, val)| {
120 let key: String = key.extract().unwrap();
121 let val_dict: &PyDict = val.extract().expect("Failed to extract PyDict");
122 let element = val_dict
123 .get_item("element")
124 .unwrap()
125 .unwrap()
126 .extract::<String>()
127 .ok();
128 let distribution = val_dict
129 .get_item("distribution")
130 .unwrap()
131 .unwrap()
132 .extract::<Vec<f64>>()
133 .ok();
134 let none = val_dict
135 .get_item("none")
136 .unwrap()
137 .unwrap()
138 .extract::<bool>()
139 .ok();
140
141 let val = match (element, distribution, none) {
142 (Some(e), _, _) => ObservationValue::Element(e),
143 (_, Some(d), _) => ObservationValue::Distribution(d),
144 (_, _, Some(_)) => ObservationValue::None,
145 _ => panic!("Invalid type"),
146 };
147
148 (key, val)
149 })
150 .collect()
151 });
152
153 let result = pyo3_asyncio::tokio::future_into_py(py, async move {
154 let mut client = unsafe { CLIENT.as_ref().unwrap().lock().await };
155 let result = get_probability_rs(&mut client, variables, evidence_map).await;
156 Python::with_gil(|py| {
157 match result {
158 Ok(result) => {
159 let dict = PyDict::new(py);
161 for (key, val) in result {
162 let py_list = PyList::new(py, &val);
163 dict.set_item(key, py_list).unwrap();
164 }
165 Ok(dict.to_object(py))
166 }
167 Err(err) => Err(PyErr::new::<pyo3::exceptions::PyException, _>(format!(
168 "{:#?}",
169 err
170 ))),
171 }
172 })
173 });
174 Ok(result.unwrap().to_object(py))
175 }
176
177 pub fn clear_observations(
178 &self,
179 py: Python,
180 variables: Option<Vec<String>>,
181 ) -> PyResult<PyObject> {
182 let result = pyo3_asyncio::tokio::future_into_py(py, async move {
183 let mut client = unsafe { CLIENT.as_ref().unwrap().lock().await };
184 match clear_observations_rs(&mut client, variables).await {
185 Ok(result) => {
186 let result: Vec<String> =
188 result.into_iter().map(|value| value.to_string()).collect();
189 Ok(result)
190 }
191 Err(err) => Err(PyErr::new::<pyo3::exceptions::PyException, _>(format!(
192 "{:#?}",
193 err
194 ))),
195 }
196 });
197 Ok(result.unwrap().to_object(py))
198 }
199}
200
201#[pyclass]
202pub struct PyClient {
203 pub inner: Arc<Mutex<Client>>,
204 inference: PyInference,
205}
206
207#[pymethods]
208impl PyClient {
209 #[staticmethod]
210 fn query(py: Python, query: String) -> PyResult<&PyAny> {
211 pyo3_asyncio::tokio::future_into_py(py, async move {
212 let mut client = unsafe { CLIENT.as_ref().unwrap().lock().await };
213 let result = client.query(query).await;
214 match result {
215 Ok(value) => {
216 let value_str = serde_json::to_string(&value).unwrap();
223 Python::with_gil(|py| {
224 let py_value_str = PyString::new(py, &value_str);
225 let json_module = PyModule::import(py, "json")?;
226 let parsed = json_module.getattr("loads")?.call1((py_value_str,))?;
227 Ok(parsed.to_object(py))
228 })
229 }
230 Err(err) => {
231 let err_msg = format!("{}", err);
232 Python::with_gil(|_py| {
233 let py_err = PyErr::new::<pyo3::exceptions::PyException, _>(err_msg);
234 Err(py_err)
235 })
236 }
237 }
238 })
239 }
240
241 #[getter]
242 fn get_inference(&self) -> PyResult<PyInference> {
243 Ok(self.inference.clone())
244 }
245}
246
247#[pyfunction]
248pub fn retrieve_auth_token_client_credentials(
249 client_id: String,
250 client_secret: String,
251 token_url: String,
252 audience: Option<String>,
253 scope: Option<String>,
254) -> PyResult<PyObject> {
255 Python::with_gil(|py| {
256 let result = tokio::runtime::Runtime::new().unwrap().block_on(
257 retrieve_auth_token_client_credentials_rs(
258 client_id,
259 client_secret,
260 token_url,
261 audience,
262 scope,
263 ),
264 );
265
266 match result {
267 Ok(token_response) => {
268 let dict = [("access_token", token_response.access_token)].into_py_dict(py);
269 Ok(dict.to_object(py))
270 }
271 Err(err) => Err(PyErr::new::<pyo3::exceptions::PyException, _>(format!(
272 "{}",
273 err
274 ))),
275 }
276 })
277}
278
279#[pyclass]
280pub struct PyHSMLEntity {
281 pub inner: HSMLEntity,
282}
283
284#[pymethods]
285impl PyHSMLEntity {
286 #[new]
287 fn new(kwargs: Option<&PyDict>) -> Self {
288 let mut entity = HSMLEntity::new(String::from(""));
289 if let Some(kwargs) = kwargs {
290 for (key, val) in kwargs {
291 match key.to_string().as_str() {
292 "swid" => entity.swid = val.extract().unwrap(),
293 "__archived" => entity.__archived = val.extract().unwrap(),
294 "schema" => entity.schema = val.extract().unwrap(),
295 "name" => entity.name = val.extract().unwrap(),
296 "source_swid" => {
297 entity.schema =
299 vec![ENTITY_SCHEMA_SWID.to_string(), LINK_SCHEMA_SWID.to_string()];
300 if let Ok(py_list) = val.downcast::<PyList>() {
301 let vec: Vec<Value> = py_list
302 .into_iter()
303 .map(|item| {
304 Value::String(item.extract::<String>().unwrap())
307 })
308 .collect();
309 entity.source_swid = Some(Value::Array(vec));
310 } else {
311 panic!("Invalid type")
312 }
313 }
314 "destination_swid" => {
315 entity.schema =
317 vec![ENTITY_SCHEMA_SWID.to_string(), LINK_SCHEMA_SWID.to_string()];
318 if let Ok(py_list) = val.downcast::<PyList>() {
319 let vec: Vec<Value> = py_list
320 .into_iter()
321 .map(|item| {
322 Value::String(item.extract::<String>().unwrap())
325 })
326 .collect();
327 entity.destination_swid = Some(Value::Array(vec));
328 } else {
329 panic!("Invalid type")
330 }
331 }
332 _ => {
333 let value = if let Ok(py_str) = val.downcast::<PyString>() {
334 Value::String(py_str.to_string())
335 } else if let Ok(py_bool) = val.downcast::<PyBool>() {
336 Value::Bool(py_bool.is_true())
337 } else if let Ok(py_int) = val.downcast::<PyInt>() {
338 Value::Number(py_int.extract::<i64>().unwrap().into())
339 } else if let Ok(py_int) = val.downcast::<PyLong>() {
340 Value::Number(py_int.extract::<i64>().unwrap().into())
341 } else if let Ok(py_float) = val.downcast::<PyFloat>() {
342 Value::Number(
343 serde_json::Number::from_f64(py_float.extract::<f64>().unwrap())
344 .unwrap(),
345 )
346 } else if let Ok(py_list) = val.downcast::<PyList>() {
347 let vec: Vec<Value> = py_list
348 .into_iter()
349 .map(|item| {
350 Value::String(item.extract::<String>().unwrap())
353 })
354 .collect();
355 Value::Array(vec)
356 } else {
357 panic!("Invalid type")
358 };
359 entity.extra_fields.insert(key.to_string(), value);
360 }
361 }
362 }
363 }
364 PyHSMLEntity { inner: entity }
365 }
366
367 #[getter]
388 fn get_swid(&self) -> PyResult<String> {
389 Ok(self.inner.swid.clone())
390 }
391
392 #[setter]
393 fn set_swid(&mut self, swid: String) {
394 self.inner.swid = swid;
395 }
396
397 #[getter]
398 fn get_destination_swid(&self) -> PyResult<Py<PyAny>> {
399 Python::with_gil(|py| {
400 let list = PyList::empty(py);
401 for item in self
402 .inner
403 .destination_swid
404 .clone()
405 .unwrap()
406 .as_array()
407 .unwrap()
408 {
409 list.append(PyString::new(py, item.as_str().unwrap()))
410 .unwrap();
411 }
412 Ok(list.to_object(py))
413 })
414 }
415
416 #[setter]
417 fn set_destination_swid(&mut self, destination_swid: &PyList) {
418 let vec: Vec<Value> = destination_swid
419 .into_iter()
420 .map(|item| {
421 Value::String(item.extract::<String>().unwrap())
424 })
425 .collect();
426 self.inner.destination_swid = Some(Value::Array(vec));
427 }
428}