Skip to main content

potato_agent/agents/
types.rs

1use crate::agents::error::AgentError;
2use potato_provider::ChatResponse;
3use potato_util::utils::ResponseLogProbs;
4use potato_util::utils::TokenLogProbs;
5use potato_util::PyHelperFuncs;
6use pyo3::prelude::*;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use tracing::instrument;
10use tracing::warn;
11
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
13pub struct AgentResponse {
14    pub id: String,
15    pub response: ChatResponse,
16}
17
18impl AgentResponse {
19    pub fn token_usage<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, AgentError> {
20        Ok(self.response.token_usage(py)?)
21    }
22
23    /// Returns the response as a Python object
24    pub fn response<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, AgentError> {
25        Ok(self.response.to_bound_py_object(py)?)
26    }
27}
28
29impl AgentResponse {
30    pub fn new(id: String, response: ChatResponse) -> Self {
31        Self { id, response }
32    }
33
34    pub fn log_probs(&self) -> Vec<TokenLogProbs> {
35        self.response.get_log_probs()
36    }
37
38    pub fn structured_output<'py>(
39        &self,
40        py: Python<'py>,
41        output_type: Option<&Bound<'py, PyAny>>,
42    ) -> Result<Bound<'py, PyAny>, AgentError> {
43        Ok(self.response.structured_output(py, output_type)?)
44    }
45
46    pub fn response_text(&self) -> String {
47        self.response.response_text()
48    }
49
50    pub fn response_value(&self) -> Option<Value> {
51        self.response.extract_structured_data()
52    }
53}
54
55#[pyclass(name = "AgentResponse")]
56#[derive(Debug, Serialize)]
57pub struct PyAgentResponse {
58    pub inner: AgentResponse,
59
60    #[serde(skip_serializing)]
61    pub output_type: Option<Py<PyAny>>,
62
63    #[pyo3(get)]
64    pub failed_conversion: bool,
65}
66
67#[pymethods]
68impl PyAgentResponse {
69    #[getter]
70    pub fn id(&self) -> &str {
71        &self.inner.id
72    }
73
74    /// Return the token usage of the response
75    #[getter]
76    pub fn token_usage<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, AgentError> {
77        self.inner.token_usage(py)
78    }
79
80    /// Returns the actual response object from the provider
81    #[getter]
82    pub fn response<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, AgentError> {
83        self.inner.response(py)
84    }
85
86    #[getter]
87    pub fn log_probs(&self) -> ResponseLogProbs {
88        ResponseLogProbs {
89            tokens: self.inner.log_probs(),
90        }
91    }
92
93    #[getter]
94    #[instrument(skip_all)]
95    pub fn structured_output<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, AgentError> {
96        let bound = self
97            .output_type
98            .as_ref()
99            .map(|output_type| output_type.bind(py));
100        self.inner.structured_output(py, bound)
101    }
102
103    pub fn __str__(&self) -> String {
104        PyHelperFuncs::__str__(self)
105    }
106
107    pub fn response_text(&self) -> String {
108        self.inner.response_text()
109    }
110
111    #[classmethod]
112    pub fn __class_getitem__<'a>(
113        cls: Bound<'a, pyo3::types::PyType>,
114        item: &'a Bound<'a, PyAny>,
115    ) -> PyResult<Bound<'a, PyAny>> {
116        let py = cls.py();
117        let types = py.import("types")?;
118        let generic_alias = types.getattr("GenericAlias")?;
119        generic_alias.call1((cls, item))
120    }
121}
122
123impl PyAgentResponse {
124    pub fn new(response: AgentResponse, output_type: Option<Py<PyAny>>) -> Self {
125        Self {
126            inner: response,
127            output_type,
128            failed_conversion: false,
129        }
130    }
131}