1pub mod error;
2
3pub use crate::error::TypeError;
4use pyo3::prelude::*;
5use schemars::JsonSchema;
6use serde::de::Error;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::any::type_name;
10use std::fmt;
11use std::fmt::Display;
12use std::path::{Path, PathBuf};
13use tracing::error;
14pub mod google;
15pub mod openai;
16
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
18#[pyclass]
19pub enum Model {
20 Undefined,
21}
22
23impl Model {
24 pub fn as_str(&self) -> &str {
25 match self {
26 Model::Undefined => "undefined",
27 }
28 }
29
30 pub fn from_string(s: &str) -> Result<Self, TypeError> {
31 match s.to_lowercase().as_str() {
32 "undefined" => Ok(Model::Undefined),
33 _ => Err(TypeError::UnknownModelError(s.to_string())),
34 }
35 }
36}
37
38pub enum Common {
39 Undefined,
40}
41
42impl Common {
43 pub fn as_str(&self) -> &str {
44 match self {
45 Common::Undefined => "undefined",
46 }
47 }
48}
49
50impl Display for Common {
51 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
52 write!(f, "{}", self.as_str())
53 }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
57#[pyclass]
58pub enum Provider {
59 OpenAI,
60 Gemini,
61 Google,
62 Vertex,
63 Undefined, }
65
66impl Provider {
67 pub fn from_string(s: &str) -> Result<Self, TypeError> {
68 match s.to_lowercase().as_str() {
69 "openai" => Ok(Provider::OpenAI),
70 "gemini" => Ok(Provider::Gemini),
71 "google" => Ok(Provider::Google),
72 "vertex" => Ok(Provider::Vertex),
73 "undefined" => Ok(Provider::Undefined), _ => Err(TypeError::UnknownProviderError(s.to_string())),
75 }
76 }
77
78 pub fn extract_provider(provider: &Bound<'_, PyAny>) -> Result<Provider, TypeError> {
89 match provider.is_instance_of::<Provider>() {
90 true => Ok(provider.extract::<Provider>().inspect_err(|e| {
91 error!("Failed to extract provider: {}", e);
92 })?),
93 false => {
94 let provider = provider.extract::<String>().unwrap();
95 Ok(Provider::from_string(&provider).inspect_err(|e| {
96 error!("Failed to convert string to provider: {}", e);
97 })?)
98 }
99 }
100 }
101
102 pub fn as_str(&self) -> &str {
103 match self {
104 Provider::OpenAI => "openai",
105 Provider::Gemini => "gemini",
106 Provider::Vertex => "vertex",
107 Provider::Google => "google",
108 Provider::Undefined => "undefined", }
110 }
111}
112
113impl Display for Provider {
114 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
115 write!(f, "{}", self.as_str())
116 }
117}
118
119#[pyclass(eq, eq_int)]
120#[derive(Debug, PartialEq, Clone)]
121pub enum SaveName {
122 Prompt,
123}
124
125#[pymethods]
126impl SaveName {
127 #[staticmethod]
128 pub fn from_string(s: &str) -> Option<Self> {
129 match s {
130 "prompt" => Some(SaveName::Prompt),
131
132 _ => None,
133 }
134 }
135
136 pub fn as_string(&self) -> &str {
137 match self {
138 SaveName::Prompt => "prompt",
139 }
140 }
141
142 pub fn __str__(&self) -> String {
143 self.to_string()
144 }
145}
146
147impl Display for SaveName {
148 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
149 write!(f, "{}", self.as_string())
150 }
151}
152
153impl AsRef<Path> for SaveName {
154 fn as_ref(&self) -> &Path {
155 match self {
156 SaveName::Prompt => Path::new("prompt"),
157 }
158 }
159}
160
161impl From<SaveName> for PathBuf {
163 fn from(save_name: SaveName) -> Self {
164 PathBuf::from(save_name.as_ref())
165 }
166}
167
168pub trait StructuredOutput: for<'de> serde::Deserialize<'de> + JsonSchema {
187 fn type_name() -> &'static str {
188 type_name::<Self>().rsplit("::").next().unwrap_or("Unknown")
189 }
190
191 fn model_validate_json_value(value: &Value) -> Result<Self, serde_json::Error> {
199 match &value {
200 Value::String(json_str) => Self::model_validate_json_str(json_str),
201 Value::Object(_) => {
202 serde_json::from_value(value.clone())
204 }
205 _ => {
206 Err(Error::custom("Expected a JSON string or object"))
208 }
209 }
210 }
211
212 fn model_validate_json_str(value: &str) -> Result<Self, serde_json::Error> {
213 serde_json::from_str(value)
214 }
215
216 fn get_structured_output_schema() -> Value {
221 let schema = ::schemars::schema_for!(Self);
222 schema.into()
223 }
224 }
226
227#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
228#[pyclass]
229pub enum SettingsType {
230 GoogleChat,
231 OpenAIChat,
232 ModelSettings,
233}