potato_type/google/
chat.rs

1use crate::{SettingsType, TypeError};
2use potato_util::{json_to_pydict, pyobject_to_json, PyHelperFuncs, UtilError};
3use pyo3::prelude::*;
4use pyo3::types::PyDict;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::collections::HashMap;
8
9#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
10#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
11pub enum SchemaType {
12    TypeUnspecified,
13    String,
14    Number,
15    Integer,
16    Boolean,
17    Array,
18    Object,
19    Null,
20}
21
22#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
23#[serde(rename_all = "camelCase", default)]
24pub struct Schema {
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub r#type: Option<SchemaType>,
27    #[serde(skip_serializing_if = "Option::is_none")]
28    pub format: Option<String>,
29    #[serde(skip_serializing_if = "Option::is_none")]
30    pub title: Option<String>,
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub description: Option<String>,
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub nullable: Option<bool>,
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub default: Option<Value>,
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub items: Option<Box<Schema>>,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub min_items: Option<String>,
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub max_items: Option<String>,
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub r#enum: Option<Vec<String>>,
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub properties: Option<HashMap<String, Schema>>,
47    #[serde(skip_serializing_if = "Option::is_none")]
48    pub property_ordering: Option<Vec<String>>,
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub required: Option<Vec<String>>,
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub min_properties: Option<String>,
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub max_properties: Option<String>,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub minimum: Option<f64>,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub maximum: Option<f64>,
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub min_length: Option<String>,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub max_length: Option<String>,
63    #[serde(skip_serializing_if = "Option::is_none")]
64    pub pattern: Option<String>,
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub example: Option<Value>,
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub any_of: Option<Vec<Schema>>,
69    #[serde(skip_serializing_if = "Option::is_none")]
70    pub additional_properties: Option<Value>,
71    #[serde(rename = "ref", skip_serializing_if = "Option::is_none")]
72    pub ref_path: Option<String>,
73    #[serde(skip_serializing_if = "Option::is_none")]
74    pub defs: Option<HashMap<String, Schema>>,
75}
76
77#[pyclass]
78#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Default)]
79#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
80pub enum HarmCategory {
81    #[default]
82    HarmCategoryUnspecified,
83    HarmCategoryHateSpeech,
84    HarmCategoryDangerousContent,
85    HarmCategoryHarassment,
86    HarmCategorySexuallyExplicit,
87    HarmCategoryImageHate,
88    HarmCategoryImageDangerousContent,
89    HarmCategoryImageHarassment,
90    HarmCategoryImageSexuallyExplicit,
91}
92
93/// Probability-based threshold levels for blocking.
94#[pyclass]
95#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Default)]
96#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
97pub enum HarmBlockThreshold {
98    HarmBlockThresholdUnspecified,
99    BlockLowAndAbove,
100    BlockMediumAndAbove,
101    BlockOnlyHigh,
102    BlockNone,
103    #[default]
104    Off,
105}
106
107/// Specifies whether the threshold is used for probability or severity score.
108#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
109#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
110#[pyclass]
111pub enum HarmBlockMethod {
112    HarmBlockMethodUnspecified,
113    Severity,
114    Probability,
115}
116
117/// Safety settings for harm blocking.
118#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
119#[serde(rename_all = "camelCase")]
120#[pyclass]
121pub struct SafetySetting {
122    /// Required. The harm category.
123    #[pyo3(get)]
124    pub category: HarmCategory,
125    /// Required. The harm block threshold.
126    #[pyo3(get)]
127    pub threshold: HarmBlockThreshold,
128    /// Optional. Specify if the threshold is used for probability or severity score.
129    #[serde(skip_serializing_if = "Option::is_none")]
130    #[pyo3(get)]
131    pub method: Option<HarmBlockMethod>,
132}
133
134#[pymethods]
135impl SafetySetting {
136    #[new]
137    #[pyo3(signature = (category, threshold, method=None))]
138    pub fn new(
139        category: HarmCategory,
140        threshold: HarmBlockThreshold,
141        method: Option<HarmBlockMethod>,
142    ) -> Self {
143        SafetySetting {
144            category,
145            threshold,
146            method,
147        }
148    }
149}
150
151#[pyclass]
152#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
153#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
154pub enum Modality {
155    ModalityUnspecified,
156    Text,
157    Image,
158    Audio,
159}
160
161#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
162#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
163#[pyclass]
164pub enum MediaResolution {
165    MediaResolutionUnspecified,
166    MediaResolutionLow,
167    MediaResolutionMedium,
168    MediaResolutionHigh,
169}
170
171#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
172#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
173pub enum ModelRoutingPreference {
174    Unknown,
175    PrioritizeQuality,
176    Balanced,
177    PrioritizeCost,
178}
179
180#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq, Eq)]
181#[serde(rename_all = "camelCase", default)]
182#[pyclass]
183pub struct ThinkingConfig {
184    #[serde(skip_serializing_if = "Option::is_none")]
185    pub include_thoughts: Option<bool>,
186    #[serde(skip_serializing_if = "Option::is_none")]
187    pub thinking_budget: Option<i32>,
188}
189
190#[pymethods]
191impl ThinkingConfig {
192    #[new]
193    #[pyo3(signature = (include_thoughts=None, thinking_budget=None))]
194    pub fn new(include_thoughts: Option<bool>, thinking_budget: Option<i32>) -> Self {
195        ThinkingConfig {
196            include_thoughts,
197            thinking_budget,
198        }
199    }
200}
201
202#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
203#[serde(rename_all = "camelCase")]
204pub struct AutoRoutingMode {
205    #[serde(skip_serializing_if = "Option::is_none")]
206    pub model_routing_preference: Option<ModelRoutingPreference>,
207}
208
209#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
210#[serde(rename_all = "camelCase")]
211pub struct ManualRoutingMode {
212    pub model_name: String,
213}
214
215#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
216#[serde(rename_all = "camelCase")]
217#[serde(untagged)]
218pub enum RoutingConfigMode {
219    AutoMode(AutoRoutingMode),
220    ManualMode(ManualRoutingMode),
221}
222
223#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
224#[serde(rename_all = "camelCase")]
225pub struct RoutingConfig {
226    #[serde(flatten)]
227    pub routing_config: RoutingConfigMode,
228}
229
230#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
231#[serde(rename_all = "camelCase")]
232#[pyclass]
233pub struct PrebuiltVoiceConfig {
234    pub voice_name: String,
235}
236
237#[pymethods]
238impl PrebuiltVoiceConfig {
239    #[new]
240    pub fn new(voice_name: String) -> Self {
241        PrebuiltVoiceConfig { voice_name }
242    }
243}
244
245#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
246#[serde(rename_all = "camelCase")]
247#[serde(untagged)]
248#[pyclass]
249pub enum VoiceConfigMode {
250    PrebuiltVoiceConfig(PrebuiltVoiceConfig),
251}
252
253#[pymethods]
254impl VoiceConfigMode {
255    #[new]
256    pub fn new(prebuilt_voice_config: PrebuiltVoiceConfig) -> Self {
257        VoiceConfigMode::PrebuiltVoiceConfig(prebuilt_voice_config)
258    }
259}
260
261#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
262#[serde(rename_all = "camelCase")]
263#[pyclass]
264pub struct VoiceConfig {
265    #[serde(flatten)]
266    pub voice_config: VoiceConfigMode,
267}
268
269#[pymethods]
270impl VoiceConfig {
271    #[new]
272    pub fn new(voice_config: VoiceConfigMode) -> Self {
273        VoiceConfig { voice_config }
274    }
275}
276
277#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq, Eq)]
278#[serde(rename_all = "camelCase", default)]
279#[pyclass]
280pub struct SpeechConfig {
281    #[serde(skip_serializing_if = "Option::is_none")]
282    pub voice_config: Option<VoiceConfig>,
283    #[serde(skip_serializing_if = "Option::is_none")]
284    pub language_code: Option<String>,
285}
286
287#[pymethods]
288impl SpeechConfig {
289    #[new]
290    pub fn new(voice_config: Option<VoiceConfig>, language_code: Option<String>) -> Self {
291        SpeechConfig {
292            voice_config,
293            language_code,
294        }
295    }
296}
297
298#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
299#[serde(rename_all = "camelCase", default)]
300#[pyclass]
301pub struct GenerationConfig {
302    #[serde(skip_serializing_if = "Option::is_none")]
303    #[pyo3(get)]
304    pub stop_sequences: Option<Vec<String>>,
305
306    #[serde(skip_serializing_if = "Option::is_none")]
307    #[pyo3(get)]
308    pub response_mime_type: Option<String>,
309
310    #[serde(skip_serializing_if = "Option::is_none")]
311    #[pyo3(get)]
312    pub response_modalities: Option<Vec<Modality>>,
313
314    #[serde(skip_serializing_if = "Option::is_none")]
315    #[pyo3(get)]
316    pub thinking_config: Option<ThinkingConfig>,
317
318    #[serde(skip_serializing_if = "Option::is_none")]
319    #[pyo3(get)]
320    pub temperature: Option<f32>,
321
322    #[serde(skip_serializing_if = "Option::is_none")]
323    #[pyo3(get)]
324    pub top_p: Option<f32>,
325
326    #[serde(skip_serializing_if = "Option::is_none")]
327    #[pyo3(get)]
328    pub top_k: Option<i32>,
329
330    #[serde(skip_serializing_if = "Option::is_none")]
331    #[pyo3(get)]
332    pub candidate_count: Option<i32>,
333
334    #[serde(skip_serializing_if = "Option::is_none")]
335    #[pyo3(get)]
336    pub max_output_tokens: Option<i32>,
337
338    #[serde(skip_serializing_if = "Option::is_none")]
339    #[pyo3(get)]
340    pub response_logprobs: Option<bool>,
341
342    #[serde(skip_serializing_if = "Option::is_none")]
343    #[pyo3(get)]
344    pub logprobs: Option<i32>,
345
346    #[serde(skip_serializing_if = "Option::is_none")]
347    #[pyo3(get)]
348    pub presence_penalty: Option<f32>,
349
350    #[serde(skip_serializing_if = "Option::is_none")]
351    #[pyo3(get)]
352    pub frequency_penalty: Option<f32>,
353
354    #[serde(skip_serializing_if = "Option::is_none")]
355    #[pyo3(get)]
356    pub seed: Option<i32>,
357
358    #[serde(skip_serializing_if = "Option::is_none")]
359    pub response_schema: Option<Schema>,
360
361    #[serde(skip_serializing_if = "Option::is_none")]
362    pub response_json_schema: Option<Value>,
363
364    #[serde(skip_serializing_if = "Option::is_none")]
365    pub routing_config: Option<RoutingConfig>,
366
367    #[serde(skip_serializing_if = "Option::is_none")]
368    #[pyo3(get)]
369    pub audio_timestamp: Option<bool>,
370
371    #[serde(skip_serializing_if = "Option::is_none")]
372    #[pyo3(get)]
373    pub media_resolution: Option<MediaResolution>,
374
375    #[serde(skip_serializing_if = "Option::is_none")]
376    #[pyo3(get)]
377    pub speech_config: Option<SpeechConfig>,
378
379    #[serde(skip_serializing_if = "Option::is_none")]
380    #[pyo3(get)]
381    pub enable_affective_dialog: Option<bool>,
382}
383
384#[pymethods]
385impl GenerationConfig {
386    #[new]
387    #[pyo3(signature = (stop_sequences=None, response_mime_type=None, response_modalities=None, thinking_config=None, temperature=None, top_p=None, top_k=None, candidate_count=None, max_output_tokens=None, response_logprobs=None, logprobs=None, presence_penalty=None, frequency_penalty=None, seed=None, audio_timestamp=None, media_resolution=None, speech_config=None, enable_affective_dialog=None))]
388    #[allow(clippy::too_many_arguments)]
389    pub fn new(
390        stop_sequences: Option<Vec<String>>,
391        response_mime_type: Option<String>,
392        response_modalities: Option<Vec<Modality>>,
393        thinking_config: Option<ThinkingConfig>,
394        temperature: Option<f32>,
395        top_p: Option<f32>,
396        top_k: Option<i32>,
397        candidate_count: Option<i32>,
398        max_output_tokens: Option<i32>,
399        response_logprobs: Option<bool>,
400        logprobs: Option<i32>,
401        presence_penalty: Option<f32>,
402        frequency_penalty: Option<f32>,
403        seed: Option<i32>,
404        //TODO: revisit this later
405        //response_schema: Option<Schema>,
406        //response_json_schema: Option<Value>,
407        //routing_config: Option<RoutingConfig>,
408        audio_timestamp: Option<bool>,
409        media_resolution: Option<MediaResolution>,
410        speech_config: Option<SpeechConfig>,
411        enable_affective_dialog: Option<bool>,
412    ) -> Self {
413        Self {
414            stop_sequences,
415            response_mime_type,
416            response_modalities,
417            thinking_config,
418            temperature,
419            top_p,
420            top_k,
421            candidate_count,
422            max_output_tokens,
423            response_logprobs,
424            logprobs,
425            presence_penalty,
426            frequency_penalty,
427            seed,
428            audio_timestamp,
429            media_resolution,
430            speech_config,
431            enable_affective_dialog,
432            ..Default::default()
433        }
434    }
435
436    pub fn __str__(&self) -> String {
437        PyHelperFuncs::__str__(self)
438    }
439}
440
441#[pyclass]
442#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
443#[serde(rename_all = "camelCase", default)]
444pub struct ModelArmorConfig {
445    #[serde(skip_serializing_if = "Option::is_none")]
446    pub prompt_template_name: Option<String>,
447    #[serde(skip_serializing_if = "Option::is_none")]
448    pub response_template_name: Option<String>,
449}
450
451#[pymethods]
452impl ModelArmorConfig {
453    #[new]
454    #[pyo3(signature = (prompt_template_name=None, response_template_name=None))]
455    pub fn new(
456        prompt_template_name: Option<String>,
457        response_template_name: Option<String>,
458    ) -> Self {
459        ModelArmorConfig {
460            prompt_template_name,
461            response_template_name,
462        }
463    }
464}
465
466#[pyclass]
467#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
468#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
469pub enum Mode {
470    ModeUnspecified,
471    Any,
472    #[default]
473    Auto,
474    #[pyo3(name = "None_Mode")]
475    None,
476}
477
478#[pyclass]
479#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
480#[serde(rename_all = "camelCase", default)]
481pub struct FunctionCallingConfig {
482    #[pyo3(get)]
483    pub mode: Option<Mode>,
484    #[pyo3(get)]
485    pub allowed_function_names: Option<Vec<String>>,
486}
487
488#[pymethods]
489impl FunctionCallingConfig {
490    #[new]
491    pub fn new(mode: Option<Mode>, allowed_function_names: Option<Vec<String>>) -> Self {
492        FunctionCallingConfig {
493            mode,
494            allowed_function_names,
495        }
496    }
497}
498
499#[pyclass]
500#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
501#[serde(rename_all = "camelCase", default)]
502pub struct LatLng {
503    #[pyo3(get)]
504    pub latitude: f64,
505    #[pyo3(get)]
506    pub longitude: f64,
507}
508
509#[pymethods]
510impl LatLng {
511    #[new]
512    pub fn new(latitude: f64, longitude: f64) -> Self {
513        LatLng {
514            latitude,
515            longitude,
516        }
517    }
518}
519
520#[pyclass]
521#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
522#[serde(rename_all = "camelCase", default)]
523pub struct RetrievalConfig {
524    #[pyo3(get)]
525    pub lat_lng: LatLng,
526
527    #[pyo3(get)]
528    pub language_code: String,
529}
530
531#[pymethods]
532impl RetrievalConfig {
533    #[new]
534    pub fn new(lat_lng: LatLng, language_code: String) -> Self {
535        RetrievalConfig {
536            lat_lng,
537            language_code,
538        }
539    }
540}
541
542#[pyclass]
543#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
544#[serde(rename_all = "camelCase", default)]
545pub struct ToolConfig {
546    #[pyo3(get)]
547    function_calling_config: Option<FunctionCallingConfig>,
548    #[pyo3(get)]
549    retrieval_config: Option<RetrievalConfig>,
550}
551
552#[pymethods]
553impl ToolConfig {
554    #[new]
555    #[pyo3(signature = (function_calling_config=None, retrieval_config=None))]
556    pub fn new(
557        function_calling_config: Option<FunctionCallingConfig>,
558        retrieval_config: Option<RetrievalConfig>,
559    ) -> Self {
560        ToolConfig {
561            function_calling_config,
562            retrieval_config,
563        }
564    }
565}
566
567#[pyclass]
568#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
569pub struct GeminiSettings {
570    #[pyo3(get)]
571    #[serde(skip_serializing_if = "Option::is_none")]
572    pub labels: Option<HashMap<String, String>>,
573
574    #[pyo3(get)]
575    #[serde(skip_serializing_if = "Option::is_none")]
576    pub tool_config: Option<ToolConfig>,
577
578    #[pyo3(get)]
579    #[serde(skip_serializing_if = "Option::is_none")]
580    pub generation_config: Option<GenerationConfig>,
581
582    #[pyo3(get)]
583    #[serde(skip_serializing_if = "Option::is_none")]
584    pub safety_settings: Option<Vec<SafetySetting>>,
585
586    #[pyo3(get)]
587    #[serde(skip_serializing_if = "Option::is_none")]
588    pub model_armor_config: Option<ModelArmorConfig>,
589
590    #[serde(skip_serializing_if = "Option::is_none")]
591    pub extra_body: Option<Value>,
592}
593
594#[pymethods]
595impl GeminiSettings {
596    #[new]
597    #[pyo3(signature = (labels=None, tool_config=None, generation_config=None, safety_settings=None, model_armor_config=None, extra_body=None))]
598    pub fn new(
599        labels: Option<HashMap<String, String>>,
600        tool_config: Option<ToolConfig>,
601        generation_config: Option<GenerationConfig>,
602        safety_settings: Option<Vec<SafetySetting>>,
603        model_armor_config: Option<ModelArmorConfig>,
604        extra_body: Option<&Bound<'_, PyAny>>,
605    ) -> Result<Self, UtilError> {
606        let extra = match extra_body {
607            Some(obj) => Some(pyobject_to_json(obj)?),
608            None => None,
609        };
610
611        Ok(GeminiSettings {
612            labels,
613            tool_config,
614            generation_config,
615            safety_settings,
616            model_armor_config,
617            extra_body: extra,
618        })
619    }
620
621    #[getter]
622    pub fn extra_body<'py>(
623        &self,
624        py: Python<'py>,
625    ) -> Result<Option<Bound<'py, PyDict>>, UtilError> {
626        // error if extra body is None
627        self.extra_body
628            .as_ref()
629            .map(|v| {
630                let pydict = PyDict::new(py);
631                json_to_pydict(py, v, &pydict)
632            })
633            .transpose()
634    }
635
636    pub fn __str__(&self) -> String {
637        PyHelperFuncs::__str__(self)
638    }
639
640    pub fn model_dump<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyDict>, TypeError> {
641        // iterate over each field in model_settings and add to the dict if it is not None
642        let json = serde_json::to_value(self)?;
643        let pydict = PyDict::new(py);
644        json_to_pydict(py, &json, &pydict)?;
645        Ok(pydict)
646    }
647
648    pub fn settings_type(&self) -> SettingsType {
649        SettingsType::GoogleChat
650    }
651}
652
653impl GeminiSettings {
654    pub fn configure_for_structured_output(&mut self) {
655        // Ensure generation_config exists and set response_mime_type
656        match self.generation_config.as_mut() {
657            Some(generation_config) => {
658                generation_config.response_mime_type = Some("application/json".to_string());
659            }
660            None => {
661                self.generation_config = Some(GenerationConfig {
662                    response_mime_type: Some("application/json".to_string()),
663                    ..Default::default()
664                });
665            }
666        }
667    }
668}