1use crate::{SettingsType, TypeError};
2use potato_util::{json_to_pydict, pyobject_to_json, PyHelperFuncs, UtilError};
3use pyo3::prelude::*;
4use pyo3::types::PyDict;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::collections::HashMap;
8
9#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
10#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
11pub enum SchemaType {
12 TypeUnspecified,
13 String,
14 Number,
15 Integer,
16 Boolean,
17 Array,
18 Object,
19 Null,
20}
21
22#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
23#[serde(rename_all = "camelCase", default)]
24pub struct Schema {
25 #[serde(skip_serializing_if = "Option::is_none")]
26 pub r#type: Option<SchemaType>,
27 #[serde(skip_serializing_if = "Option::is_none")]
28 pub format: Option<String>,
29 #[serde(skip_serializing_if = "Option::is_none")]
30 pub title: Option<String>,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub description: Option<String>,
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub nullable: Option<bool>,
35 #[serde(skip_serializing_if = "Option::is_none")]
36 pub default: Option<Value>,
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub items: Option<Box<Schema>>,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 pub min_items: Option<String>,
41 #[serde(skip_serializing_if = "Option::is_none")]
42 pub max_items: Option<String>,
43 #[serde(skip_serializing_if = "Option::is_none")]
44 pub r#enum: Option<Vec<String>>,
45 #[serde(skip_serializing_if = "Option::is_none")]
46 pub properties: Option<HashMap<String, Schema>>,
47 #[serde(skip_serializing_if = "Option::is_none")]
48 pub property_ordering: Option<Vec<String>>,
49 #[serde(skip_serializing_if = "Option::is_none")]
50 pub required: Option<Vec<String>>,
51 #[serde(skip_serializing_if = "Option::is_none")]
52 pub min_properties: Option<String>,
53 #[serde(skip_serializing_if = "Option::is_none")]
54 pub max_properties: Option<String>,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 pub minimum: Option<f64>,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 pub maximum: Option<f64>,
59 #[serde(skip_serializing_if = "Option::is_none")]
60 pub min_length: Option<String>,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 pub max_length: Option<String>,
63 #[serde(skip_serializing_if = "Option::is_none")]
64 pub pattern: Option<String>,
65 #[serde(skip_serializing_if = "Option::is_none")]
66 pub example: Option<Value>,
67 #[serde(skip_serializing_if = "Option::is_none")]
68 pub any_of: Option<Vec<Schema>>,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 pub additional_properties: Option<Value>,
71 #[serde(rename = "ref", skip_serializing_if = "Option::is_none")]
72 pub ref_path: Option<String>,
73 #[serde(skip_serializing_if = "Option::is_none")]
74 pub defs: Option<HashMap<String, Schema>>,
75}
76
77#[pyclass]
78#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Default)]
79#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
80pub enum HarmCategory {
81 #[default]
82 HarmCategoryUnspecified,
83 HarmCategoryHateSpeech,
84 HarmCategoryDangerousContent,
85 HarmCategoryHarassment,
86 HarmCategorySexuallyExplicit,
87 HarmCategoryImageHate,
88 HarmCategoryImageDangerousContent,
89 HarmCategoryImageHarassment,
90 HarmCategoryImageSexuallyExplicit,
91}
92
93#[pyclass]
95#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Default)]
96#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
97pub enum HarmBlockThreshold {
98 HarmBlockThresholdUnspecified,
99 BlockLowAndAbove,
100 BlockMediumAndAbove,
101 BlockOnlyHigh,
102 BlockNone,
103 #[default]
104 Off,
105}
106
107#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
109#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
110#[pyclass]
111pub enum HarmBlockMethod {
112 HarmBlockMethodUnspecified,
113 Severity,
114 Probability,
115}
116
117#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
119#[serde(rename_all = "camelCase")]
120#[pyclass]
121pub struct SafetySetting {
122 #[pyo3(get)]
124 pub category: HarmCategory,
125 #[pyo3(get)]
127 pub threshold: HarmBlockThreshold,
128 #[serde(skip_serializing_if = "Option::is_none")]
130 #[pyo3(get)]
131 pub method: Option<HarmBlockMethod>,
132}
133
134#[pymethods]
135impl SafetySetting {
136 #[new]
137 #[pyo3(signature = (category, threshold, method=None))]
138 pub fn new(
139 category: HarmCategory,
140 threshold: HarmBlockThreshold,
141 method: Option<HarmBlockMethod>,
142 ) -> Self {
143 SafetySetting {
144 category,
145 threshold,
146 method,
147 }
148 }
149}
150
151#[pyclass]
152#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
153#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
154pub enum Modality {
155 ModalityUnspecified,
156 Text,
157 Image,
158 Audio,
159}
160
161#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
162#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
163#[pyclass]
164pub enum MediaResolution {
165 MediaResolutionUnspecified,
166 MediaResolutionLow,
167 MediaResolutionMedium,
168 MediaResolutionHigh,
169}
170
171#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
172#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
173pub enum ModelRoutingPreference {
174 Unknown,
175 PrioritizeQuality,
176 Balanced,
177 PrioritizeCost,
178}
179
180#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq, Eq)]
181#[serde(rename_all = "camelCase", default)]
182#[pyclass]
183pub struct ThinkingConfig {
184 #[serde(skip_serializing_if = "Option::is_none")]
185 pub include_thoughts: Option<bool>,
186 #[serde(skip_serializing_if = "Option::is_none")]
187 pub thinking_budget: Option<i32>,
188}
189
190#[pymethods]
191impl ThinkingConfig {
192 #[new]
193 #[pyo3(signature = (include_thoughts=None, thinking_budget=None))]
194 pub fn new(include_thoughts: Option<bool>, thinking_budget: Option<i32>) -> Self {
195 ThinkingConfig {
196 include_thoughts,
197 thinking_budget,
198 }
199 }
200}
201
202#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
203#[serde(rename_all = "camelCase")]
204pub struct AutoRoutingMode {
205 #[serde(skip_serializing_if = "Option::is_none")]
206 pub model_routing_preference: Option<ModelRoutingPreference>,
207}
208
209#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
210#[serde(rename_all = "camelCase")]
211pub struct ManualRoutingMode {
212 pub model_name: String,
213}
214
215#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
216#[serde(rename_all = "camelCase")]
217#[serde(untagged)]
218pub enum RoutingConfigMode {
219 AutoMode(AutoRoutingMode),
220 ManualMode(ManualRoutingMode),
221}
222
223#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
224#[serde(rename_all = "camelCase")]
225pub struct RoutingConfig {
226 #[serde(flatten)]
227 pub routing_config: RoutingConfigMode,
228}
229
230#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
231#[serde(rename_all = "camelCase")]
232#[pyclass]
233pub struct PrebuiltVoiceConfig {
234 pub voice_name: String,
235}
236
237#[pymethods]
238impl PrebuiltVoiceConfig {
239 #[new]
240 pub fn new(voice_name: String) -> Self {
241 PrebuiltVoiceConfig { voice_name }
242 }
243}
244
245#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
246#[serde(rename_all = "camelCase")]
247#[serde(untagged)]
248#[pyclass]
249pub enum VoiceConfigMode {
250 PrebuiltVoiceConfig(PrebuiltVoiceConfig),
251}
252
253#[pymethods]
254impl VoiceConfigMode {
255 #[new]
256 pub fn new(prebuilt_voice_config: PrebuiltVoiceConfig) -> Self {
257 VoiceConfigMode::PrebuiltVoiceConfig(prebuilt_voice_config)
258 }
259}
260
261#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
262#[serde(rename_all = "camelCase")]
263#[pyclass]
264pub struct VoiceConfig {
265 #[serde(flatten)]
266 pub voice_config: VoiceConfigMode,
267}
268
269#[pymethods]
270impl VoiceConfig {
271 #[new]
272 pub fn new(voice_config: VoiceConfigMode) -> Self {
273 VoiceConfig { voice_config }
274 }
275}
276
277#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq, Eq)]
278#[serde(rename_all = "camelCase", default)]
279#[pyclass]
280pub struct SpeechConfig {
281 #[serde(skip_serializing_if = "Option::is_none")]
282 pub voice_config: Option<VoiceConfig>,
283 #[serde(skip_serializing_if = "Option::is_none")]
284 pub language_code: Option<String>,
285}
286
287#[pymethods]
288impl SpeechConfig {
289 #[new]
290 pub fn new(voice_config: Option<VoiceConfig>, language_code: Option<String>) -> Self {
291 SpeechConfig {
292 voice_config,
293 language_code,
294 }
295 }
296}
297
298#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
299#[serde(rename_all = "camelCase", default)]
300#[pyclass]
301pub struct GenerationConfig {
302 #[serde(skip_serializing_if = "Option::is_none")]
303 #[pyo3(get)]
304 pub stop_sequences: Option<Vec<String>>,
305
306 #[serde(skip_serializing_if = "Option::is_none")]
307 #[pyo3(get)]
308 pub response_mime_type: Option<String>,
309
310 #[serde(skip_serializing_if = "Option::is_none")]
311 #[pyo3(get)]
312 pub response_modalities: Option<Vec<Modality>>,
313
314 #[serde(skip_serializing_if = "Option::is_none")]
315 #[pyo3(get)]
316 pub thinking_config: Option<ThinkingConfig>,
317
318 #[serde(skip_serializing_if = "Option::is_none")]
319 #[pyo3(get)]
320 pub temperature: Option<f32>,
321
322 #[serde(skip_serializing_if = "Option::is_none")]
323 #[pyo3(get)]
324 pub top_p: Option<f32>,
325
326 #[serde(skip_serializing_if = "Option::is_none")]
327 #[pyo3(get)]
328 pub top_k: Option<i32>,
329
330 #[serde(skip_serializing_if = "Option::is_none")]
331 #[pyo3(get)]
332 pub candidate_count: Option<i32>,
333
334 #[serde(skip_serializing_if = "Option::is_none")]
335 #[pyo3(get)]
336 pub max_output_tokens: Option<i32>,
337
338 #[serde(skip_serializing_if = "Option::is_none")]
339 #[pyo3(get)]
340 pub response_logprobs: Option<bool>,
341
342 #[serde(skip_serializing_if = "Option::is_none")]
343 #[pyo3(get)]
344 pub logprobs: Option<i32>,
345
346 #[serde(skip_serializing_if = "Option::is_none")]
347 #[pyo3(get)]
348 pub presence_penalty: Option<f32>,
349
350 #[serde(skip_serializing_if = "Option::is_none")]
351 #[pyo3(get)]
352 pub frequency_penalty: Option<f32>,
353
354 #[serde(skip_serializing_if = "Option::is_none")]
355 #[pyo3(get)]
356 pub seed: Option<i32>,
357
358 #[serde(skip_serializing_if = "Option::is_none")]
359 pub response_schema: Option<Schema>,
360
361 #[serde(skip_serializing_if = "Option::is_none")]
362 pub response_json_schema: Option<Value>,
363
364 #[serde(skip_serializing_if = "Option::is_none")]
365 pub routing_config: Option<RoutingConfig>,
366
367 #[serde(skip_serializing_if = "Option::is_none")]
368 #[pyo3(get)]
369 pub audio_timestamp: Option<bool>,
370
371 #[serde(skip_serializing_if = "Option::is_none")]
372 #[pyo3(get)]
373 pub media_resolution: Option<MediaResolution>,
374
375 #[serde(skip_serializing_if = "Option::is_none")]
376 #[pyo3(get)]
377 pub speech_config: Option<SpeechConfig>,
378
379 #[serde(skip_serializing_if = "Option::is_none")]
380 #[pyo3(get)]
381 pub enable_affective_dialog: Option<bool>,
382}
383
384#[pymethods]
385impl GenerationConfig {
386 #[new]
387 #[pyo3(signature = (stop_sequences=None, response_mime_type=None, response_modalities=None, thinking_config=None, temperature=None, top_p=None, top_k=None, candidate_count=None, max_output_tokens=None, response_logprobs=None, logprobs=None, presence_penalty=None, frequency_penalty=None, seed=None, audio_timestamp=None, media_resolution=None, speech_config=None, enable_affective_dialog=None))]
388 #[allow(clippy::too_many_arguments)]
389 pub fn new(
390 stop_sequences: Option<Vec<String>>,
391 response_mime_type: Option<String>,
392 response_modalities: Option<Vec<Modality>>,
393 thinking_config: Option<ThinkingConfig>,
394 temperature: Option<f32>,
395 top_p: Option<f32>,
396 top_k: Option<i32>,
397 candidate_count: Option<i32>,
398 max_output_tokens: Option<i32>,
399 response_logprobs: Option<bool>,
400 logprobs: Option<i32>,
401 presence_penalty: Option<f32>,
402 frequency_penalty: Option<f32>,
403 seed: Option<i32>,
404 audio_timestamp: Option<bool>,
409 media_resolution: Option<MediaResolution>,
410 speech_config: Option<SpeechConfig>,
411 enable_affective_dialog: Option<bool>,
412 ) -> Self {
413 Self {
414 stop_sequences,
415 response_mime_type,
416 response_modalities,
417 thinking_config,
418 temperature,
419 top_p,
420 top_k,
421 candidate_count,
422 max_output_tokens,
423 response_logprobs,
424 logprobs,
425 presence_penalty,
426 frequency_penalty,
427 seed,
428 audio_timestamp,
429 media_resolution,
430 speech_config,
431 enable_affective_dialog,
432 ..Default::default()
433 }
434 }
435
436 pub fn __str__(&self) -> String {
437 PyHelperFuncs::__str__(self)
438 }
439}
440
441#[pyclass]
442#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
443#[serde(rename_all = "camelCase", default)]
444pub struct ModelArmorConfig {
445 #[serde(skip_serializing_if = "Option::is_none")]
446 pub prompt_template_name: Option<String>,
447 #[serde(skip_serializing_if = "Option::is_none")]
448 pub response_template_name: Option<String>,
449}
450
451#[pymethods]
452impl ModelArmorConfig {
453 #[new]
454 #[pyo3(signature = (prompt_template_name=None, response_template_name=None))]
455 pub fn new(
456 prompt_template_name: Option<String>,
457 response_template_name: Option<String>,
458 ) -> Self {
459 ModelArmorConfig {
460 prompt_template_name,
461 response_template_name,
462 }
463 }
464}
465
466#[pyclass]
467#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
468#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
469pub enum Mode {
470 ModeUnspecified,
471 Any,
472 #[default]
473 Auto,
474 #[pyo3(name = "None_Mode")]
475 None,
476}
477
478#[pyclass]
479#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
480#[serde(rename_all = "camelCase", default)]
481pub struct FunctionCallingConfig {
482 #[pyo3(get)]
483 pub mode: Option<Mode>,
484 #[pyo3(get)]
485 pub allowed_function_names: Option<Vec<String>>,
486}
487
488#[pymethods]
489impl FunctionCallingConfig {
490 #[new]
491 pub fn new(mode: Option<Mode>, allowed_function_names: Option<Vec<String>>) -> Self {
492 FunctionCallingConfig {
493 mode,
494 allowed_function_names,
495 }
496 }
497}
498
499#[pyclass]
500#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
501#[serde(rename_all = "camelCase", default)]
502pub struct LatLng {
503 #[pyo3(get)]
504 pub latitude: f64,
505 #[pyo3(get)]
506 pub longitude: f64,
507}
508
509#[pymethods]
510impl LatLng {
511 #[new]
512 pub fn new(latitude: f64, longitude: f64) -> Self {
513 LatLng {
514 latitude,
515 longitude,
516 }
517 }
518}
519
520#[pyclass]
521#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
522#[serde(rename_all = "camelCase", default)]
523pub struct RetrievalConfig {
524 #[pyo3(get)]
525 pub lat_lng: LatLng,
526
527 #[pyo3(get)]
528 pub language_code: String,
529}
530
531#[pymethods]
532impl RetrievalConfig {
533 #[new]
534 pub fn new(lat_lng: LatLng, language_code: String) -> Self {
535 RetrievalConfig {
536 lat_lng,
537 language_code,
538 }
539 }
540}
541
542#[pyclass]
543#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
544#[serde(rename_all = "camelCase", default)]
545pub struct ToolConfig {
546 #[pyo3(get)]
547 function_calling_config: Option<FunctionCallingConfig>,
548 #[pyo3(get)]
549 retrieval_config: Option<RetrievalConfig>,
550}
551
552#[pymethods]
553impl ToolConfig {
554 #[new]
555 #[pyo3(signature = (function_calling_config=None, retrieval_config=None))]
556 pub fn new(
557 function_calling_config: Option<FunctionCallingConfig>,
558 retrieval_config: Option<RetrievalConfig>,
559 ) -> Self {
560 ToolConfig {
561 function_calling_config,
562 retrieval_config,
563 }
564 }
565}
566
567#[pyclass]
568#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
569pub struct GeminiSettings {
570 #[pyo3(get)]
571 #[serde(skip_serializing_if = "Option::is_none")]
572 pub labels: Option<HashMap<String, String>>,
573
574 #[pyo3(get)]
575 #[serde(skip_serializing_if = "Option::is_none")]
576 pub tool_config: Option<ToolConfig>,
577
578 #[pyo3(get)]
579 #[serde(skip_serializing_if = "Option::is_none")]
580 pub generation_config: Option<GenerationConfig>,
581
582 #[pyo3(get)]
583 #[serde(skip_serializing_if = "Option::is_none")]
584 pub safety_settings: Option<Vec<SafetySetting>>,
585
586 #[pyo3(get)]
587 #[serde(skip_serializing_if = "Option::is_none")]
588 pub model_armor_config: Option<ModelArmorConfig>,
589
590 #[serde(skip_serializing_if = "Option::is_none")]
591 pub extra_body: Option<Value>,
592}
593
594#[pymethods]
595impl GeminiSettings {
596 #[new]
597 #[pyo3(signature = (labels=None, tool_config=None, generation_config=None, safety_settings=None, model_armor_config=None, extra_body=None))]
598 pub fn new(
599 labels: Option<HashMap<String, String>>,
600 tool_config: Option<ToolConfig>,
601 generation_config: Option<GenerationConfig>,
602 safety_settings: Option<Vec<SafetySetting>>,
603 model_armor_config: Option<ModelArmorConfig>,
604 extra_body: Option<&Bound<'_, PyAny>>,
605 ) -> Result<Self, UtilError> {
606 let extra = match extra_body {
607 Some(obj) => Some(pyobject_to_json(obj)?),
608 None => None,
609 };
610
611 Ok(GeminiSettings {
612 labels,
613 tool_config,
614 generation_config,
615 safety_settings,
616 model_armor_config,
617 extra_body: extra,
618 })
619 }
620
621 #[getter]
622 pub fn extra_body<'py>(
623 &self,
624 py: Python<'py>,
625 ) -> Result<Option<Bound<'py, PyDict>>, UtilError> {
626 self.extra_body
628 .as_ref()
629 .map(|v| {
630 let pydict = PyDict::new(py);
631 json_to_pydict(py, v, &pydict)
632 })
633 .transpose()
634 }
635
636 pub fn __str__(&self) -> String {
637 PyHelperFuncs::__str__(self)
638 }
639
640 pub fn model_dump<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyDict>, TypeError> {
641 let json = serde_json::to_value(self)?;
643 let pydict = PyDict::new(py);
644 json_to_pydict(py, &json, &pydict)?;
645 Ok(pydict)
646 }
647
648 pub fn settings_type(&self) -> SettingsType {
649 SettingsType::GoogleChat
650 }
651}
652
653impl GeminiSettings {
654 pub fn configure_for_structured_output(&mut self) {
655 match self.generation_config.as_mut() {
657 Some(generation_config) => {
658 generation_config.response_mime_type = Some("application/json".to_string());
659 }
660 None => {
661 self.generation_config = Some(GenerationConfig {
662 response_mime_type: Some("application/json".to_string()),
663 ..Default::default()
664 });
665 }
666 }
667 }
668}