potato_type/
lib.rs

1pub mod error;
2
3use crate::error::TypeError;
4use pyo3::prelude::*;
5use schemars::JsonSchema;
6use serde::de::Error;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::any::type_name;
10use std::fmt;
11use std::fmt::Display;
12use std::path::{Path, PathBuf};
13use tracing::error;
14
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
16#[pyclass]
17pub enum Model {
18    Undefined,
19}
20
21impl Model {
22    pub fn as_str(&self) -> &str {
23        match self {
24            Model::Undefined => "undefined",
25        }
26    }
27
28    pub fn from_string(s: &str) -> Result<Self, TypeError> {
29        match s.to_lowercase().as_str() {
30            "undefined" => Ok(Model::Undefined),
31            _ => Err(TypeError::UnknownModelError(s.to_string())),
32        }
33    }
34}
35
36pub enum Common {
37    Undefined,
38}
39
40impl Common {
41    pub fn as_str(&self) -> &str {
42        match self {
43            Common::Undefined => "undefined",
44        }
45    }
46}
47
48impl Display for Common {
49    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
50        write!(f, "{}", self.as_str())
51    }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
55#[pyclass]
56pub enum Provider {
57    OpenAI,
58    Gemini,
59    Undefined, // Added Undefined for better error handling
60}
61
62impl Provider {
63    pub fn url(&self) -> &str {
64        match self {
65            Provider::OpenAI => "https://api.openai.com/v1",
66            Provider::Gemini => "https://generativelanguage.googleapis.com/v1beta/models",
67            Provider::Undefined => {
68                error!("Undefined provider URL requested");
69                "https://undefined.provider.url"
70            }
71        }
72    }
73
74    pub fn from_string(s: &str) -> Result<Self, TypeError> {
75        match s.to_lowercase().as_str() {
76            "openai" => Ok(Provider::OpenAI),
77            "gemini" => Ok(Provider::Gemini),
78            "undefined" => Ok(Provider::Undefined), // Handle undefined case
79            _ => Err(TypeError::UnknownProviderError(s.to_string())),
80        }
81    }
82
83    /// Extract provider from a PyAny object
84    ///
85    /// # Arguments
86    /// * `provider` - PyAny object
87    ///
88    /// # Returns
89    /// * `Result<Provider, AgentError>` - Result
90    ///
91    /// # Errors
92    /// * `AgentError` - Error
93    pub fn extract_provider(provider: &Bound<'_, PyAny>) -> Result<Provider, TypeError> {
94        match provider.is_instance_of::<Provider>() {
95            true => Ok(provider.extract::<Provider>().inspect_err(|e| {
96                error!("Failed to extract provider: {}", e);
97            })?),
98            false => {
99                let provider = provider.extract::<String>().unwrap();
100                Ok(Provider::from_string(&provider).inspect_err(|e| {
101                    error!("Failed to convert string to provider: {}", e);
102                })?)
103            }
104        }
105    }
106
107    pub fn as_str(&self) -> &str {
108        match self {
109            Provider::OpenAI => "openai",
110            Provider::Gemini => "gemini",
111            Provider::Undefined => "undefined", // Added Undefined case
112        }
113    }
114}
115
116#[pyclass(eq, eq_int)]
117#[derive(Debug, PartialEq, Clone)]
118pub enum SaveName {
119    Prompt,
120}
121
122#[pymethods]
123impl SaveName {
124    #[staticmethod]
125    pub fn from_string(s: &str) -> Option<Self> {
126        match s {
127            "prompt" => Some(SaveName::Prompt),
128
129            _ => None,
130        }
131    }
132
133    pub fn as_string(&self) -> &str {
134        match self {
135            SaveName::Prompt => "prompt",
136        }
137    }
138
139    pub fn __str__(&self) -> String {
140        self.to_string()
141    }
142}
143
144impl Display for SaveName {
145    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
146        write!(f, "{}", self.as_string())
147    }
148}
149
150impl AsRef<Path> for SaveName {
151    fn as_ref(&self) -> &Path {
152        match self {
153            SaveName::Prompt => Path::new("prompt"),
154        }
155    }
156}
157
158// impl PathBuf: From<SaveName>
159impl From<SaveName> for PathBuf {
160    fn from(save_name: SaveName) -> Self {
161        PathBuf::from(save_name.as_ref())
162    }
163}
164
165/// A trait for structured output types that can be used with potatohead prompts agents and workflows.
166///
167/// # Example
168/// ```rust
169/// use potato_macros::StructureOutput;
170/// use serde::{Serialize, Deserialize};
171/// use schemars::JsonSchema;
172///
173/// #[derive(Serialize, Deserialize, JsonSchema)]
174/// struct MyOutput {
175///     message: String,
176///     value: i32,
177/// }
178///
179/// impl StructuredOutput for MyOutput {}
180///
181/// let schema = MyOutput::get_structured_output_schema();
182/// ```
183pub trait StructuredOutput: for<'de> serde::Deserialize<'de> + JsonSchema {
184    fn type_name() -> &'static str {
185        type_name::<Self>().rsplit("::").next().unwrap_or("Unknown")
186    }
187
188    /// Validates and deserializes a JSON value into its struct type.
189    ///
190    /// # Arguments
191    /// * `value` - The JSON value to deserialize
192    ///
193    /// # Returns
194    /// * `Result<Self, serde_json::Error>` - The deserialized value or error
195    fn model_validate_json_value(value: &Value) -> Result<Self, serde_json::Error> {
196        match &value {
197            Value::String(json_str) => Self::model_validate_json_str(json_str),
198            Value::Object(_) => {
199                // Content is already a JSON object
200                serde_json::from_value(value.clone())
201            }
202            _ => {
203                // If the value is not a string or object, we cannot deserialize it
204                Err(Error::custom("Expected a JSON string or object"))
205            }
206        }
207    }
208
209    fn model_validate_json_str(value: &str) -> Result<Self, serde_json::Error> {
210        serde_json::from_str(value)
211    }
212
213    /// Generates an OpenAI-compatible JSON schema.
214    ///
215    /// # Returns
216    /// * `Value` - The JSON schema wrapped in OpenAI's format
217    fn get_structured_output_schema() -> Value {
218        let schema = ::schemars::schema_for!(Self);
219        schema.into()
220    }
221    // add fallback parsing logic
222}