genius_core_client/
python.rs

1use crate::auth::retrieve_auth_token_client_credentials as retrieve_auth_token_client_credentials_rs;
2use crate::client::inference::{
3    clear_observations as clear_observations_rs, get_probability as get_probability_rs,
4    ObservationValue,
5};
6use crate::client::{Client, TimeoutAndRetries};
7use crate::types::entity::HSMLEntity;
8use crate::types::static_schema::entity_schema::ENTITY_SCHEMA_SWID;
9use crate::types::static_schema::link_schema::LINK_SCHEMA_SWID;
10use crate::utils;
11use once_cell::sync::Lazy;
12use pyo3::prelude::*;
13use pyo3::types::IntoPyDict;
14use pyo3::types::{PyBool, PyDict, PyFloat, PyInt, PyList, PyLong, PyString};
15use pyo3::wrap_pyfunction;
16use serde_json::Value;
17use std::collections::HashMap;
18use std::sync::Arc;
19use tokio::sync::Mutex;
20
21// Lazy initialize client with null value to be set later
22// because class methods cannot be async which is
23// a limitation of Pyo3 / Python interop due to lifetime
24// of async functions being potentially longer lived than
25// the object we're defining methods on
26static mut CLIENT: Lazy<Option<Arc<Mutex<Client>>>> = Lazy::new(|| None);
27
28#[pymodule]
29fn genius_core_client(_py: Python, m: &PyModule) -> PyResult<()> {
30    m.add_function(wrap_pyfunction!(new_with_oauth2_token, m)?)?;
31    m.add_function(wrap_pyfunction!(make_swid, m)?)?;
32    m.add_class::<PyClient>()?;
33    m.add_class::<PyHSMLEntity>()?;
34
35    let auth_module = PyModule::new(_py, "auth")?;
36    let auth_utils_module = PyModule::new(_py, "utils")?;
37    auth_utils_module.add_function(wrap_pyfunction!(retrieve_auth_token_client_credentials, m)?)?;
38    auth_module.add_submodule(auth_utils_module)?;
39    m.add_submodule(auth_module)?;
40    Ok(())
41}
42
43#[pymodule]
44fn auth_utils(_py: Python, m: &PyModule) -> PyResult<()> {
45    m.add_function(wrap_pyfunction!(retrieve_auth_token_client_credentials, m)?)?;
46    Ok(())
47}
48
49#[pyfunction]
50pub fn make_swid(class: String) -> String {
51    utils::make_swid(&class)
52}
53
54#[pyfunction]
55pub fn new_with_oauth2_token(
56    py: Python,
57    protocol: String,
58    host: String,
59    port: String,
60    token: String,
61    timeout: Option<u64>,
62    retries: Option<u32>,
63) -> PyResult<&PyAny> {
64    pyo3_asyncio::tokio::future_into_py(py, async move {
65        let timeout_and_retries = TimeoutAndRetries {
66            timeout: tokio::time::Duration::from_secs(timeout.unwrap_or(30)),
67            retries: retries.unwrap_or(3),
68        };
69        let result = Client::new_with_oauth2_token(
70            crate::client::Protocol::from(protocol.as_str()),
71            host,
72            port,
73            token,
74            Some(timeout_and_retries),
75        )
76        .await;
77        match result {
78            Ok(client) => {
79                // Initialize client because class methods cannot be async
80                // which is a limitation of Pyo3 / Python interop due to lifetime
81                // of async functions being potentially longer lived than
82                // the object we're defining methods on
83                unsafe {
84                    *CLIENT = Some(Arc::new(Mutex::new(client)));
85                }
86                // Pass back to give a convenient and familiar OOP-like interface
87                // that people are familiar with and that way
88                // there is no confusion when using static functions
89                // whether the client has been initialized because the very fact
90                // of the object being returned generally means successful
91                // client creation
92                Ok(PyClient {
93                    inner: unsafe { CLIENT.as_ref().unwrap().clone() },
94                    inference: PyInference {},
95                })
96            }
97            Err(err) => Err(PyErr::new::<pyo3::exceptions::PyException, _>(format!(
98                "{}",
99                err
100            ))),
101        }
102    })
103}
104
105#[pyclass]
106#[derive(Clone)]
107pub struct PyInference {}
108
109#[pymethods]
110impl PyInference {
111    pub fn get_probability(
112        &self,
113        py: Python,
114        variables: Vec<String>,
115        evidence: Option<&PyDict>,
116    ) -> PyResult<PyObject> {
117        let evidence_map: Option<HashMap<String, ObservationValue>> = evidence.map(|dict| {
118            dict.into_iter()
119                .map(|(key, val)| {
120                    let key: String = key.extract().unwrap();
121                    let val_dict: &PyDict = val.extract().expect("Failed to extract PyDict");
122                    let element = val_dict
123                        .get_item("element")
124                        .unwrap()
125                        .unwrap()
126                        .extract::<String>()
127                        .ok();
128                    let distribution = val_dict
129                        .get_item("distribution")
130                        .unwrap()
131                        .unwrap()
132                        .extract::<Vec<f64>>()
133                        .ok();
134                    let none = val_dict
135                        .get_item("none")
136                        .unwrap()
137                        .unwrap()
138                        .extract::<bool>()
139                        .ok();
140
141                    let val = match (element, distribution, none) {
142                        (Some(e), _, _) => ObservationValue::Element(e),
143                        (_, Some(d), _) => ObservationValue::Distribution(d),
144                        (_, _, Some(_)) => ObservationValue::None,
145                        _ => panic!("Invalid type"),
146                    };
147
148                    (key, val)
149                })
150                .collect()
151        });
152
153        let result = pyo3_asyncio::tokio::future_into_py(py, async move {
154            let mut client = unsafe { CLIENT.as_ref().unwrap().lock().await };
155            let result = get_probability_rs(&mut client, variables, evidence_map).await;
156            Python::with_gil(|py| {
157                match result {
158                    Ok(result) => {
159                        // Convert HashMap<String, Vec<f64>> into pyo3 Python object
160                        let dict = PyDict::new(py);
161                        for (key, val) in result {
162                            let py_list = PyList::new(py, &val);
163                            dict.set_item(key, py_list).unwrap();
164                        }
165                        Ok(dict.to_object(py))
166                    }
167                    Err(err) => Err(PyErr::new::<pyo3::exceptions::PyException, _>(format!(
168                        "{:#?}",
169                        err
170                    ))),
171                }
172            })
173        });
174        Ok(result.unwrap().to_object(py))
175    }
176
177    pub fn clear_observations(
178        &self,
179        py: Python,
180        variables: Option<Vec<String>>,
181    ) -> PyResult<PyObject> {
182        let result = pyo3_asyncio::tokio::future_into_py(py, async move {
183            let mut client = unsafe { CLIENT.as_ref().unwrap().lock().await };
184            match clear_observations_rs(&mut client, variables).await {
185                Ok(result) => {
186                    // Convert the result into strings here, inside the async block
187                    let result: Vec<String> =
188                        result.into_iter().map(|value| value.to_string()).collect();
189                    Ok(result)
190                }
191                Err(err) => Err(PyErr::new::<pyo3::exceptions::PyException, _>(format!(
192                    "{:#?}",
193                    err
194                ))),
195            }
196        });
197        Ok(result.unwrap().to_object(py))
198    }
199}
200
201#[pyclass]
202pub struct PyClient {
203    pub inner: Arc<Mutex<Client>>,
204    inference: PyInference,
205}
206
207#[pymethods]
208impl PyClient {
209    #[staticmethod]
210    fn query(py: Python, query: String) -> PyResult<&PyAny> {
211        pyo3_asyncio::tokio::future_into_py(py, async move {
212            let mut client = unsafe { CLIENT.as_ref().unwrap().lock().await };
213            let result = client.query(query).await;
214            match result {
215                Ok(value) => {
216                    // Convert the HSML Entity into a Python Dictionary
217                    // to be usable from Python without have to
218                    // pass in a generic parameter to return an abstracted
219                    // HSML Entity type, which would not allow it to
220                    // be typed as any shape of data structure, and only
221                    // the strict type of the generic parameter passed in.
222                    let value_str = serde_json::to_string(&value).unwrap();
223                    Python::with_gil(|py| {
224                        let py_value_str = PyString::new(py, &value_str);
225                        let json_module = PyModule::import(py, "json")?;
226                        let parsed = json_module.getattr("loads")?.call1((py_value_str,))?;
227                        Ok(parsed.to_object(py))
228                    })
229                }
230                Err(err) => {
231                    let err_msg = format!("{}", err);
232                    Python::with_gil(|_py| {
233                        let py_err = PyErr::new::<pyo3::exceptions::PyException, _>(err_msg);
234                        Err(py_err)
235                    })
236                }
237            }
238        })
239    }
240
241    #[getter]
242    fn get_inference(&self) -> PyResult<PyInference> {
243        Ok(self.inference.clone())
244    }
245}
246
247#[pyfunction]
248pub fn retrieve_auth_token_client_credentials(
249    client_id: String,
250    client_secret: String,
251    token_url: String,
252    audience: Option<String>,
253    scope: Option<String>,
254) -> PyResult<PyObject> {
255    Python::with_gil(|py| {
256        let result = tokio::runtime::Runtime::new().unwrap().block_on(
257            retrieve_auth_token_client_credentials_rs(
258                client_id,
259                client_secret,
260                token_url,
261                audience,
262                scope,
263            ),
264        );
265
266        match result {
267            Ok(token_response) => {
268                let dict = [("access_token", token_response.access_token)].into_py_dict(py);
269                Ok(dict.to_object(py))
270            }
271            Err(err) => Err(PyErr::new::<pyo3::exceptions::PyException, _>(format!(
272                "{}",
273                err
274            ))),
275        }
276    })
277}
278
279#[pyclass]
280pub struct PyHSMLEntity {
281    pub inner: HSMLEntity,
282}
283
284#[pymethods]
285impl PyHSMLEntity {
286    #[new]
287    fn new(kwargs: Option<&PyDict>) -> Self {
288        let mut entity = HSMLEntity::new(String::from(""));
289        if let Some(kwargs) = kwargs {
290            for (key, val) in kwargs {
291                match key.to_string().as_str() {
292                    "swid" => entity.swid = val.extract().unwrap(),
293                    "__archived" => entity.__archived = val.extract().unwrap(),
294                    "schema" => entity.schema = val.extract().unwrap(),
295                    "name" => entity.name = val.extract().unwrap(),
296                    "source_swid" => {
297                        // Make sure entity schema type is link if contains source_swid
298                        entity.schema =
299                            vec![ENTITY_SCHEMA_SWID.to_string(), LINK_SCHEMA_SWID.to_string()];
300                        if let Ok(py_list) = val.downcast::<PyList>() {
301                            let vec: Vec<Value> = py_list
302                                .into_iter()
303                                .map(|item| {
304                                    // Recursively convert each item in the list to a Value
305                                    // You may need to handle different types here as well
306                                    Value::String(item.extract::<String>().unwrap())
307                                })
308                                .collect();
309                            entity.source_swid = Some(Value::Array(vec));
310                        } else {
311                            panic!("Invalid type")
312                        }
313                    }
314                    "destination_swid" => {
315                        // Make sure entity schema type is link if contains destination_swid
316                        entity.schema =
317                            vec![ENTITY_SCHEMA_SWID.to_string(), LINK_SCHEMA_SWID.to_string()];
318                        if let Ok(py_list) = val.downcast::<PyList>() {
319                            let vec: Vec<Value> = py_list
320                                .into_iter()
321                                .map(|item| {
322                                    // Recursively convert each item in the list to a Value
323                                    // You may need to handle different types here as well
324                                    Value::String(item.extract::<String>().unwrap())
325                                })
326                                .collect();
327                            entity.destination_swid = Some(Value::Array(vec));
328                        } else {
329                            panic!("Invalid type")
330                        }
331                    }
332                    _ => {
333                        let value = if let Ok(py_str) = val.downcast::<PyString>() {
334                            Value::String(py_str.to_string())
335                        } else if let Ok(py_bool) = val.downcast::<PyBool>() {
336                            Value::Bool(py_bool.is_true())
337                        } else if let Ok(py_int) = val.downcast::<PyInt>() {
338                            Value::Number(py_int.extract::<i64>().unwrap().into())
339                        } else if let Ok(py_int) = val.downcast::<PyLong>() {
340                            Value::Number(py_int.extract::<i64>().unwrap().into())
341                        } else if let Ok(py_float) = val.downcast::<PyFloat>() {
342                            Value::Number(
343                                serde_json::Number::from_f64(py_float.extract::<f64>().unwrap())
344                                    .unwrap(),
345                            )
346                        } else if let Ok(py_list) = val.downcast::<PyList>() {
347                            let vec: Vec<Value> = py_list
348                                .into_iter()
349                                .map(|item| {
350                                    // Recursively convert each item in the list to a Value
351                                    // You may need to handle different types here as well
352                                    Value::String(item.extract::<String>().unwrap())
353                                })
354                                .collect();
355                            Value::Array(vec)
356                        } else {
357                            panic!("Invalid type")
358                        };
359                        entity.extra_fields.insert(key.to_string(), value);
360                    }
361                }
362            }
363        }
364        PyHSMLEntity { inner: entity }
365    }
366
367    // Out of scope:
368    // Implement this if we want a convenience function
369    // to instantiate a new link entity
370    // fn new_link(kwargs: Option<&PyDict>) -> Self {
371    //     let entity = HSMLEntity::new_link(String::from(""), &[], &[], None);
372    //     PyHSMLEntity { inner: }
373    // }
374
375    // Out of scope:
376    // Implement this if you want to give users a way to print
377    // the entity to a dictionary
378    // #[text_signature = "(self)"]
379    // fn to_dict(&self, py: Python) -> PyResult<&PyDict> {
380    //     let dict = PyDict::new(py);
381    //     dict.set_item("swid", self.inner.swid.clone())?;
382    //     // Add other fields of the HSMLEntity struct here
383    //     // dict.set_item("field_name", self.inner.field_name.clone())?;
384    //     Ok(dict)
385    // }
386
387    #[getter]
388    fn get_swid(&self) -> PyResult<String> {
389        Ok(self.inner.swid.clone())
390    }
391
392    #[setter]
393    fn set_swid(&mut self, swid: String) {
394        self.inner.swid = swid;
395    }
396
397    #[getter]
398    fn get_destination_swid(&self) -> PyResult<Py<PyAny>> {
399        Python::with_gil(|py| {
400            let list = PyList::empty(py);
401            for item in self
402                .inner
403                .destination_swid
404                .clone()
405                .unwrap()
406                .as_array()
407                .unwrap()
408            {
409                list.append(PyString::new(py, item.as_str().unwrap()))
410                    .unwrap();
411            }
412            Ok(list.to_object(py))
413        })
414    }
415
416    #[setter]
417    fn set_destination_swid(&mut self, destination_swid: &PyList) {
418        let vec: Vec<Value> = destination_swid
419            .into_iter()
420            .map(|item| {
421                // Recursively convert each item in the list to a Value
422                // You may need to handle different types here as well
423                Value::String(item.extract::<String>().unwrap())
424            })
425            .collect();
426        self.inner.destination_swid = Some(Value::Array(vec));
427    }
428}