1pub mod error;
2
3pub use crate::error::TypeError;
4use pyo3::prelude::*;
5use schemars::JsonSchema;
6use serde::de::Error;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::any::type_name;
10use std::fmt;
11use std::fmt::Display;
12use std::path::{Path, PathBuf};
13use tracing::error;
14pub mod google;
15pub mod openai;
16
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
18#[pyclass]
19pub enum Model {
20 Undefined,
21}
22
23impl Model {
24 pub fn as_str(&self) -> &str {
25 match self {
26 Model::Undefined => "undefined",
27 }
28 }
29
30 pub fn from_string(s: &str) -> Result<Self, TypeError> {
31 match s.to_lowercase().as_str() {
32 "undefined" => Ok(Model::Undefined),
33 _ => Err(TypeError::UnknownModelError(s.to_string())),
34 }
35 }
36}
37
38pub enum Common {
39 Undefined,
40}
41
42impl Common {
43 pub fn as_str(&self) -> &str {
44 match self {
45 Common::Undefined => "undefined",
46 }
47 }
48}
49
50impl Display for Common {
51 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
52 write!(f, "{}", self.as_str())
53 }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
57#[pyclass]
58pub enum Provider {
59 OpenAI,
60 Gemini,
61 Undefined, }
63
64impl Provider {
65 pub fn url(&self) -> &str {
66 match self {
67 Provider::OpenAI => "https://api.openai.com/v1",
68
69 Provider::Gemini => "https://generativelanguage.googleapis.com/v1beta/models",
71
72 Provider::Undefined => {
73 error!("Undefined provider URL requested");
74 "https://undefined.provider.url"
75 }
76 }
77 }
78
79 pub fn from_string(s: &str) -> Result<Self, TypeError> {
80 match s.to_lowercase().as_str() {
81 "openai" => Ok(Provider::OpenAI),
82 "gemini" => Ok(Provider::Gemini),
83 "undefined" => Ok(Provider::Undefined), _ => Err(TypeError::UnknownProviderError(s.to_string())),
85 }
86 }
87
88 pub fn extract_provider(provider: &Bound<'_, PyAny>) -> Result<Provider, TypeError> {
99 match provider.is_instance_of::<Provider>() {
100 true => Ok(provider.extract::<Provider>().inspect_err(|e| {
101 error!("Failed to extract provider: {}", e);
102 })?),
103 false => {
104 let provider = provider.extract::<String>().unwrap();
105 Ok(Provider::from_string(&provider).inspect_err(|e| {
106 error!("Failed to convert string to provider: {}", e);
107 })?)
108 }
109 }
110 }
111
112 pub fn as_str(&self) -> &str {
113 match self {
114 Provider::OpenAI => "openai",
115 Provider::Gemini => "gemini",
116 Provider::Undefined => "undefined", }
118 }
119}
120
121impl Display for Provider {
122 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
123 write!(f, "{}", self.as_str())
124 }
125}
126
127#[pyclass(eq, eq_int)]
128#[derive(Debug, PartialEq, Clone)]
129pub enum SaveName {
130 Prompt,
131}
132
133#[pymethods]
134impl SaveName {
135 #[staticmethod]
136 pub fn from_string(s: &str) -> Option<Self> {
137 match s {
138 "prompt" => Some(SaveName::Prompt),
139
140 _ => None,
141 }
142 }
143
144 pub fn as_string(&self) -> &str {
145 match self {
146 SaveName::Prompt => "prompt",
147 }
148 }
149
150 pub fn __str__(&self) -> String {
151 self.to_string()
152 }
153}
154
155impl Display for SaveName {
156 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
157 write!(f, "{}", self.as_string())
158 }
159}
160
161impl AsRef<Path> for SaveName {
162 fn as_ref(&self) -> &Path {
163 match self {
164 SaveName::Prompt => Path::new("prompt"),
165 }
166 }
167}
168
169impl From<SaveName> for PathBuf {
171 fn from(save_name: SaveName) -> Self {
172 PathBuf::from(save_name.as_ref())
173 }
174}
175
176pub trait StructuredOutput: for<'de> serde::Deserialize<'de> + JsonSchema {
195 fn type_name() -> &'static str {
196 type_name::<Self>().rsplit("::").next().unwrap_or("Unknown")
197 }
198
199 fn model_validate_json_value(value: &Value) -> Result<Self, serde_json::Error> {
207 match &value {
208 Value::String(json_str) => Self::model_validate_json_str(json_str),
209 Value::Object(_) => {
210 serde_json::from_value(value.clone())
212 }
213 _ => {
214 Err(Error::custom("Expected a JSON string or object"))
216 }
217 }
218 }
219
220 fn model_validate_json_str(value: &str) -> Result<Self, serde_json::Error> {
221 serde_json::from_str(value)
222 }
223
224 fn get_structured_output_schema() -> Value {
229 let schema = ::schemars::schema_for!(Self);
230 schema.into()
231 }
232 }
234
235#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
236#[pyclass]
237pub enum SettingsType {
238 GoogleChat,
239 OpenAIChat,
240 ModelSettings,
241}