potato_type/prompt/
settings.rs1use crate::anthropic::v1::request::AnthropicSettings;
2use crate::error::TypeError;
3use crate::{
4 google::v1::generate::request::GeminiSettings, openai::v1::chat::settings::OpenAIChatSettings,
5};
6use crate::{Provider, SettingsType};
7use potato_util::PyHelperFuncs;
8use pyo3::prelude::*;
9use pyo3::IntoPyObjectExt;
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12
13#[pyclass]
14#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
15#[serde(untagged)]
16#[allow(clippy::large_enum_variant)]
17pub enum ModelSettings {
18 OpenAIChat(OpenAIChatSettings),
19 GoogleChat(GeminiSettings),
20 AnthropicChat(AnthropicSettings),
21}
22
23impl Default for ModelSettings {
24 fn default() -> Self {
25 ModelSettings::OpenAIChat(OpenAIChatSettings::default())
26 }
27}
28
29#[pymethods]
30impl ModelSettings {
31 #[new]
32 pub fn new(settings: &Bound<'_, PyAny>) -> Result<Self, TypeError> {
33 potatohead_macro::try_extract_py_object!(
34 settings,
35 OpenAIChatSettings => ModelSettings::OpenAIChat,
36 GeminiSettings => ModelSettings::GoogleChat,
37 AnthropicSettings => ModelSettings::AnthropicChat,
38 );
39
40 Err(TypeError::InvalidModelSettings)
42 }
43
44 #[getter]
45 pub fn settings<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
46 match self {
47 ModelSettings::OpenAIChat(settings) => {
48 Ok(Py::new(py, settings.clone())?.into_bound_py_any(py)?)
49 }
50 ModelSettings::GoogleChat(settings) => {
51 Ok(Py::new(py, settings.clone())?.into_bound_py_any(py)?)
52 }
53 ModelSettings::AnthropicChat(settings) => {
54 Ok(Py::new(py, settings.clone())?.into_bound_py_any(py)?)
55 }
56 }
57 }
58
59 pub fn model_dump_json(&self) -> String {
60 serde_json::to_string(self).unwrap()
61 }
62
63 pub fn model_dump<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
64 match self {
65 ModelSettings::OpenAIChat(settings) => Ok(settings.model_dump(py)?),
66 ModelSettings::GoogleChat(settings) => Ok(settings.model_dump(py)?),
67 ModelSettings::AnthropicChat(settings) => Ok(settings.model_dump(py)?),
68 }
69 }
70
71 pub fn settings_type(&self) -> SettingsType {
72 SettingsType::ModelSettings
73 }
74
75 pub fn __str__(&self) -> String {
76 PyHelperFuncs::__str__(self)
77 }
78}
79
80impl ModelSettings {
81 pub fn validate_provider(&self, provider: &Provider) -> Result<(), TypeError> {
82 match provider {
83 Provider::OpenAI => match self {
84 ModelSettings::OpenAIChat(_) => Ok(()),
85 _ => Err(TypeError::InvalidModelSettings),
86 },
87 Provider::Gemini => match self {
88 ModelSettings::GoogleChat(_) => Ok(()),
89 _ => Err(TypeError::InvalidModelSettings),
90 },
91 Provider::Vertex => match self {
92 ModelSettings::GoogleChat(_) => Ok(()),
93 _ => Err(TypeError::InvalidModelSettings),
94 },
95 Provider::Google => match self {
96 ModelSettings::GoogleChat(_) => Ok(()),
97 _ => Err(TypeError::InvalidModelSettings),
98 },
99 Provider::Anthropic => match self {
100 ModelSettings::AnthropicChat(_) => Ok(()),
101 _ => Err(TypeError::InvalidModelSettings),
102 },
103 Provider::Undefined => match self {
104 ModelSettings::OpenAIChat(_) => Ok(()),
105 ModelSettings::GoogleChat(_) => Ok(()),
106 ModelSettings::AnthropicChat(_) => Ok(()),
107 },
108 }
109 }
110
111 pub fn provider_default_settings(provider: &Provider) -> Self {
112 match provider {
113 Provider::OpenAI => ModelSettings::OpenAIChat(OpenAIChatSettings::default()),
114 Provider::Gemini => ModelSettings::GoogleChat(GeminiSettings::default()),
115 _ => ModelSettings::OpenAIChat(OpenAIChatSettings::default()), }
117 }
118
119 pub fn get_openai_settings(&self) -> Option<OpenAIChatSettings> {
120 match self {
121 ModelSettings::OpenAIChat(settings) => {
122 let mut cloned_settings = settings.clone();
123 cloned_settings.extra_body = None;
125 Some(cloned_settings)
126 }
127 _ => None,
128 }
129 }
130
131 pub fn get_gemini_settings(&self) -> Option<GeminiSettings> {
132 match self {
133 ModelSettings::GoogleChat(settings) => {
134 let mut cloned_settings = settings.clone();
135 cloned_settings.extra_body = None;
137 Some(cloned_settings)
138 }
139 _ => None,
140 }
141 }
142
143 pub fn get_anthropic_settings(&self) -> AnthropicSettings {
144 match self {
145 ModelSettings::AnthropicChat(settings) => {
146 let mut cloned_settings = settings.clone();
147 cloned_settings.extra_body = None;
149 cloned_settings
150 }
151 _ => AnthropicSettings::default(),
152 }
153 }
154
155 pub fn extra_body(&self) -> Option<&Value> {
156 match self {
157 ModelSettings::OpenAIChat(settings) => settings.extra_body.as_ref(),
158 ModelSettings::GoogleChat(settings) => settings.extra_body.as_ref(),
159 ModelSettings::AnthropicChat(settings) => settings.extra_body.as_ref(),
160 }
161 }
162
163 pub fn provider(&self) -> Provider {
164 match self {
165 ModelSettings::OpenAIChat(_) => Provider::OpenAI,
166 ModelSettings::GoogleChat(_) => Provider::Gemini,
167 ModelSettings::AnthropicChat(_) => Provider::Anthropic,
168 }
169 }
170}