potato_type/google/
predict.rs

1use crate::TypeError;
2use potato_util::PyHelperFuncs;
3use potato_util::{json_to_pyobject, pyobject_to_json};
4use pyo3::prelude::*;
5use pyo3::IntoPyObjectExt;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
10#[pyclass]
11#[serde(rename_all = "camelCase", default)]
12pub struct PredictRequest {
13    pub instances: Value,
14    pub parameters: Value,
15}
16
17#[pymethods]
18impl PredictRequest {
19    #[getter]
20    pub fn instances<'py>(&self, py: Python<'py>) -> Result<PyObject, TypeError> {
21        let obj = json_to_pyobject(py, &self.instances)?;
22        Ok(obj)
23    }
24
25    #[getter]
26    pub fn parameters<'py>(&self, py: Python<'py>) -> Result<PyObject, TypeError> {
27        let obj = json_to_pyobject(py, &self.parameters)?;
28        Ok(obj)
29    }
30
31    #[new]
32    #[pyo3(signature = (instances, parameters=None))]
33    pub fn new(instances: Bound<'_, PyAny>, parameters: Option<Bound<'_, PyAny>>) -> Self {
34        // check if instances is a PyList, if not,
35        let instances = pyobject_to_json(&instances).unwrap_or(Value::Null);
36        let parameters =
37            parameters.map_or(Value::Null, |p| pyobject_to_json(&p).unwrap_or(Value::Null));
38
39        Self {
40            instances,
41            parameters,
42        }
43    }
44
45    pub fn __str__(&self) -> String {
46        PyHelperFuncs::__str__(self)
47    }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
51#[pyclass]
52#[serde(rename_all = "camelCase", default)]
53pub struct PredictResponse {
54    pub predictions: Value,
55    pub metadata: Value,
56    #[pyo3(get)]
57    pub deployed_model_id: String,
58    #[pyo3(get)]
59    pub model: String,
60    #[pyo3(get)]
61    pub model_version_id: String,
62    #[pyo3(get)]
63    pub model_display_name: String,
64}
65
66#[pymethods]
67impl PredictResponse {
68    #[getter]
69    pub fn predictions<'py>(&self, py: Python<'py>) -> Result<PyObject, TypeError> {
70        let obj = json_to_pyobject(py, &self.predictions)?;
71        Ok(obj)
72    }
73
74    #[getter]
75    pub fn metadata<'py>(&self, py: Python<'py>) -> Result<PyObject, TypeError> {
76        let obj = json_to_pyobject(py, &self.metadata)?;
77        Ok(obj)
78    }
79
80    pub fn __str__(&self) -> String {
81        PyHelperFuncs::__str__(self)
82    }
83}
84
85impl PredictResponse {
86    pub fn into_py_bound_any<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
87        let bound = Py::new(py, self.clone())?;
88        Ok(bound.into_bound_py_any(py)?)
89    }
90}