1pub mod error;
2
3pub use crate::error::TypeError;
4use pyo3::prelude::*;
5use schemars::JsonSchema;
6use serde::de::Error;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::any::type_name;
10use std::fmt;
11use std::fmt::Display;
12use std::path::{Path, PathBuf};
13use tracing::error;
14pub mod anthropic;
15pub mod common;
16pub mod google;
17pub mod openai;
18pub mod prompt;
19pub mod tools;
20pub mod traits;
21
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
23#[pyclass]
24pub enum Model {
25 Undefined,
26}
27
28impl Model {
29 pub fn as_str(&self) -> &str {
30 match self {
31 Model::Undefined => "undefined",
32 }
33 }
34
35 pub fn from_string(s: &str) -> Result<Self, TypeError> {
36 match s.to_lowercase().as_str() {
37 "undefined" => Ok(Model::Undefined),
38 _ => Err(TypeError::UnknownModelError(s.to_string())),
39 }
40 }
41}
42
43pub enum Common {
44 Undefined,
45}
46
47impl Common {
48 pub fn as_str(&self) -> &str {
49 match self {
50 Common::Undefined => "undefined",
51 }
52 }
53}
54
55impl Display for Common {
56 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
57 write!(f, "{}", self.as_str())
58 }
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
62#[pyclass(eq, eq_int)]
63pub enum Provider {
64 OpenAI,
65 Gemini,
66 Google,
67 Vertex,
68 Anthropic,
69 GoogleAdk,
70 Undefined, }
72
73impl Provider {
74 pub fn from_string(s: &str) -> Result<Self, TypeError> {
75 match s.to_lowercase().as_str() {
76 "openai" => Ok(Provider::OpenAI),
77 "gemini" => Ok(Provider::Gemini),
78 "google" => Ok(Provider::Google),
79 "vertex" => Ok(Provider::Vertex),
80 "anthropic" => Ok(Provider::Anthropic),
81 "google_adk" => Ok(Provider::GoogleAdk),
82 "undefined" => Ok(Provider::Undefined), _ => Err(TypeError::UnknownProviderError(s.to_string())),
84 }
85 }
86
87 pub fn extract_provider(provider: &Bound<'_, PyAny>) -> Result<Provider, TypeError> {
98 match provider.is_instance_of::<Provider>() {
99 true => Ok(provider.extract::<Provider>().inspect_err(|e| {
100 error!("Failed to extract provider: {}", e);
101 })?),
102 false => {
103 let provider = provider.extract::<String>().unwrap();
104 Ok(Provider::from_string(&provider).inspect_err(|e| {
105 error!("Failed to convert string to provider: {}", e);
106 })?)
107 }
108 }
109 }
110
111 pub fn as_str(&self) -> &str {
112 match self {
113 Provider::OpenAI => "openai",
114 Provider::Gemini => "gemini",
115 Provider::Vertex => "vertex",
116 Provider::Google => "google",
117 Provider::Anthropic => "anthropic",
118 Provider::GoogleAdk => "google_adk",
119 Provider::Undefined => "undefined", }
121 }
122}
123
124impl Display for Provider {
125 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
126 write!(f, "{}", self.as_str())
127 }
128}
129
130#[pyclass(eq, eq_int)]
131#[derive(Debug, PartialEq, Clone)]
132pub enum SaveName {
133 Prompt,
134}
135
136#[pymethods]
137impl SaveName {
138 #[staticmethod]
139 pub fn from_string(s: &str) -> Option<Self> {
140 match s {
141 "prompt" => Some(SaveName::Prompt),
142
143 _ => None,
144 }
145 }
146
147 pub fn as_string(&self) -> &str {
148 match self {
149 SaveName::Prompt => "prompt",
150 }
151 }
152
153 pub fn __str__(&self) -> String {
154 self.to_string()
155 }
156}
157
158impl Display for SaveName {
159 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
160 write!(f, "{}", self.as_string())
161 }
162}
163
164impl AsRef<Path> for SaveName {
165 fn as_ref(&self) -> &Path {
166 match self {
167 SaveName::Prompt => Path::new("prompt"),
168 }
169 }
170}
171
172impl From<SaveName> for PathBuf {
174 fn from(save_name: SaveName) -> Self {
175 PathBuf::from(save_name.as_ref())
176 }
177}
178
179pub trait StructuredOutput: for<'de> serde::Deserialize<'de> + JsonSchema {
198 fn type_name() -> &'static str {
199 type_name::<Self>().rsplit("::").next().unwrap_or("Unknown")
200 }
201
202 fn model_validate_json_value(value: &Value) -> Result<Self, serde_json::Error> {
210 match &value {
211 Value::String(json_str) => Self::model_validate_json_str(json_str),
212 Value::Object(_) => {
213 serde_json::from_value(value.clone())
215 }
216 _ => {
217 Err(Error::custom("Expected a JSON string or object"))
219 }
220 }
221 }
222
223 fn model_validate_json_str(value: &str) -> Result<Self, serde_json::Error> {
224 serde_json::from_str(value)
225 }
226
227 fn get_structured_output_schema() -> Value {
232 let schema = ::schemars::schema_for!(Self);
233 schema.into()
234 }
235 }
237
238#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
239#[pyclass]
240pub enum SettingsType {
241 GoogleChat,
242 OpenAIChat,
243 ModelSettings,
244 Anthropic,
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 #[test]
252 fn test_provider_google_adk_round_trip() {
253 let p = Provider::from_string("google_adk").unwrap();
254 assert_eq!(p, Provider::GoogleAdk);
255 assert_eq!(p.as_str(), "google_adk");
256 }
257
258 #[test]
259 fn test_provider_all_variants_round_trip() {
260 for (s, variant) in [
261 ("openai", Provider::OpenAI),
262 ("gemini", Provider::Gemini),
263 ("google", Provider::Google),
264 ("vertex", Provider::Vertex),
265 ("anthropic", Provider::Anthropic),
266 ("google_adk", Provider::GoogleAdk),
267 ("undefined", Provider::Undefined),
268 ] {
269 let parsed = Provider::from_string(s).unwrap();
270 assert_eq!(parsed, variant);
271 assert_eq!(parsed.as_str(), s);
272 }
273 }
274
275 #[test]
276 fn test_provider_unknown_string_errors() {
277 assert!(Provider::from_string("not_a_provider").is_err());
278 }
279}