potato_agent/agents/
types.rs

1use crate::agents::error::AgentError;
2use crate::agents::provider::gemini::FunctionCall;
3use crate::agents::provider::gemini::GenerateContentResponse;
4use crate::agents::provider::openai::OpenAIChatResponse;
5use crate::agents::provider::openai::ToolCall;
6use crate::agents::provider::openai::Usage;
7use crate::agents::provider::traits::LogProbExt;
8use crate::agents::provider::traits::ResponseExt;
9use potato_prompt::{
10    prompt::{PromptContent, Role},
11    Message,
12};
13use potato_util::utils::{LogProbs, ResponseLogProbs};
14use potato_util::{json_to_pyobject, PyHelperFuncs};
15use pyo3::prelude::*;
16use pyo3::IntoPyObjectExt;
17use serde::{Deserialize, Serialize};
18use serde_json::Value;
19use tracing::debug;
20use tracing::instrument;
21use tracing::warn;
22
23#[pyclass]
24#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
25pub enum ChatResponse {
26    OpenAI(OpenAIChatResponse),
27    Gemini(GenerateContentResponse),
28}
29
30#[pymethods]
31impl ChatResponse {
32    pub fn to_py<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, AgentError> {
33        // try unwrapping the prompt, if it exists
34        match self {
35            ChatResponse::OpenAI(resp) => Ok(resp.clone().into_bound_py_any(py)?),
36            ChatResponse::Gemini(resp) => Ok(resp.clone().into_bound_py_any(py)?),
37        }
38    }
39    pub fn __str__(&self) -> String {
40        match self {
41            ChatResponse::OpenAI(resp) => PyHelperFuncs::__str__(resp),
42            ChatResponse::Gemini(resp) => PyHelperFuncs::__str__(resp),
43        }
44    }
45}
46
47impl ChatResponse {
48    pub fn is_empty(&self) -> bool {
49        match self {
50            ChatResponse::OpenAI(resp) => resp.choices.is_empty(),
51            ChatResponse::Gemini(resp) => resp.candidates.is_empty(),
52        }
53    }
54
55    #[instrument(skip_all)]
56    pub fn to_message(&self, role: Role) -> Result<Vec<Message>, AgentError> {
57        debug!("Converting chat response to message with role");
58        match self {
59            ChatResponse::OpenAI(resp) => {
60                let first_choice = resp
61                    .choices
62                    .first()
63                    .ok_or_else(|| AgentError::ClientNoResponseError)?;
64
65                let content =
66                    PromptContent::Str(first_choice.message.content.clone().unwrap_or_default());
67
68                Ok(vec![Message::from(content, role)])
69            }
70
71            ChatResponse::Gemini(resp) => {
72                let content = resp
73                    .candidates
74                    .first()
75                    .ok_or_else(|| AgentError::ClientNoResponseError)?
76                    .content
77                    .parts
78                    .first()
79                    .and_then(|part| part.text.as_ref())
80                    .map(|s| s.as_str())
81                    .unwrap_or("")
82                    .to_string();
83
84                Ok(vec![Message::from(PromptContent::Str(content), role)])
85            }
86        }
87    }
88
89    pub fn to_python<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, AgentError> {
90        match self {
91            ChatResponse::OpenAI(resp) => Ok(resp.clone().into_bound_py_any(py)?),
92            ChatResponse::Gemini(resp) => Ok(resp.clone().into_bound_py_any(py)?),
93        }
94    }
95
96    pub fn id(&self) -> String {
97        match self {
98            ChatResponse::OpenAI(resp) => resp.id.clone(),
99            ChatResponse::Gemini(resp) => resp.response_id.clone().unwrap_or("".to_string()),
100        }
101    }
102
103    /// Get the content of the first choice in the chat response
104    pub fn content(&self) -> Option<String> {
105        match self {
106            ChatResponse::OpenAI(resp) => {
107                resp.choices.first().and_then(|c| c.message.content.clone())
108            }
109            ChatResponse::Gemini(resp) => resp
110                .candidates
111                .first()
112                .and_then(|c| c.content.parts.first())
113                .and_then(|part| part.text.as_ref().map(|s| s.to_string())),
114        }
115    }
116
117    /// Check for tool calls in the chat response
118    pub fn tool_calls(&self) -> Option<Value> {
119        match self {
120            ChatResponse::OpenAI(resp) => {
121                let tool_calls: Option<&Vec<ToolCall>> =
122                    resp.choices.first().map(|c| c.message.tool_calls.as_ref());
123                tool_calls.and_then(|tc| serde_json::to_value(tc).ok())
124            }
125            ChatResponse::Gemini(resp) => {
126                // Collect all function calls from all parts in the first candidate
127                let function_calls: Vec<&FunctionCall> = resp
128                    .candidates
129                    .first()?
130                    .content
131                    .parts
132                    .iter()
133                    .filter_map(|part| part.function_call.as_ref())
134                    .collect();
135
136                if function_calls.is_empty() {
137                    None
138                } else {
139                    serde_json::to_value(&function_calls).ok()
140                }
141            }
142        }
143    }
144
145    /// Extracts structured data from a chat response
146    pub fn extract_structured_data(&self) -> Option<Value> {
147        if let Some(content) = self.content() {
148            serde_json::from_str(&content).ok()
149        } else {
150            self.tool_calls()
151        }
152    }
153}
154
155#[pyclass]
156#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
157pub struct AgentResponse {
158    pub id: String,
159    pub response: ChatResponse,
160}
161
162#[pymethods]
163impl AgentResponse {
164    pub fn token_usage(&self) -> Usage {
165        match &self.response {
166            ChatResponse::OpenAI(resp) => resp.usage.clone(),
167            ChatResponse::Gemini(resp) => resp.get_token_usage(),
168        }
169    }
170
171    //pub fn result<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, AgentError> {
172    //    let pyobj = json_to_pyobject(py, &self.content())?;
173    //
174    //    // Convert plain string output to Python string
175    //    Ok(pyobj.into_bound_py_any(py)?)
176    //}
177}
178
179impl AgentResponse {
180    pub fn new(id: String, response: ChatResponse) -> Self {
181        Self { id, response }
182    }
183
184    pub fn content(&self) -> Option<String> {
185        match &self.response {
186            ChatResponse::OpenAI(resp) => resp.get_content(),
187            ChatResponse::Gemini(resp) => resp.get_content(),
188        }
189    }
190
191    pub fn log_probs(&self) -> Vec<ResponseLogProbs> {
192        match &self.response {
193            ChatResponse::OpenAI(resp) => resp.get_log_probs(),
194            ChatResponse::Gemini(resp) => resp.get_log_probs(),
195        }
196    }
197}
198
199#[pyclass(name = "AgentResponse")]
200#[derive(Debug, Serialize)]
201pub struct PyAgentResponse {
202    pub response: AgentResponse,
203
204    #[serde(skip_serializing)]
205    pub output_type: Option<PyObject>,
206
207    #[pyo3(get)]
208    pub failed_conversion: bool,
209}
210
211#[pymethods]
212impl PyAgentResponse {
213    #[getter]
214    pub fn id(&self) -> &str {
215        &self.response.id
216    }
217
218    #[getter]
219    pub fn token_usage(&self) -> Usage {
220        self.response.token_usage()
221    }
222
223    #[getter]
224    pub fn log_probs(&self) -> LogProbs {
225        LogProbs {
226            tokens: self.response.log_probs(),
227        }
228    }
229
230    /// This will map a the content of the response to a python object.
231    /// A python object in this case will be either a passed pydantic model or support potatohead types.
232    /// If neither is porvided, an attempt is made to parse the serde Value into an appropriate Python type.
233    /// Types:
234    /// - Serde Null -> Python None
235    /// - Serde Bool -> Python bool
236    /// - Serde String -> Python str
237    /// - Serde Number -> Python int or float
238    /// - Serde Array -> Python list (with each item converted to Python type)
239    /// - Serde Object -> Python dict (with each key-value pair converted to Python type)
240    #[getter]
241    #[instrument(skip_all)]
242    pub fn result<'py>(&mut self, py: Python<'py>) -> Result<Bound<'py, PyAny>, AgentError> {
243        let content_value = self.response.content();
244
245        // If the content is None, return None
246        if content_value.is_none() {
247            return Ok(py.None().into_bound_py_any(py)?);
248        }
249        // convert content_value to string
250        let content_value = content_value.unwrap();
251
252        match &self.output_type {
253            Some(output_type) => {
254                // Match the value. For loading into pydantic models, it's expected that the api response is a JSON string.
255
256                let bound = output_type
257                    .bind(py)
258                    .call_method1("model_validate_json", (&content_value,));
259
260                match bound {
261                    Ok(obj) => {
262                        // Successfully validated the model
263                        Ok(obj)
264                    }
265                    Err(err) => {
266                        // Model validation failed
267                        // convert string to json and then to python object
268                        warn!("Failed to validate model: {}", err);
269                        self.failed_conversion = true;
270                        let val = serde_json::from_str::<Value>(&content_value)?;
271                        Ok(json_to_pyobject(py, &val)?.into_bound_py_any(py)?)
272                    }
273                }
274            }
275            None => {
276                // If no output type is provided, attempt to parse the content as JSON
277                let val = Value::String(content_value);
278                Ok(json_to_pyobject(py, &val)?.into_bound_py_any(py)?)
279            }
280        }
281    }
282}
283
284impl PyAgentResponse {
285    pub fn new(response: AgentResponse, output_type: Option<PyObject>) -> Self {
286        Self {
287            response,
288            output_type,
289            failed_conversion: false,
290        }
291    }
292}