1pub mod error;
2
3pub use crate::error::TypeError;
4use pyo3::prelude::*;
5use schemars::JsonSchema;
6use serde::de::Error;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::any::type_name;
10use std::fmt;
11use std::fmt::Display;
12use std::path::{Path, PathBuf};
13use tracing::error;
14pub mod anthropic;
15pub mod common;
16pub mod google;
17pub mod openai;
18pub mod prompt;
19pub mod tools;
20pub mod traits;
21
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
23#[pyclass]
24pub enum Model {
25 Undefined,
26}
27
28impl Model {
29 pub fn as_str(&self) -> &str {
30 match self {
31 Model::Undefined => "undefined",
32 }
33 }
34
35 pub fn from_string(s: &str) -> Result<Self, TypeError> {
36 match s.to_lowercase().as_str() {
37 "undefined" => Ok(Model::Undefined),
38 _ => Err(TypeError::UnknownModelError(s.to_string())),
39 }
40 }
41}
42
43pub enum Common {
44 Undefined,
45}
46
47impl Common {
48 pub fn as_str(&self) -> &str {
49 match self {
50 Common::Undefined => "undefined",
51 }
52 }
53}
54
55impl Display for Common {
56 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
57 write!(f, "{}", self.as_str())
58 }
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
62#[pyclass(eq, eq_int)]
63pub enum Provider {
64 OpenAI,
65 Gemini,
66 Google,
67 Vertex,
68 Anthropic,
69 Undefined, }
71
72impl Provider {
73 pub fn from_string(s: &str) -> Result<Self, TypeError> {
74 match s.to_lowercase().as_str() {
75 "openai" => Ok(Provider::OpenAI),
76 "gemini" => Ok(Provider::Gemini),
77 "google" => Ok(Provider::Google),
78 "vertex" => Ok(Provider::Vertex),
79 "anthropic" => Ok(Provider::Anthropic),
80 "undefined" => Ok(Provider::Undefined), _ => Err(TypeError::UnknownProviderError(s.to_string())),
82 }
83 }
84
85 pub fn extract_provider(provider: &Bound<'_, PyAny>) -> Result<Provider, TypeError> {
96 match provider.is_instance_of::<Provider>() {
97 true => Ok(provider.extract::<Provider>().inspect_err(|e| {
98 error!("Failed to extract provider: {}", e);
99 })?),
100 false => {
101 let provider = provider.extract::<String>().unwrap();
102 Ok(Provider::from_string(&provider).inspect_err(|e| {
103 error!("Failed to convert string to provider: {}", e);
104 })?)
105 }
106 }
107 }
108
109 pub fn as_str(&self) -> &str {
110 match self {
111 Provider::OpenAI => "openai",
112 Provider::Gemini => "gemini",
113 Provider::Vertex => "vertex",
114 Provider::Google => "google",
115 Provider::Anthropic => "anthropic",
116 Provider::Undefined => "undefined", }
118 }
119}
120
121impl Display for Provider {
122 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
123 write!(f, "{}", self.as_str())
124 }
125}
126
127#[pyclass(eq, eq_int)]
128#[derive(Debug, PartialEq, Clone)]
129pub enum SaveName {
130 Prompt,
131}
132
133#[pymethods]
134impl SaveName {
135 #[staticmethod]
136 pub fn from_string(s: &str) -> Option<Self> {
137 match s {
138 "prompt" => Some(SaveName::Prompt),
139
140 _ => None,
141 }
142 }
143
144 pub fn as_string(&self) -> &str {
145 match self {
146 SaveName::Prompt => "prompt",
147 }
148 }
149
150 pub fn __str__(&self) -> String {
151 self.to_string()
152 }
153}
154
155impl Display for SaveName {
156 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
157 write!(f, "{}", self.as_string())
158 }
159}
160
161impl AsRef<Path> for SaveName {
162 fn as_ref(&self) -> &Path {
163 match self {
164 SaveName::Prompt => Path::new("prompt"),
165 }
166 }
167}
168
169impl From<SaveName> for PathBuf {
171 fn from(save_name: SaveName) -> Self {
172 PathBuf::from(save_name.as_ref())
173 }
174}
175
176pub trait StructuredOutput: for<'de> serde::Deserialize<'de> + JsonSchema {
195 fn type_name() -> &'static str {
196 type_name::<Self>().rsplit("::").next().unwrap_or("Unknown")
197 }
198
199 fn model_validate_json_value(value: &Value) -> Result<Self, serde_json::Error> {
207 match &value {
208 Value::String(json_str) => Self::model_validate_json_str(json_str),
209 Value::Object(_) => {
210 serde_json::from_value(value.clone())
212 }
213 _ => {
214 Err(Error::custom("Expected a JSON string or object"))
216 }
217 }
218 }
219
220 fn model_validate_json_str(value: &str) -> Result<Self, serde_json::Error> {
221 serde_json::from_str(value)
222 }
223
224 fn get_structured_output_schema() -> Value {
229 let schema = ::schemars::schema_for!(Self);
230 schema.into()
231 }
232 }
234
235#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
236#[pyclass]
237pub enum SettingsType {
238 GoogleChat,
239 OpenAIChat,
240 ModelSettings,
241 Anthropic,
242}