potato_type/
lib.rs

1pub mod error;
2
3pub use crate::error::TypeError;
4use pyo3::prelude::*;
5use schemars::JsonSchema;
6use serde::de::Error;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::any::type_name;
10use std::fmt;
11use std::fmt::Display;
12use std::path::{Path, PathBuf};
13use tracing::error;
14pub mod google;
15pub mod openai;
16
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
18#[pyclass]
19pub enum Model {
20    Undefined,
21}
22
23impl Model {
24    pub fn as_str(&self) -> &str {
25        match self {
26            Model::Undefined => "undefined",
27        }
28    }
29
30    pub fn from_string(s: &str) -> Result<Self, TypeError> {
31        match s.to_lowercase().as_str() {
32            "undefined" => Ok(Model::Undefined),
33            _ => Err(TypeError::UnknownModelError(s.to_string())),
34        }
35    }
36}
37
38pub enum Common {
39    Undefined,
40}
41
42impl Common {
43    pub fn as_str(&self) -> &str {
44        match self {
45            Common::Undefined => "undefined",
46        }
47    }
48}
49
50impl Display for Common {
51    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
52        write!(f, "{}", self.as_str())
53    }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
57#[pyclass]
58pub enum Provider {
59    OpenAI,
60    Gemini,
61    Google,
62    Vertex,
63    Undefined, // Added Undefined for better error handling
64}
65
66impl Provider {
67    pub fn from_string(s: &str) -> Result<Self, TypeError> {
68        match s.to_lowercase().as_str() {
69            "openai" => Ok(Provider::OpenAI),
70            "gemini" => Ok(Provider::Gemini),
71            "google" => Ok(Provider::Google),
72            "vertex" => Ok(Provider::Vertex),
73            "undefined" => Ok(Provider::Undefined), // Handle undefined case
74            _ => Err(TypeError::UnknownProviderError(s.to_string())),
75        }
76    }
77
78    /// Extract provider from a PyAny object
79    ///
80    /// # Arguments
81    /// * `provider` - PyAny object
82    ///
83    /// # Returns
84    /// * `Result<Provider, AgentError>` - Result
85    ///
86    /// # Errors
87    /// * `AgentError` - Error
88    pub fn extract_provider(provider: &Bound<'_, PyAny>) -> Result<Provider, TypeError> {
89        match provider.is_instance_of::<Provider>() {
90            true => Ok(provider.extract::<Provider>().inspect_err(|e| {
91                error!("Failed to extract provider: {}", e);
92            })?),
93            false => {
94                let provider = provider.extract::<String>().unwrap();
95                Ok(Provider::from_string(&provider).inspect_err(|e| {
96                    error!("Failed to convert string to provider: {}", e);
97                })?)
98            }
99        }
100    }
101
102    pub fn as_str(&self) -> &str {
103        match self {
104            Provider::OpenAI => "openai",
105            Provider::Gemini => "gemini",
106            Provider::Vertex => "vertex",
107            Provider::Google => "google",
108            Provider::Undefined => "undefined", // Added Undefined case
109        }
110    }
111}
112
113impl Display for Provider {
114    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
115        write!(f, "{}", self.as_str())
116    }
117}
118
119#[pyclass(eq, eq_int)]
120#[derive(Debug, PartialEq, Clone)]
121pub enum SaveName {
122    Prompt,
123}
124
125#[pymethods]
126impl SaveName {
127    #[staticmethod]
128    pub fn from_string(s: &str) -> Option<Self> {
129        match s {
130            "prompt" => Some(SaveName::Prompt),
131
132            _ => None,
133        }
134    }
135
136    pub fn as_string(&self) -> &str {
137        match self {
138            SaveName::Prompt => "prompt",
139        }
140    }
141
142    pub fn __str__(&self) -> String {
143        self.to_string()
144    }
145}
146
147impl Display for SaveName {
148    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
149        write!(f, "{}", self.as_string())
150    }
151}
152
153impl AsRef<Path> for SaveName {
154    fn as_ref(&self) -> &Path {
155        match self {
156            SaveName::Prompt => Path::new("prompt"),
157        }
158    }
159}
160
161// impl PathBuf: From<SaveName>
162impl From<SaveName> for PathBuf {
163    fn from(save_name: SaveName) -> Self {
164        PathBuf::from(save_name.as_ref())
165    }
166}
167
168/// A trait for structured output types that can be used with potatohead prompts agents and workflows.
169///
170/// # Example
171/// ```rust
172/// use potato_macros::StructureOutput;
173/// use serde::{Serialize, Deserialize};
174/// use schemars::JsonSchema;
175///
176/// #[derive(Serialize, Deserialize, JsonSchema)]
177/// struct MyOutput {
178///     message: String,
179///     value: i32,
180/// }
181///
182/// impl StructuredOutput for MyOutput {}
183///
184/// let schema = MyOutput::get_structured_output_schema();
185/// ```
186pub trait StructuredOutput: for<'de> serde::Deserialize<'de> + JsonSchema {
187    fn type_name() -> &'static str {
188        type_name::<Self>().rsplit("::").next().unwrap_or("Unknown")
189    }
190
191    /// Validates and deserializes a JSON value into its struct type.
192    ///
193    /// # Arguments
194    /// * `value` - The JSON value to deserialize
195    ///
196    /// # Returns
197    /// * `Result<Self, serde_json::Error>` - The deserialized value or error
198    fn model_validate_json_value(value: &Value) -> Result<Self, serde_json::Error> {
199        match &value {
200            Value::String(json_str) => Self::model_validate_json_str(json_str),
201            Value::Object(_) => {
202                // Content is already a JSON object
203                serde_json::from_value(value.clone())
204            }
205            _ => {
206                // If the value is not a string or object, we cannot deserialize it
207                Err(Error::custom("Expected a JSON string or object"))
208            }
209        }
210    }
211
212    fn model_validate_json_str(value: &str) -> Result<Self, serde_json::Error> {
213        serde_json::from_str(value)
214    }
215
216    /// Generates an OpenAI-compatible JSON schema.
217    ///
218    /// # Returns
219    /// * `Value` - The JSON schema wrapped in OpenAI's format
220    fn get_structured_output_schema() -> Value {
221        let schema = ::schemars::schema_for!(Self);
222        schema.into()
223    }
224    // add fallback parsing logic
225}
226
227#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
228#[pyclass]
229pub enum SettingsType {
230    GoogleChat,
231    OpenAIChat,
232    ModelSettings,
233}