potato_agent/agents/
types.rs1use crate::agents::error::AgentError;
2use potato_provider::ChatResponse;
3use potato_util::utils::ResponseLogProbs;
4use potato_util::utils::TokenLogProbs;
5use potato_util::PyHelperFuncs;
6use pyo3::prelude::*;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use tracing::instrument;
10use tracing::warn;
11
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
13pub struct AgentResponse {
14 pub id: String,
15 pub response: ChatResponse,
16}
17
18impl AgentResponse {
19 pub fn token_usage<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, AgentError> {
20 Ok(self.response.token_usage(py)?)
21 }
22
23 pub fn response<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, AgentError> {
25 Ok(self.response.to_bound_py_object(py)?)
26 }
27}
28
29impl AgentResponse {
30 pub fn new(id: String, response: ChatResponse) -> Self {
31 Self { id, response }
32 }
33
34 pub fn log_probs(&self) -> Vec<TokenLogProbs> {
35 self.response.get_log_probs()
36 }
37
38 pub fn structured_output<'py>(
39 &self,
40 py: Python<'py>,
41 output_type: Option<&Bound<'py, PyAny>>,
42 ) -> Result<Bound<'py, PyAny>, AgentError> {
43 Ok(self.response.structured_output(py, output_type)?)
44 }
45
46 pub fn response_text(&self) -> String {
47 self.response.response_text()
48 }
49
50 pub fn response_value(&self) -> Option<Value> {
51 self.response.extract_structured_data()
52 }
53}
54
55#[pyclass(name = "AgentResponse")]
56#[derive(Debug, Serialize)]
57pub struct PyAgentResponse {
58 pub inner: AgentResponse,
59
60 #[serde(skip_serializing)]
61 pub output_type: Option<Py<PyAny>>,
62
63 #[pyo3(get)]
64 pub failed_conversion: bool,
65}
66
67#[pymethods]
68impl PyAgentResponse {
69 #[getter]
70 pub fn id(&self) -> &str {
71 &self.inner.id
72 }
73
74 #[getter]
76 pub fn token_usage<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, AgentError> {
77 self.inner.token_usage(py)
78 }
79
80 #[getter]
82 pub fn response<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, AgentError> {
83 self.inner.response(py)
84 }
85
86 #[getter]
87 pub fn log_probs(&self) -> ResponseLogProbs {
88 ResponseLogProbs {
89 tokens: self.inner.log_probs(),
90 }
91 }
92
93 #[getter]
94 #[instrument(skip_all)]
95 pub fn structured_output<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, AgentError> {
96 let bound = self
97 .output_type
98 .as_ref()
99 .map(|output_type| output_type.bind(py));
100 self.inner.structured_output(py, bound)
101 }
102
103 pub fn __str__(&self) -> String {
104 PyHelperFuncs::__str__(self)
105 }
106
107 pub fn response_text(&self) -> String {
108 self.inner.response_text()
109 }
110
111 #[classmethod]
112 pub fn __class_getitem__<'a>(
113 cls: Bound<'a, pyo3::types::PyType>,
114 item: &'a Bound<'a, PyAny>,
115 ) -> PyResult<Bound<'a, PyAny>> {
116 let py = cls.py();
117 let types = py.import("types")?;
118 let generic_alias = types.getattr("GenericAlias")?;
119 generic_alias.call1((cls, item))
120 }
121}
122
123impl PyAgentResponse {
124 pub fn new(response: AgentResponse, output_type: Option<Py<PyAny>>) -> Self {
125 Self {
126 inner: response,
127 output_type,
128 failed_conversion: false,
129 }
130 }
131}