potato_agent/agents/
types.rs1use crate::agents::error::AgentError;
2use potato_provider::ChatResponse;
3use potato_provider::LogProbExt;
4use potato_provider::ResponseExt;
5use potato_provider::Usage;
6use potato_util::json_to_pyobject;
7use potato_util::utils::{LogProbs, ResponseLogProbs};
8use pyo3::prelude::*;
9use pyo3::IntoPyObjectExt;
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use tracing::instrument;
13use tracing::warn;
14
15#[pyclass]
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
17pub struct AgentResponse {
18 pub id: String,
19 pub response: ChatResponse,
20}
21
22#[pymethods]
23impl AgentResponse {
24 pub fn token_usage(&self) -> Result<Usage, AgentError> {
25 match &self.response {
26 ChatResponse::OpenAI(resp) => Ok(resp.usage.clone()),
27 ChatResponse::Gemini(resp) => Ok(resp.get_token_usage()),
28 ChatResponse::VertexGenerate(resp) => Ok(resp.get_token_usage()),
29 _ => Err(AgentError::NotSupportedError(
30 "Token usage not supported for the vertex predict response type".to_string(),
31 )),
32 }
33 }
34
35 }
42
43impl AgentResponse {
44 pub fn new(id: String, response: ChatResponse) -> Self {
45 Self { id, response }
46 }
47
48 pub fn content(&self) -> Option<String> {
49 match &self.response {
50 ChatResponse::OpenAI(resp) => resp.get_content(),
51 ChatResponse::Gemini(resp) => resp.get_content(),
52 ChatResponse::VertexGenerate(resp) => resp.get_content(),
53 _ => {
54 warn!("Content not available for this response type");
55 None
56 }
57 }
58 }
59
60 pub fn log_probs(&self) -> Vec<ResponseLogProbs> {
61 match &self.response {
62 ChatResponse::OpenAI(resp) => resp.get_log_probs(),
63 ChatResponse::Gemini(resp) => resp.get_log_probs(),
64 ChatResponse::VertexGenerate(resp) => resp.get_log_probs(),
65 _ => {
66 warn!("Log probabilities not available for this response type");
67 vec![]
68 }
69 }
70 }
71}
72
73#[pyclass(name = "AgentResponse")]
74#[derive(Debug, Serialize)]
75pub struct PyAgentResponse {
76 pub response: AgentResponse,
77
78 #[serde(skip_serializing)]
79 pub output_type: Option<Py<PyAny>>,
80
81 #[pyo3(get)]
82 pub failed_conversion: bool,
83}
84
85#[pymethods]
86impl PyAgentResponse {
87 #[getter]
88 pub fn id(&self) -> &str {
89 &self.response.id
90 }
91
92 #[getter]
93 pub fn token_usage(&self) -> Result<Usage, AgentError> {
94 self.response.token_usage()
95 }
96
97 #[getter]
98 pub fn log_probs(&self) -> LogProbs {
99 LogProbs {
100 tokens: self.response.log_probs(),
101 }
102 }
103
104 #[getter]
115 #[instrument(skip_all)]
116 pub fn result<'py>(&mut self, py: Python<'py>) -> Result<Bound<'py, PyAny>, AgentError> {
117 let content_value = self.response.content();
118
119 if content_value.is_none() {
121 return Ok(py.None().into_bound_py_any(py)?);
122 }
123 let content_value = content_value.unwrap();
125
126 match &self.output_type {
127 Some(output_type) => {
128 let bound = output_type
131 .bind(py)
132 .call_method1("model_validate_json", (&content_value,));
133
134 match bound {
135 Ok(obj) => {
136 Ok(obj)
138 }
139 Err(err) => {
140 warn!("Failed to validate model: {}", err);
143 self.failed_conversion = true;
144 let val = serde_json::from_str::<Value>(&content_value)?;
145 Ok(json_to_pyobject(py, &val)?.into_bound_py_any(py)?)
146 }
147 }
148 }
149 None => {
150 let val = Value::String(content_value);
152 Ok(json_to_pyobject(py, &val)?.into_bound_py_any(py)?)
153 }
154 }
155 }
156}
157
158impl PyAgentResponse {
159 pub fn new(response: AgentResponse, output_type: Option<Py<PyAny>>) -> Self {
160 Self {
161 response,
162 output_type,
163 failed_conversion: false,
164 }
165 }
166}