1pub mod error;
2
3pub use crate::error::TypeError;
4use pyo3::prelude::*;
5use schemars::JsonSchema;
6use serde::de::Error;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::any::type_name;
10use std::fmt;
11use std::fmt::Display;
12use std::path::{Path, PathBuf};
13use tracing::error;
14pub mod google;
15pub mod openai;
16
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
18#[pyclass]
19pub enum Model {
20 Undefined,
21}
22
23impl Model {
24 pub fn as_str(&self) -> &str {
25 match self {
26 Model::Undefined => "undefined",
27 }
28 }
29
30 pub fn from_string(s: &str) -> Result<Self, TypeError> {
31 match s.to_lowercase().as_str() {
32 "undefined" => Ok(Model::Undefined),
33 _ => Err(TypeError::UnknownModelError(s.to_string())),
34 }
35 }
36}
37
38pub enum Common {
39 Undefined,
40}
41
42impl Common {
43 pub fn as_str(&self) -> &str {
44 match self {
45 Common::Undefined => "undefined",
46 }
47 }
48}
49
50impl Display for Common {
51 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
52 write!(f, "{}", self.as_str())
53 }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
57#[pyclass]
58pub enum Provider {
59 OpenAI,
60 Gemini,
61 Vertex,
62 Undefined, }
64
65impl Provider {
66 pub fn from_string(s: &str) -> Result<Self, TypeError> {
67 match s.to_lowercase().as_str() {
68 "openai" => Ok(Provider::OpenAI),
69 "gemini" => Ok(Provider::Gemini),
70 "vertex" => Ok(Provider::Vertex),
71 "undefined" => Ok(Provider::Undefined), _ => Err(TypeError::UnknownProviderError(s.to_string())),
73 }
74 }
75
76 pub fn extract_provider(provider: &Bound<'_, PyAny>) -> Result<Provider, TypeError> {
87 match provider.is_instance_of::<Provider>() {
88 true => Ok(provider.extract::<Provider>().inspect_err(|e| {
89 error!("Failed to extract provider: {}", e);
90 })?),
91 false => {
92 let provider = provider.extract::<String>().unwrap();
93 Ok(Provider::from_string(&provider).inspect_err(|e| {
94 error!("Failed to convert string to provider: {}", e);
95 })?)
96 }
97 }
98 }
99
100 pub fn as_str(&self) -> &str {
101 match self {
102 Provider::OpenAI => "openai",
103 Provider::Gemini => "gemini",
104 Provider::Vertex => "vertex",
105 Provider::Undefined => "undefined", }
107 }
108}
109
110impl Display for Provider {
111 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
112 write!(f, "{}", self.as_str())
113 }
114}
115
116#[pyclass(eq, eq_int)]
117#[derive(Debug, PartialEq, Clone)]
118pub enum SaveName {
119 Prompt,
120}
121
122#[pymethods]
123impl SaveName {
124 #[staticmethod]
125 pub fn from_string(s: &str) -> Option<Self> {
126 match s {
127 "prompt" => Some(SaveName::Prompt),
128
129 _ => None,
130 }
131 }
132
133 pub fn as_string(&self) -> &str {
134 match self {
135 SaveName::Prompt => "prompt",
136 }
137 }
138
139 pub fn __str__(&self) -> String {
140 self.to_string()
141 }
142}
143
144impl Display for SaveName {
145 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
146 write!(f, "{}", self.as_string())
147 }
148}
149
150impl AsRef<Path> for SaveName {
151 fn as_ref(&self) -> &Path {
152 match self {
153 SaveName::Prompt => Path::new("prompt"),
154 }
155 }
156}
157
158impl From<SaveName> for PathBuf {
160 fn from(save_name: SaveName) -> Self {
161 PathBuf::from(save_name.as_ref())
162 }
163}
164
165pub trait StructuredOutput: for<'de> serde::Deserialize<'de> + JsonSchema {
184 fn type_name() -> &'static str {
185 type_name::<Self>().rsplit("::").next().unwrap_or("Unknown")
186 }
187
188 fn model_validate_json_value(value: &Value) -> Result<Self, serde_json::Error> {
196 match &value {
197 Value::String(json_str) => Self::model_validate_json_str(json_str),
198 Value::Object(_) => {
199 serde_json::from_value(value.clone())
201 }
202 _ => {
203 Err(Error::custom("Expected a JSON string or object"))
205 }
206 }
207 }
208
209 fn model_validate_json_str(value: &str) -> Result<Self, serde_json::Error> {
210 serde_json::from_str(value)
211 }
212
213 fn get_structured_output_schema() -> Value {
218 let schema = ::schemars::schema_for!(Self);
219 schema.into()
220 }
221 }
223
224#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
225#[pyclass]
226pub enum SettingsType {
227 GoogleChat,
228 OpenAIChat,
229 ModelSettings,
230}