potato_agent/agents/
types.rs

1use crate::agents::error::AgentError;
2use potato_provider::ChatResponse;
3use potato_provider::LogProbExt;
4use potato_provider::ResponseExt;
5use potato_provider::Usage;
6use potato_util::json_to_pyobject;
7use potato_util::utils::{LogProbs, ResponseLogProbs};
8use pyo3::prelude::*;
9use pyo3::IntoPyObjectExt;
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use tracing::instrument;
13use tracing::warn;
14
15#[pyclass]
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
17pub struct AgentResponse {
18    pub id: String,
19    pub response: ChatResponse,
20}
21
22#[pymethods]
23impl AgentResponse {
24    pub fn token_usage(&self) -> Result<Usage, AgentError> {
25        match &self.response {
26            ChatResponse::OpenAI(resp) => Ok(resp.usage.clone()),
27            ChatResponse::Gemini(resp) => Ok(resp.get_token_usage()),
28            ChatResponse::VertexGenerate(resp) => Ok(resp.get_token_usage()),
29            _ => Err(AgentError::NotSupportedError(
30                "Token usage not supported for the vertex predict response type".to_string(),
31            )),
32        }
33    }
34
35    //pub fn result<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, AgentError> {
36    //    let pyobj = json_to_pyobject(py, &self.content())?;
37    //
38    //    // Convert plain string output to Python string
39    //    Ok(pyobj.into_bound_py_any(py)?)
40    //}
41}
42
43impl AgentResponse {
44    pub fn new(id: String, response: ChatResponse) -> Self {
45        Self { id, response }
46    }
47
48    pub fn content(&self) -> Option<String> {
49        match &self.response {
50            ChatResponse::OpenAI(resp) => resp.get_content(),
51            ChatResponse::Gemini(resp) => resp.get_content(),
52            ChatResponse::VertexGenerate(resp) => resp.get_content(),
53            _ => {
54                warn!("Content not available for this response type");
55                None
56            }
57        }
58    }
59
60    pub fn log_probs(&self) -> Vec<ResponseLogProbs> {
61        match &self.response {
62            ChatResponse::OpenAI(resp) => resp.get_log_probs(),
63            ChatResponse::Gemini(resp) => resp.get_log_probs(),
64            ChatResponse::VertexGenerate(resp) => resp.get_log_probs(),
65            _ => {
66                warn!("Log probabilities not available for this response type");
67                vec![]
68            }
69        }
70    }
71}
72
73#[pyclass(name = "AgentResponse")]
74#[derive(Debug, Serialize)]
75pub struct PyAgentResponse {
76    pub response: AgentResponse,
77
78    #[serde(skip_serializing)]
79    pub output_type: Option<Py<PyAny>>,
80
81    #[pyo3(get)]
82    pub failed_conversion: bool,
83}
84
85#[pymethods]
86impl PyAgentResponse {
87    #[getter]
88    pub fn id(&self) -> &str {
89        &self.response.id
90    }
91
92    #[getter]
93    pub fn token_usage(&self) -> Result<Usage, AgentError> {
94        self.response.token_usage()
95    }
96
97    #[getter]
98    pub fn log_probs(&self) -> LogProbs {
99        LogProbs {
100            tokens: self.response.log_probs(),
101        }
102    }
103
104    /// This will map a the content of the response to a python object.
105    /// A python object in this case will be either a passed pydantic model or support potatohead types.
106    /// If neither is porvided, an attempt is made to parse the serde Value into an appropriate Python type.
107    /// Types:
108    /// - Serde Null -> Python None
109    /// - Serde Bool -> Python bool
110    /// - Serde String -> Python str
111    /// - Serde Number -> Python int or float
112    /// - Serde Array -> Python list (with each item converted to Python type)
113    /// - Serde Object -> Python dict (with each key-value pair converted to Python type)
114    #[getter]
115    #[instrument(skip_all)]
116    pub fn result<'py>(&mut self, py: Python<'py>) -> Result<Bound<'py, PyAny>, AgentError> {
117        let content_value = self.response.content();
118
119        // If the content is None, return None
120        if content_value.is_none() {
121            return Ok(py.None().into_bound_py_any(py)?);
122        }
123        // convert content_value to string
124        let content_value = content_value.unwrap();
125
126        match &self.output_type {
127            Some(output_type) => {
128                // Match the value. For loading into pydantic models, it's expected that the api response is a JSON string.
129
130                let bound = output_type
131                    .bind(py)
132                    .call_method1("model_validate_json", (&content_value,));
133
134                match bound {
135                    Ok(obj) => {
136                        // Successfully validated the model
137                        Ok(obj)
138                    }
139                    Err(err) => {
140                        // Model validation failed
141                        // convert string to json and then to python object
142                        warn!("Failed to validate model: {}", err);
143                        self.failed_conversion = true;
144                        let val = serde_json::from_str::<Value>(&content_value)?;
145                        Ok(json_to_pyobject(py, &val)?.into_bound_py_any(py)?)
146                    }
147                }
148            }
149            None => {
150                // If no output type is provided, attempt to parse the content as JSON
151                let val = Value::String(content_value);
152                Ok(json_to_pyobject(py, &val)?.into_bound_py_any(py)?)
153            }
154        }
155    }
156}
157
158impl PyAgentResponse {
159    pub fn new(response: AgentResponse, output_type: Option<Py<PyAny>>) -> Self {
160        Self {
161            response,
162            output_type,
163            failed_conversion: false,
164        }
165    }
166}