potato_type/
lib.rs

1pub mod error;
2
3pub use crate::error::TypeError;
4use pyo3::prelude::*;
5use schemars::JsonSchema;
6use serde::de::Error;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::any::type_name;
10use std::fmt;
11use std::fmt::Display;
12use std::path::{Path, PathBuf};
13use tracing::error;
14pub mod google;
15pub mod openai;
16
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
18#[pyclass]
19pub enum Model {
20    Undefined,
21}
22
23impl Model {
24    pub fn as_str(&self) -> &str {
25        match self {
26            Model::Undefined => "undefined",
27        }
28    }
29
30    pub fn from_string(s: &str) -> Result<Self, TypeError> {
31        match s.to_lowercase().as_str() {
32            "undefined" => Ok(Model::Undefined),
33            _ => Err(TypeError::UnknownModelError(s.to_string())),
34        }
35    }
36}
37
38pub enum Common {
39    Undefined,
40}
41
42impl Common {
43    pub fn as_str(&self) -> &str {
44        match self {
45            Common::Undefined => "undefined",
46        }
47    }
48}
49
50impl Display for Common {
51    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
52        write!(f, "{}", self.as_str())
53    }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
57#[pyclass]
58pub enum Provider {
59    OpenAI,
60    Gemini,
61    Undefined, // Added Undefined for better error handling
62}
63
64impl Provider {
65    pub fn url(&self) -> &str {
66        match self {
67            Provider::OpenAI => "https://api.openai.com/v1",
68
69            //https://cloud.google.com/vertex-ai/generative-ai/docs/migrate/migrate-google-ai
70            Provider::Gemini => "https://generativelanguage.googleapis.com/v1beta/models",
71
72            Provider::Undefined => {
73                error!("Undefined provider URL requested");
74                "https://undefined.provider.url"
75            }
76        }
77    }
78
79    pub fn from_string(s: &str) -> Result<Self, TypeError> {
80        match s.to_lowercase().as_str() {
81            "openai" => Ok(Provider::OpenAI),
82            "gemini" => Ok(Provider::Gemini),
83            "undefined" => Ok(Provider::Undefined), // Handle undefined case
84            _ => Err(TypeError::UnknownProviderError(s.to_string())),
85        }
86    }
87
88    /// Extract provider from a PyAny object
89    ///
90    /// # Arguments
91    /// * `provider` - PyAny object
92    ///
93    /// # Returns
94    /// * `Result<Provider, AgentError>` - Result
95    ///
96    /// # Errors
97    /// * `AgentError` - Error
98    pub fn extract_provider(provider: &Bound<'_, PyAny>) -> Result<Provider, TypeError> {
99        match provider.is_instance_of::<Provider>() {
100            true => Ok(provider.extract::<Provider>().inspect_err(|e| {
101                error!("Failed to extract provider: {}", e);
102            })?),
103            false => {
104                let provider = provider.extract::<String>().unwrap();
105                Ok(Provider::from_string(&provider).inspect_err(|e| {
106                    error!("Failed to convert string to provider: {}", e);
107                })?)
108            }
109        }
110    }
111
112    pub fn as_str(&self) -> &str {
113        match self {
114            Provider::OpenAI => "openai",
115            Provider::Gemini => "gemini",
116            Provider::Undefined => "undefined", // Added Undefined case
117        }
118    }
119}
120
121impl Display for Provider {
122    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
123        write!(f, "{}", self.as_str())
124    }
125}
126
127#[pyclass(eq, eq_int)]
128#[derive(Debug, PartialEq, Clone)]
129pub enum SaveName {
130    Prompt,
131}
132
133#[pymethods]
134impl SaveName {
135    #[staticmethod]
136    pub fn from_string(s: &str) -> Option<Self> {
137        match s {
138            "prompt" => Some(SaveName::Prompt),
139
140            _ => None,
141        }
142    }
143
144    pub fn as_string(&self) -> &str {
145        match self {
146            SaveName::Prompt => "prompt",
147        }
148    }
149
150    pub fn __str__(&self) -> String {
151        self.to_string()
152    }
153}
154
155impl Display for SaveName {
156    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
157        write!(f, "{}", self.as_string())
158    }
159}
160
161impl AsRef<Path> for SaveName {
162    fn as_ref(&self) -> &Path {
163        match self {
164            SaveName::Prompt => Path::new("prompt"),
165        }
166    }
167}
168
169// impl PathBuf: From<SaveName>
170impl From<SaveName> for PathBuf {
171    fn from(save_name: SaveName) -> Self {
172        PathBuf::from(save_name.as_ref())
173    }
174}
175
176/// A trait for structured output types that can be used with potatohead prompts agents and workflows.
177///
178/// # Example
179/// ```rust
180/// use potato_macros::StructureOutput;
181/// use serde::{Serialize, Deserialize};
182/// use schemars::JsonSchema;
183///
184/// #[derive(Serialize, Deserialize, JsonSchema)]
185/// struct MyOutput {
186///     message: String,
187///     value: i32,
188/// }
189///
190/// impl StructuredOutput for MyOutput {}
191///
192/// let schema = MyOutput::get_structured_output_schema();
193/// ```
194pub trait StructuredOutput: for<'de> serde::Deserialize<'de> + JsonSchema {
195    fn type_name() -> &'static str {
196        type_name::<Self>().rsplit("::").next().unwrap_or("Unknown")
197    }
198
199    /// Validates and deserializes a JSON value into its struct type.
200    ///
201    /// # Arguments
202    /// * `value` - The JSON value to deserialize
203    ///
204    /// # Returns
205    /// * `Result<Self, serde_json::Error>` - The deserialized value or error
206    fn model_validate_json_value(value: &Value) -> Result<Self, serde_json::Error> {
207        match &value {
208            Value::String(json_str) => Self::model_validate_json_str(json_str),
209            Value::Object(_) => {
210                // Content is already a JSON object
211                serde_json::from_value(value.clone())
212            }
213            _ => {
214                // If the value is not a string or object, we cannot deserialize it
215                Err(Error::custom("Expected a JSON string or object"))
216            }
217        }
218    }
219
220    fn model_validate_json_str(value: &str) -> Result<Self, serde_json::Error> {
221        serde_json::from_str(value)
222    }
223
224    /// Generates an OpenAI-compatible JSON schema.
225    ///
226    /// # Returns
227    /// * `Value` - The JSON schema wrapped in OpenAI's format
228    fn get_structured_output_schema() -> Value {
229        let schema = ::schemars::schema_for!(Self);
230        schema.into()
231    }
232    // add fallback parsing logic
233}
234
235#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
236#[pyclass]
237pub enum SettingsType {
238    GoogleChat,
239    OpenAIChat,
240    ModelSettings,
241}