potato_type/
lib.rs

1pub mod error;
2
3pub use crate::error::TypeError;
4use pyo3::prelude::*;
5use schemars::JsonSchema;
6use serde::de::Error;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::any::type_name;
10use std::fmt;
11use std::fmt::Display;
12use std::path::{Path, PathBuf};
13use tracing::error;
14pub mod google;
15pub mod openai;
16
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
18#[pyclass]
19pub enum Model {
20    Undefined,
21}
22
23impl Model {
24    pub fn as_str(&self) -> &str {
25        match self {
26            Model::Undefined => "undefined",
27        }
28    }
29
30    pub fn from_string(s: &str) -> Result<Self, TypeError> {
31        match s.to_lowercase().as_str() {
32            "undefined" => Ok(Model::Undefined),
33            _ => Err(TypeError::UnknownModelError(s.to_string())),
34        }
35    }
36}
37
38pub enum Common {
39    Undefined,
40}
41
42impl Common {
43    pub fn as_str(&self) -> &str {
44        match self {
45            Common::Undefined => "undefined",
46        }
47    }
48}
49
50impl Display for Common {
51    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
52        write!(f, "{}", self.as_str())
53    }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
57#[pyclass]
58pub enum Provider {
59    OpenAI,
60    Gemini,
61    Vertex,
62    Undefined, // Added Undefined for better error handling
63}
64
65impl Provider {
66    pub fn from_string(s: &str) -> Result<Self, TypeError> {
67        match s.to_lowercase().as_str() {
68            "openai" => Ok(Provider::OpenAI),
69            "gemini" => Ok(Provider::Gemini),
70            "vertex" => Ok(Provider::Vertex),
71            "undefined" => Ok(Provider::Undefined), // Handle undefined case
72            _ => Err(TypeError::UnknownProviderError(s.to_string())),
73        }
74    }
75
76    /// Extract provider from a PyAny object
77    ///
78    /// # Arguments
79    /// * `provider` - PyAny object
80    ///
81    /// # Returns
82    /// * `Result<Provider, AgentError>` - Result
83    ///
84    /// # Errors
85    /// * `AgentError` - Error
86    pub fn extract_provider(provider: &Bound<'_, PyAny>) -> Result<Provider, TypeError> {
87        match provider.is_instance_of::<Provider>() {
88            true => Ok(provider.extract::<Provider>().inspect_err(|e| {
89                error!("Failed to extract provider: {}", e);
90            })?),
91            false => {
92                let provider = provider.extract::<String>().unwrap();
93                Ok(Provider::from_string(&provider).inspect_err(|e| {
94                    error!("Failed to convert string to provider: {}", e);
95                })?)
96            }
97        }
98    }
99
100    pub fn as_str(&self) -> &str {
101        match self {
102            Provider::OpenAI => "openai",
103            Provider::Gemini => "gemini",
104            Provider::Vertex => "vertex",
105            Provider::Undefined => "undefined", // Added Undefined case
106        }
107    }
108}
109
110impl Display for Provider {
111    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
112        write!(f, "{}", self.as_str())
113    }
114}
115
116#[pyclass(eq, eq_int)]
117#[derive(Debug, PartialEq, Clone)]
118pub enum SaveName {
119    Prompt,
120}
121
122#[pymethods]
123impl SaveName {
124    #[staticmethod]
125    pub fn from_string(s: &str) -> Option<Self> {
126        match s {
127            "prompt" => Some(SaveName::Prompt),
128
129            _ => None,
130        }
131    }
132
133    pub fn as_string(&self) -> &str {
134        match self {
135            SaveName::Prompt => "prompt",
136        }
137    }
138
139    pub fn __str__(&self) -> String {
140        self.to_string()
141    }
142}
143
144impl Display for SaveName {
145    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
146        write!(f, "{}", self.as_string())
147    }
148}
149
150impl AsRef<Path> for SaveName {
151    fn as_ref(&self) -> &Path {
152        match self {
153            SaveName::Prompt => Path::new("prompt"),
154        }
155    }
156}
157
158// impl PathBuf: From<SaveName>
159impl From<SaveName> for PathBuf {
160    fn from(save_name: SaveName) -> Self {
161        PathBuf::from(save_name.as_ref())
162    }
163}
164
165/// A trait for structured output types that can be used with potatohead prompts agents and workflows.
166///
167/// # Example
168/// ```rust
169/// use potato_macros::StructureOutput;
170/// use serde::{Serialize, Deserialize};
171/// use schemars::JsonSchema;
172///
173/// #[derive(Serialize, Deserialize, JsonSchema)]
174/// struct MyOutput {
175///     message: String,
176///     value: i32,
177/// }
178///
179/// impl StructuredOutput for MyOutput {}
180///
181/// let schema = MyOutput::get_structured_output_schema();
182/// ```
183pub trait StructuredOutput: for<'de> serde::Deserialize<'de> + JsonSchema {
184    fn type_name() -> &'static str {
185        type_name::<Self>().rsplit("::").next().unwrap_or("Unknown")
186    }
187
188    /// Validates and deserializes a JSON value into its struct type.
189    ///
190    /// # Arguments
191    /// * `value` - The JSON value to deserialize
192    ///
193    /// # Returns
194    /// * `Result<Self, serde_json::Error>` - The deserialized value or error
195    fn model_validate_json_value(value: &Value) -> Result<Self, serde_json::Error> {
196        match &value {
197            Value::String(json_str) => Self::model_validate_json_str(json_str),
198            Value::Object(_) => {
199                // Content is already a JSON object
200                serde_json::from_value(value.clone())
201            }
202            _ => {
203                // If the value is not a string or object, we cannot deserialize it
204                Err(Error::custom("Expected a JSON string or object"))
205            }
206        }
207    }
208
209    fn model_validate_json_str(value: &str) -> Result<Self, serde_json::Error> {
210        serde_json::from_str(value)
211    }
212
213    /// Generates an OpenAI-compatible JSON schema.
214    ///
215    /// # Returns
216    /// * `Value` - The JSON schema wrapped in OpenAI's format
217    fn get_structured_output_schema() -> Value {
218        let schema = ::schemars::schema_for!(Self);
219        schema.into()
220    }
221    // add fallback parsing logic
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
225#[pyclass]
226pub enum SettingsType {
227    GoogleChat,
228    OpenAIChat,
229    ModelSettings,
230}