potato_agent/agents/
types.rs1use crate::agents::error::AgentError;
2use crate::agents::provider::gemini::FunctionCall;
3use crate::agents::provider::gemini::GenerateContentResponse;
4use crate::agents::provider::openai::OpenAIChatResponse;
5use crate::agents::provider::openai::ToolCall;
6use crate::agents::provider::openai::Usage;
7use crate::agents::provider::traits::LogProbExt;
8use crate::agents::provider::traits::ResponseExt;
9use potato_prompt::{
10 prompt::{PromptContent, Role},
11 Message,
12};
13use potato_util::utils::{LogProbs, ResponseLogProbs};
14use potato_util::{json_to_pyobject, PyHelperFuncs};
15use pyo3::prelude::*;
16use pyo3::IntoPyObjectExt;
17use serde::{Deserialize, Serialize};
18use serde_json::Value;
19use tracing::debug;
20use tracing::instrument;
21use tracing::warn;
22
23#[pyclass]
24#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
25pub enum ChatResponse {
26 OpenAI(OpenAIChatResponse),
27 Gemini(GenerateContentResponse),
28}
29
30#[pymethods]
31impl ChatResponse {
32 pub fn to_py<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, AgentError> {
33 match self {
35 ChatResponse::OpenAI(resp) => Ok(resp.clone().into_bound_py_any(py)?),
36 ChatResponse::Gemini(resp) => Ok(resp.clone().into_bound_py_any(py)?),
37 }
38 }
39 pub fn __str__(&self) -> String {
40 match self {
41 ChatResponse::OpenAI(resp) => PyHelperFuncs::__str__(resp),
42 ChatResponse::Gemini(resp) => PyHelperFuncs::__str__(resp),
43 }
44 }
45}
46
47impl ChatResponse {
48 pub fn is_empty(&self) -> bool {
49 match self {
50 ChatResponse::OpenAI(resp) => resp.choices.is_empty(),
51 ChatResponse::Gemini(resp) => resp.candidates.is_empty(),
52 }
53 }
54
55 #[instrument(skip_all)]
56 pub fn to_message(&self, role: Role) -> Result<Vec<Message>, AgentError> {
57 debug!("Converting chat response to message with role");
58 match self {
59 ChatResponse::OpenAI(resp) => {
60 let first_choice = resp
61 .choices
62 .first()
63 .ok_or_else(|| AgentError::ClientNoResponseError)?;
64
65 let content =
66 PromptContent::Str(first_choice.message.content.clone().unwrap_or_default());
67
68 Ok(vec![Message::from(content, role)])
69 }
70
71 ChatResponse::Gemini(resp) => {
72 let content = resp
73 .candidates
74 .first()
75 .ok_or_else(|| AgentError::ClientNoResponseError)?
76 .content
77 .parts
78 .first()
79 .and_then(|part| part.text.as_ref())
80 .map(|s| s.as_str())
81 .unwrap_or("")
82 .to_string();
83
84 Ok(vec![Message::from(PromptContent::Str(content), role)])
85 }
86 }
87 }
88
89 pub fn to_python<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, AgentError> {
90 match self {
91 ChatResponse::OpenAI(resp) => Ok(resp.clone().into_bound_py_any(py)?),
92 ChatResponse::Gemini(resp) => Ok(resp.clone().into_bound_py_any(py)?),
93 }
94 }
95
96 pub fn id(&self) -> String {
97 match self {
98 ChatResponse::OpenAI(resp) => resp.id.clone(),
99 ChatResponse::Gemini(resp) => resp.response_id.clone().unwrap_or("".to_string()),
100 }
101 }
102
103 pub fn content(&self) -> Option<String> {
105 match self {
106 ChatResponse::OpenAI(resp) => {
107 resp.choices.first().and_then(|c| c.message.content.clone())
108 }
109 ChatResponse::Gemini(resp) => resp
110 .candidates
111 .first()
112 .and_then(|c| c.content.parts.first())
113 .and_then(|part| part.text.as_ref().map(|s| s.to_string())),
114 }
115 }
116
117 pub fn tool_calls(&self) -> Option<Value> {
119 match self {
120 ChatResponse::OpenAI(resp) => {
121 let tool_calls: Option<&Vec<ToolCall>> =
122 resp.choices.first().map(|c| c.message.tool_calls.as_ref());
123 tool_calls.and_then(|tc| serde_json::to_value(tc).ok())
124 }
125 ChatResponse::Gemini(resp) => {
126 let function_calls: Vec<&FunctionCall> = resp
128 .candidates
129 .first()?
130 .content
131 .parts
132 .iter()
133 .filter_map(|part| part.function_call.as_ref())
134 .collect();
135
136 if function_calls.is_empty() {
137 None
138 } else {
139 serde_json::to_value(&function_calls).ok()
140 }
141 }
142 }
143 }
144
145 pub fn extract_structured_data(&self) -> Option<Value> {
147 if let Some(content) = self.content() {
148 serde_json::from_str(&content).ok()
149 } else {
150 self.tool_calls()
151 }
152 }
153}
154
155#[pyclass]
156#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
157pub struct AgentResponse {
158 pub id: String,
159 pub response: ChatResponse,
160}
161
162#[pymethods]
163impl AgentResponse {
164 pub fn token_usage(&self) -> Usage {
165 match &self.response {
166 ChatResponse::OpenAI(resp) => resp.usage.clone(),
167 ChatResponse::Gemini(resp) => resp.get_token_usage(),
168 }
169 }
170
171 }
178
179impl AgentResponse {
180 pub fn new(id: String, response: ChatResponse) -> Self {
181 Self { id, response }
182 }
183
184 pub fn content(&self) -> Option<String> {
185 match &self.response {
186 ChatResponse::OpenAI(resp) => resp.get_content(),
187 ChatResponse::Gemini(resp) => resp.get_content(),
188 }
189 }
190
191 pub fn log_probs(&self) -> Vec<ResponseLogProbs> {
192 match &self.response {
193 ChatResponse::OpenAI(resp) => resp.get_log_probs(),
194 ChatResponse::Gemini(resp) => resp.get_log_probs(),
195 }
196 }
197}
198
199#[pyclass(name = "AgentResponse")]
200#[derive(Debug, Serialize)]
201pub struct PyAgentResponse {
202 pub response: AgentResponse,
203
204 #[serde(skip_serializing)]
205 pub output_type: Option<PyObject>,
206
207 #[pyo3(get)]
208 pub failed_conversion: bool,
209}
210
211#[pymethods]
212impl PyAgentResponse {
213 #[getter]
214 pub fn id(&self) -> &str {
215 &self.response.id
216 }
217
218 #[getter]
219 pub fn token_usage(&self) -> Usage {
220 self.response.token_usage()
221 }
222
223 #[getter]
224 pub fn log_probs(&self) -> LogProbs {
225 LogProbs {
226 tokens: self.response.log_probs(),
227 }
228 }
229
230 #[getter]
241 #[instrument(skip_all)]
242 pub fn result<'py>(&mut self, py: Python<'py>) -> Result<Bound<'py, PyAny>, AgentError> {
243 let content_value = self.response.content();
244
245 if content_value.is_none() {
247 return Ok(py.None().into_bound_py_any(py)?);
248 }
249 let content_value = content_value.unwrap();
251
252 match &self.output_type {
253 Some(output_type) => {
254 let bound = output_type
257 .bind(py)
258 .call_method1("model_validate_json", (&content_value,));
259
260 match bound {
261 Ok(obj) => {
262 Ok(obj)
264 }
265 Err(err) => {
266 warn!("Failed to validate model: {}", err);
269 self.failed_conversion = true;
270 let val = serde_json::from_str::<Value>(&content_value)?;
271 Ok(json_to_pyobject(py, &val)?.into_bound_py_any(py)?)
272 }
273 }
274 }
275 None => {
276 let val = Value::String(content_value);
278 Ok(json_to_pyobject(py, &val)?.into_bound_py_any(py)?)
279 }
280 }
281 }
282}
283
284impl PyAgentResponse {
285 pub fn new(response: AgentResponse, output_type: Option<PyObject>) -> Self {
286 Self {
287 response,
288 output_type,
289 failed_conversion: false,
290 }
291 }
292}