Skip to main content

potato_type/
lib.rs

1pub mod error;
2
3pub use crate::error::TypeError;
4use pyo3::prelude::*;
5use schemars::JsonSchema;
6use serde::de::Error;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::any::type_name;
10use std::fmt;
11use std::fmt::Display;
12use std::path::{Path, PathBuf};
13use tracing::error;
14pub mod anthropic;
15pub mod common;
16pub mod google;
17pub mod openai;
18pub mod prompt;
19pub mod tools;
20pub mod traits;
21
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
23#[pyclass]
24pub enum Model {
25    Undefined,
26}
27
28impl Model {
29    pub fn as_str(&self) -> &str {
30        match self {
31            Model::Undefined => "undefined",
32        }
33    }
34
35    pub fn from_string(s: &str) -> Result<Self, TypeError> {
36        match s.to_lowercase().as_str() {
37            "undefined" => Ok(Model::Undefined),
38            _ => Err(TypeError::UnknownModelError(s.to_string())),
39        }
40    }
41}
42
43pub enum Common {
44    Undefined,
45}
46
47impl Common {
48    pub fn as_str(&self) -> &str {
49        match self {
50            Common::Undefined => "undefined",
51        }
52    }
53}
54
55impl Display for Common {
56    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
57        write!(f, "{}", self.as_str())
58    }
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
62#[pyclass(eq, eq_int)]
63pub enum Provider {
64    OpenAI,
65    Gemini,
66    Google,
67    Vertex,
68    Anthropic,
69    GoogleAdk,
70    Undefined, // Added Undefined for better error handling
71}
72
73impl Provider {
74    pub fn from_string(s: &str) -> Result<Self, TypeError> {
75        match s.to_lowercase().as_str() {
76            "openai" => Ok(Provider::OpenAI),
77            "gemini" => Ok(Provider::Gemini),
78            "google" => Ok(Provider::Google),
79            "vertex" => Ok(Provider::Vertex),
80            "anthropic" => Ok(Provider::Anthropic),
81            "google_adk" => Ok(Provider::GoogleAdk),
82            "undefined" => Ok(Provider::Undefined), // Handle undefined case
83            _ => Err(TypeError::UnknownProviderError(s.to_string())),
84        }
85    }
86
87    /// Extract provider from a PyAny object
88    ///
89    /// # Arguments
90    /// * `provider` - PyAny object
91    ///
92    /// # Returns
93    /// * `Result<Provider, AgentError>` - Result
94    ///
95    /// # Errors
96    /// * `AgentError` - Error
97    pub fn extract_provider(provider: &Bound<'_, PyAny>) -> Result<Provider, TypeError> {
98        match provider.is_instance_of::<Provider>() {
99            true => Ok(provider.extract::<Provider>().inspect_err(|e| {
100                error!("Failed to extract provider: {}", e);
101            })?),
102            false => {
103                let provider = provider.extract::<String>().unwrap();
104                Ok(Provider::from_string(&provider).inspect_err(|e| {
105                    error!("Failed to convert string to provider: {}", e);
106                })?)
107            }
108        }
109    }
110
111    pub fn as_str(&self) -> &str {
112        match self {
113            Provider::OpenAI => "openai",
114            Provider::Gemini => "gemini",
115            Provider::Vertex => "vertex",
116            Provider::Google => "google",
117            Provider::Anthropic => "anthropic",
118            Provider::GoogleAdk => "google_adk",
119            Provider::Undefined => "undefined", // Added Undefined case
120        }
121    }
122}
123
124impl Display for Provider {
125    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
126        write!(f, "{}", self.as_str())
127    }
128}
129
130#[pyclass(eq, eq_int)]
131#[derive(Debug, PartialEq, Clone)]
132pub enum SaveName {
133    Prompt,
134}
135
136#[pymethods]
137impl SaveName {
138    #[staticmethod]
139    pub fn from_string(s: &str) -> Option<Self> {
140        match s {
141            "prompt" => Some(SaveName::Prompt),
142
143            _ => None,
144        }
145    }
146
147    pub fn as_string(&self) -> &str {
148        match self {
149            SaveName::Prompt => "prompt",
150        }
151    }
152
153    pub fn __str__(&self) -> String {
154        self.to_string()
155    }
156}
157
158impl Display for SaveName {
159    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
160        write!(f, "{}", self.as_string())
161    }
162}
163
164impl AsRef<Path> for SaveName {
165    fn as_ref(&self) -> &Path {
166        match self {
167            SaveName::Prompt => Path::new("prompt"),
168        }
169    }
170}
171
172// impl PathBuf: From<SaveName>
173impl From<SaveName> for PathBuf {
174    fn from(save_name: SaveName) -> Self {
175        PathBuf::from(save_name.as_ref())
176    }
177}
178
179/// A trait for structured output types that can be used with potatohead prompts agents and workflows.
180///
181/// # Example
182/// ```rust
183/// use potatohead_macro::StructureOutput;
184/// use serde::{Serialize, Deserialize};
185/// use schemars::JsonSchema;
186///
187/// #[derive(Serialize, Deserialize, JsonSchema)]
188/// struct MyOutput {
189///     message: String,
190///     value: i32,
191/// }
192///
193/// impl StructuredOutput for MyOutput {}
194///
195/// let schema = MyOutput::get_structured_output_schema();
196/// ```
197pub trait StructuredOutput: for<'de> serde::Deserialize<'de> + JsonSchema {
198    fn type_name() -> &'static str {
199        type_name::<Self>().rsplit("::").next().unwrap_or("Unknown")
200    }
201
202    /// Validates and deserializes a JSON value into its struct type.
203    ///
204    /// # Arguments
205    /// * `value` - The JSON value to deserialize
206    ///
207    /// # Returns
208    /// * `Result<Self, serde_json::Error>` - The deserialized value or error
209    fn model_validate_json_value(value: &Value) -> Result<Self, serde_json::Error> {
210        match &value {
211            Value::String(json_str) => Self::model_validate_json_str(json_str),
212            Value::Object(_) => {
213                // Content is already a JSON object
214                serde_json::from_value(value.clone())
215            }
216            _ => {
217                // If the value is not a string or object, we cannot deserialize it
218                Err(Error::custom("Expected a JSON string or object"))
219            }
220        }
221    }
222
223    fn model_validate_json_str(value: &str) -> Result<Self, serde_json::Error> {
224        serde_json::from_str(value)
225    }
226
227    /// Generates an OpenAI-compatible JSON schema.
228    ///
229    /// # Returns
230    /// * `Value` - The JSON schema wrapped in OpenAI's format
231    fn get_structured_output_schema() -> Value {
232        let schema = ::schemars::schema_for!(Self);
233        schema.into()
234    }
235    // add fallback parsing logic
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
239#[pyclass]
240pub enum SettingsType {
241    GoogleChat,
242    OpenAIChat,
243    ModelSettings,
244    Anthropic,
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    #[test]
252    fn test_provider_google_adk_round_trip() {
253        let p = Provider::from_string("google_adk").unwrap();
254        assert_eq!(p, Provider::GoogleAdk);
255        assert_eq!(p.as_str(), "google_adk");
256    }
257
258    #[test]
259    fn test_provider_all_variants_round_trip() {
260        for (s, variant) in [
261            ("openai", Provider::OpenAI),
262            ("gemini", Provider::Gemini),
263            ("google", Provider::Google),
264            ("vertex", Provider::Vertex),
265            ("anthropic", Provider::Anthropic),
266            ("google_adk", Provider::GoogleAdk),
267            ("undefined", Provider::Undefined),
268        ] {
269            let parsed = Provider::from_string(s).unwrap();
270            assert_eq!(parsed, variant);
271            assert_eq!(parsed.as_str(), s);
272        }
273    }
274
275    #[test]
276    fn test_provider_unknown_string_errors() {
277        assert!(Provider::from_string("not_a_provider").is_err());
278    }
279}