Skip to main content

potato_type/prompt/
settings.rs

1use crate::anthropic::v1::request::AnthropicSettings;
2use crate::error::TypeError;
3use crate::{
4    google::v1::generate::request::GeminiSettings, openai::v1::chat::settings::OpenAIChatSettings,
5};
6use crate::{Provider, SettingsType};
7use potato_util::PyHelperFuncs;
8use pyo3::prelude::*;
9use pyo3::IntoPyObjectExt;
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12
13#[pyclass]
14#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
15#[serde(untagged)]
16#[allow(clippy::large_enum_variant)]
17pub enum ModelSettings {
18    OpenAIChat(OpenAIChatSettings),
19    GoogleChat(GeminiSettings),
20    AnthropicChat(AnthropicSettings),
21}
22
23impl Default for ModelSettings {
24    fn default() -> Self {
25        ModelSettings::OpenAIChat(OpenAIChatSettings::default())
26    }
27}
28
29#[pymethods]
30impl ModelSettings {
31    #[new]
32    pub fn new(settings: &Bound<'_, PyAny>) -> Result<Self, TypeError> {
33        potatohead_macro::try_extract_py_object!(
34            settings,
35            OpenAIChatSettings => ModelSettings::OpenAIChat,
36            GeminiSettings => ModelSettings::GoogleChat,
37            AnthropicSettings => ModelSettings::AnthropicChat,
38        );
39
40        // If none matched, return error
41        Err(TypeError::InvalidModelSettings)
42    }
43
44    #[getter]
45    pub fn settings<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
46        match self {
47            ModelSettings::OpenAIChat(settings) => {
48                Ok(Py::new(py, settings.clone())?.into_bound_py_any(py)?)
49            }
50            ModelSettings::GoogleChat(settings) => {
51                Ok(Py::new(py, settings.clone())?.into_bound_py_any(py)?)
52            }
53            ModelSettings::AnthropicChat(settings) => {
54                Ok(Py::new(py, settings.clone())?.into_bound_py_any(py)?)
55            }
56        }
57    }
58
59    pub fn model_dump_json(&self) -> String {
60        serde_json::to_string(self).unwrap()
61    }
62
63    pub fn model_dump<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
64        match self {
65            ModelSettings::OpenAIChat(settings) => Ok(settings.model_dump(py)?),
66            ModelSettings::GoogleChat(settings) => Ok(settings.model_dump(py)?),
67            ModelSettings::AnthropicChat(settings) => Ok(settings.model_dump(py)?),
68        }
69    }
70
71    pub fn settings_type(&self) -> SettingsType {
72        SettingsType::ModelSettings
73    }
74
75    pub fn __str__(&self) -> String {
76        PyHelperFuncs::__str__(self)
77    }
78}
79
80impl ModelSettings {
81    pub fn validate_provider(&self, provider: &Provider) -> Result<(), TypeError> {
82        match provider {
83            Provider::OpenAI => match self {
84                ModelSettings::OpenAIChat(_) => Ok(()),
85                _ => Err(TypeError::InvalidModelSettings),
86            },
87            Provider::Gemini => match self {
88                ModelSettings::GoogleChat(_) => Ok(()),
89                _ => Err(TypeError::InvalidModelSettings),
90            },
91            Provider::Vertex => match self {
92                ModelSettings::GoogleChat(_) => Ok(()),
93                _ => Err(TypeError::InvalidModelSettings),
94            },
95            Provider::Google => match self {
96                ModelSettings::GoogleChat(_) => Ok(()),
97                _ => Err(TypeError::InvalidModelSettings),
98            },
99            Provider::Anthropic => match self {
100                ModelSettings::AnthropicChat(_) => Ok(()),
101                _ => Err(TypeError::InvalidModelSettings),
102            },
103            Provider::Undefined => match self {
104                ModelSettings::OpenAIChat(_) => Ok(()),
105                ModelSettings::GoogleChat(_) => Ok(()),
106                ModelSettings::AnthropicChat(_) => Ok(()),
107            },
108        }
109    }
110
111    pub fn provider_default_settings(provider: &Provider) -> Self {
112        match provider {
113            Provider::OpenAI => ModelSettings::OpenAIChat(OpenAIChatSettings::default()),
114            Provider::Gemini => ModelSettings::GoogleChat(GeminiSettings::default()),
115            _ => ModelSettings::OpenAIChat(OpenAIChatSettings::default()), // Fallback to OpenAI settings
116        }
117    }
118
119    pub fn get_openai_settings(&self) -> Option<OpenAIChatSettings> {
120        match self {
121            ModelSettings::OpenAIChat(settings) => {
122                let mut cloned_settings = settings.clone();
123                // set extra body to None
124                cloned_settings.extra_body = None;
125                Some(cloned_settings)
126            }
127            _ => None,
128        }
129    }
130
131    pub fn get_gemini_settings(&self) -> Option<GeminiSettings> {
132        match self {
133            ModelSettings::GoogleChat(settings) => {
134                let mut cloned_settings = settings.clone();
135                // set extra body to None
136                cloned_settings.extra_body = None;
137                Some(cloned_settings)
138            }
139            _ => None,
140        }
141    }
142
143    pub fn get_anthropic_settings(&self) -> AnthropicSettings {
144        match self {
145            ModelSettings::AnthropicChat(settings) => {
146                let mut cloned_settings = settings.clone();
147                // set extra body to None
148                cloned_settings.extra_body = None;
149                cloned_settings
150            }
151            _ => AnthropicSettings::default(),
152        }
153    }
154
155    pub fn extra_body(&self) -> Option<&Value> {
156        match self {
157            ModelSettings::OpenAIChat(settings) => settings.extra_body.as_ref(),
158            ModelSettings::GoogleChat(settings) => settings.extra_body.as_ref(),
159            ModelSettings::AnthropicChat(settings) => settings.extra_body.as_ref(),
160        }
161    }
162
163    pub fn provider(&self) -> Provider {
164        match self {
165            ModelSettings::OpenAIChat(_) => Provider::OpenAI,
166            ModelSettings::GoogleChat(_) => Provider::Gemini,
167            ModelSettings::AnthropicChat(_) => Provider::Anthropic,
168        }
169    }
170}