1pub mod error;
2
3use crate::error::TypeError;
4use pyo3::prelude::*;
5use schemars::JsonSchema;
6use serde::de::Error;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::any::type_name;
10use std::fmt;
11use std::fmt::Display;
12use std::path::{Path, PathBuf};
13use tracing::error;
14
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
16#[pyclass]
17pub enum Model {
18 Undefined,
19}
20
21impl Model {
22 pub fn as_str(&self) -> &str {
23 match self {
24 Model::Undefined => "undefined",
25 }
26 }
27
28 pub fn from_string(s: &str) -> Result<Self, TypeError> {
29 match s.to_lowercase().as_str() {
30 "undefined" => Ok(Model::Undefined),
31 _ => Err(TypeError::UnknownModelError(s.to_string())),
32 }
33 }
34}
35
36pub enum Common {
37 Undefined,
38}
39
40impl Common {
41 pub fn as_str(&self) -> &str {
42 match self {
43 Common::Undefined => "undefined",
44 }
45 }
46}
47
48impl Display for Common {
49 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
50 write!(f, "{}", self.as_str())
51 }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
55#[pyclass]
56pub enum Provider {
57 OpenAI,
58 Gemini,
59 Undefined, }
61
62impl Provider {
63 pub fn url(&self) -> &str {
64 match self {
65 Provider::OpenAI => "https://api.openai.com/v1",
66 Provider::Gemini => "https://generativelanguage.googleapis.com/v1beta/models",
67 Provider::Undefined => {
68 error!("Undefined provider URL requested");
69 "https://undefined.provider.url"
70 }
71 }
72 }
73
74 pub fn from_string(s: &str) -> Result<Self, TypeError> {
75 match s.to_lowercase().as_str() {
76 "openai" => Ok(Provider::OpenAI),
77 "gemini" => Ok(Provider::Gemini),
78 "undefined" => Ok(Provider::Undefined), _ => Err(TypeError::UnknownProviderError(s.to_string())),
80 }
81 }
82
83 pub fn extract_provider(provider: &Bound<'_, PyAny>) -> Result<Provider, TypeError> {
94 match provider.is_instance_of::<Provider>() {
95 true => Ok(provider.extract::<Provider>().inspect_err(|e| {
96 error!("Failed to extract provider: {}", e);
97 })?),
98 false => {
99 let provider = provider.extract::<String>().unwrap();
100 Ok(Provider::from_string(&provider).inspect_err(|e| {
101 error!("Failed to convert string to provider: {}", e);
102 })?)
103 }
104 }
105 }
106
107 pub fn as_str(&self) -> &str {
108 match self {
109 Provider::OpenAI => "openai",
110 Provider::Gemini => "gemini",
111 Provider::Undefined => "undefined", }
113 }
114}
115
116#[pyclass(eq, eq_int)]
117#[derive(Debug, PartialEq, Clone)]
118pub enum SaveName {
119 Prompt,
120}
121
122#[pymethods]
123impl SaveName {
124 #[staticmethod]
125 pub fn from_string(s: &str) -> Option<Self> {
126 match s {
127 "prompt" => Some(SaveName::Prompt),
128
129 _ => None,
130 }
131 }
132
133 pub fn as_string(&self) -> &str {
134 match self {
135 SaveName::Prompt => "prompt",
136 }
137 }
138
139 pub fn __str__(&self) -> String {
140 self.to_string()
141 }
142}
143
144impl Display for SaveName {
145 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
146 write!(f, "{}", self.as_string())
147 }
148}
149
150impl AsRef<Path> for SaveName {
151 fn as_ref(&self) -> &Path {
152 match self {
153 SaveName::Prompt => Path::new("prompt"),
154 }
155 }
156}
157
158impl From<SaveName> for PathBuf {
160 fn from(save_name: SaveName) -> Self {
161 PathBuf::from(save_name.as_ref())
162 }
163}
164
165pub trait StructuredOutput: for<'de> serde::Deserialize<'de> + JsonSchema {
184 fn type_name() -> &'static str {
185 type_name::<Self>().rsplit("::").next().unwrap_or("Unknown")
186 }
187
188 fn model_validate_json_value(value: &Value) -> Result<Self, serde_json::Error> {
196 match &value {
197 Value::String(json_str) => Self::model_validate_json_str(json_str),
198 Value::Object(_) => {
199 serde_json::from_value(value.clone())
201 }
202 _ => {
203 Err(Error::custom("Expected a JSON string or object"))
205 }
206 }
207 }
208
209 fn model_validate_json_str(value: &str) -> Result<Self, serde_json::Error> {
210 serde_json::from_str(value)
211 }
212
213 fn get_structured_output_schema() -> Value {
218 let schema = ::schemars::schema_for!(Self);
219 schema.into()
220 }
221 }