gemini_rust/
models.rs

1use serde::{Deserialize, Serialize};
2
3/// Role of a message in a conversation
4#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
5#[serde(rename_all = "lowercase")]
6pub enum Role {
7    /// Message from the user
8    User,
9    /// Message from the model
10    Model,
11}
12
13/// Content part that can be included in a message
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
15#[serde(untagged)]
16pub enum Part {
17    /// Text content
18    Text {
19        /// The text content
20        text: String,
21        /// Whether this is a thought summary (Gemini 2.5 series only)
22        #[serde(skip_serializing_if = "Option::is_none")]
23        thought: Option<bool>,
24    },
25    InlineData {
26        /// The blob data
27        #[serde(rename = "inlineData")]
28        inline_data: Blob,
29    },
30    /// Function call from the model
31    FunctionCall {
32        /// The function call details
33        #[serde(rename = "functionCall")]
34        function_call: super::tools::FunctionCall,
35    },
36    /// Function response (results from executing a function call)
37    FunctionResponse {
38        /// The function response details
39        #[serde(rename = "functionResponse")]
40        function_response: super::tools::FunctionResponse,
41    },
42}
43
44/// Blob for a message part
45#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
46#[serde(rename_all = "camelCase")]
47pub struct Blob {
48    pub mime_type: String,
49    pub data: String, // Base64 encoded data
50}
51
52impl Blob {
53    /// Create a new blob with mime type and data
54    pub fn new(mime_type: impl Into<String>, data: impl Into<String>) -> Self {
55        Self {
56            mime_type: mime_type.into(),
57            data: data.into(),
58        }
59    }
60}
61
62/// Content of a message
63#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)]
64#[serde(rename_all = "camelCase")]
65pub struct Content {
66    /// Parts of the content
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub parts: Option<Vec<Part>>,
69    /// Role of the content
70    #[serde(skip_serializing_if = "Option::is_none")]
71    pub role: Option<Role>,
72}
73
74impl Content {
75    /// Create a new text content
76    pub fn text(text: impl Into<String>) -> Self {
77        Self {
78            parts: Some(vec![Part::Text {
79                text: text.into(),
80                thought: None,
81            }]),
82            role: None,
83        }
84    }
85
86    /// Create a new content with a function call
87    pub fn function_call(function_call: super::tools::FunctionCall) -> Self {
88        Self {
89            parts: Some(vec![Part::FunctionCall { function_call }]),
90            role: None,
91        }
92    }
93
94    /// Create a new content with a function response
95    pub fn function_response(function_response: super::tools::FunctionResponse) -> Self {
96        Self {
97            parts: Some(vec![Part::FunctionResponse { function_response }]),
98            role: None,
99        }
100    }
101
102    /// Create a new content with a function response from name and JSON value
103    pub fn function_response_json(name: impl Into<String>, response: serde_json::Value) -> Self {
104        Self {
105            parts: Some(vec![Part::FunctionResponse {
106                function_response: super::tools::FunctionResponse::new(name, response),
107            }]),
108            role: None,
109        }
110    }
111
112    /// Create a new content with inline data (blob data)
113    pub fn inline_data(mime_type: impl Into<String>, data: impl Into<String>) -> Self {
114        Self {
115            parts: Some(vec![Part::InlineData {
116                inline_data: Blob::new(mime_type, data),
117            }]),
118            role: None,
119        }
120    }
121
122    /// Add a role to this content
123    pub fn with_role(mut self, role: Role) -> Self {
124        self.role = Some(role);
125        self
126    }
127}
128
129/// Message in a conversation
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct Message {
132    /// Content of the message
133    pub content: Content,
134    /// Role of the message
135    pub role: Role,
136}
137
138impl Message {
139    /// Create a new user message with text content
140    pub fn user(text: impl Into<String>) -> Self {
141        Self {
142            content: Content::text(text).with_role(Role::User),
143            role: Role::User,
144        }
145    }
146
147    /// Create a new model message with text content
148    pub fn model(text: impl Into<String>) -> Self {
149        Self {
150            content: Content::text(text).with_role(Role::Model),
151            role: Role::Model,
152        }
153    }
154
155    pub fn embed(text: impl Into<String>) -> Self {
156        Self {
157            content: Content::text(text),
158            role: Role::Model,
159        }
160    }
161
162    /// Create a new function message with function response content from JSON
163    pub fn function(name: impl Into<String>, response: serde_json::Value) -> Self {
164        Self {
165            content: Content::function_response_json(name, response).with_role(Role::Model),
166            role: Role::Model,
167        }
168    }
169
170    /// Create a new function message with function response from a JSON string
171    pub fn function_str(
172        name: impl Into<String>,
173        response: impl Into<String>,
174    ) -> Result<Self, serde_json::Error> {
175        let response_str = response.into();
176        let json = serde_json::from_str(&response_str)?;
177        Ok(Self {
178            content: Content::function_response_json(name, json).with_role(Role::Model),
179            role: Role::Model,
180        })
181    }
182}
183
184/// Safety rating for content
185#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
186pub struct SafetyRating {
187    /// The category of the safety rating
188    pub category: String,
189    /// The probability that the content is harmful
190    pub probability: String,
191}
192
193/// Citation metadata for content
194#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
195#[serde(rename_all = "camelCase")]
196pub struct CitationMetadata {
197    /// The citation sources
198    pub citation_sources: Vec<CitationSource>,
199}
200
201/// Citation source
202#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
203#[serde(rename_all = "camelCase")]
204pub struct CitationSource {
205    /// The URI of the citation source
206    pub uri: Option<String>,
207    /// The title of the citation source
208    pub title: Option<String>,
209    /// The start index of the citation in the response
210    pub start_index: Option<i32>,
211    /// The end index of the citation in the response
212    pub end_index: Option<i32>,
213    /// The license of the citation source
214    pub license: Option<String>,
215    /// The publication date of the citation source
216    pub publication_date: Option<String>,
217}
218
219/// A candidate response
220#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
221#[serde(rename_all = "camelCase")]
222pub struct Candidate {
223    /// The content of the candidate
224    pub content: Content,
225    /// The safety ratings for the candidate
226    #[serde(skip_serializing_if = "Option::is_none")]
227    pub safety_ratings: Option<Vec<SafetyRating>>,
228    /// The citation metadata for the candidate
229    #[serde(skip_serializing_if = "Option::is_none")]
230    pub citation_metadata: Option<CitationMetadata>,
231    /// The finish reason for the candidate
232    #[serde(skip_serializing_if = "Option::is_none")]
233    pub finish_reason: Option<String>,
234    /// The index of the candidate
235    #[serde(skip_serializing_if = "Option::is_none")]
236    pub index: Option<i32>,
237}
238
239/// Metadata about token usage
240#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
241#[serde(rename_all = "camelCase")]
242pub struct UsageMetadata {
243    /// The number of prompt tokens
244    pub prompt_token_count: i32,
245    /// The number of response tokens
246    #[serde(skip_serializing_if = "Option::is_none")]
247    pub candidates_token_count: Option<i32>,
248    /// The total number of tokens
249    pub total_token_count: i32,
250    /// The number of thinking tokens (Gemini 2.5 series only)
251    #[serde(skip_serializing_if = "Option::is_none")]
252    pub thoughts_token_count: Option<i32>,
253    /// Detailed prompt token information
254    #[serde(skip_serializing_if = "Option::is_none")]
255    pub prompt_tokens_details: Option<Vec<PromptTokenDetails>>,
256}
257
258/// Details about prompt tokens by modality
259#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
260#[serde(rename_all = "camelCase")]
261pub struct PromptTokenDetails {
262    /// The modality (e.g., "TEXT")
263    pub modality: String,
264    /// Token count for this modality
265    pub token_count: i32,
266}
267
268/// Response from the Gemini API for content generation
269#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
270#[serde(rename_all = "camelCase")]
271pub struct GenerationResponse {
272    /// The candidates generated
273    pub candidates: Vec<Candidate>,
274    /// The prompt feedback
275    #[serde(skip_serializing_if = "Option::is_none")]
276    pub prompt_feedback: Option<PromptFeedback>,
277    /// Usage metadata
278    #[serde(skip_serializing_if = "Option::is_none")]
279    pub usage_metadata: Option<UsageMetadata>,
280    /// Model version used
281    #[serde(skip_serializing_if = "Option::is_none")]
282    pub model_version: Option<String>,
283    /// Response ID
284    #[serde(skip_serializing_if = "Option::is_none")]
285    pub response_id: Option<String>,
286}
287
288/// Content of the embedding
289#[derive(Debug, Clone, Serialize, Deserialize)]
290pub struct ContentEmbedding {
291    /// The values generated
292    pub values: Vec<f32>, //Maybe Quantize this
293}
294
295/// Response from the Gemini API for content embedding
296#[derive(Debug, Clone, Serialize, Deserialize)]
297pub struct ContentEmbeddingResponse {
298    /// The embeddings generated
299    pub embedding: ContentEmbedding,
300}
301
302/// Response from the Gemini API for batch content embedding
303#[derive(Debug, Clone, Serialize, Deserialize)]
304pub struct BatchContentEmbeddingResponse {
305    /// The embeddings generated
306    pub embeddings: Vec<ContentEmbedding>,
307}
308
309/// Feedback about the prompt
310#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
311#[serde(rename_all = "camelCase")]
312pub struct PromptFeedback {
313    /// The safety ratings for the prompt
314    pub safety_ratings: Vec<SafetyRating>,
315    /// The block reason if the prompt was blocked
316    #[serde(skip_serializing_if = "Option::is_none")]
317    pub block_reason: Option<String>,
318}
319
320impl GenerationResponse {
321    /// Get the text of the first candidate
322    pub fn text(&self) -> String {
323        self.candidates
324            .first()
325            .and_then(|c| {
326                c.content.parts.as_ref().and_then(|parts| {
327                    parts.first().and_then(|p| match p {
328                        Part::Text { text, thought: _ } => Some(text.clone()),
329                        _ => None,
330                    })
331                })
332            })
333            .unwrap_or_default()
334    }
335
336    /// Get function calls from the response
337    pub fn function_calls(&self) -> Vec<&super::tools::FunctionCall> {
338        self.candidates
339            .iter()
340            .flat_map(|c| {
341                c.content
342                    .parts
343                    .as_ref()
344                    .map(|parts| {
345                        parts
346                            .iter()
347                            .filter_map(|p| match p {
348                                Part::FunctionCall { function_call } => Some(function_call),
349                                _ => None,
350                            })
351                            .collect::<Vec<_>>()
352                    })
353                    .unwrap_or_default()
354            })
355            .collect()
356    }
357
358    /// Get thought summaries from the response
359    pub fn thoughts(&self) -> Vec<String> {
360        self.candidates
361            .iter()
362            .flat_map(|c| {
363                c.content
364                    .parts
365                    .as_ref()
366                    .map(|parts| {
367                        parts
368                            .iter()
369                            .filter_map(|p| match p {
370                                Part::Text {
371                                    text,
372                                    thought: Some(true),
373                                } => Some(text.clone()),
374                                _ => None,
375                            })
376                            .collect::<Vec<_>>()
377                    })
378                    .unwrap_or_default()
379            })
380            .collect()
381    }
382
383    /// Get all text parts (both regular text and thoughts)
384    pub fn all_text(&self) -> Vec<(String, bool)> {
385        self.candidates
386            .iter()
387            .flat_map(|c| {
388                c.content
389                    .parts
390                    .as_ref()
391                    .map(|parts| {
392                        parts
393                            .iter()
394                            .filter_map(|p| match p {
395                                Part::Text { text, thought } => {
396                                    Some((text.clone(), thought.unwrap_or(false)))
397                                }
398                                _ => None,
399                            })
400                            .collect::<Vec<_>>()
401                    })
402                    .unwrap_or_default()
403            })
404            .collect()
405    }
406}
407
408/// Request to generate content
409#[derive(Debug, Clone, Serialize, Deserialize)]
410#[serde(rename_all = "camelCase")]
411pub struct GenerateContentRequest {
412    /// The contents to generate content from
413    pub contents: Vec<Content>,
414    /// The generation config
415    #[serde(skip_serializing_if = "Option::is_none")]
416    pub generation_config: Option<GenerationConfig>,
417    /// The safety settings
418    #[serde(skip_serializing_if = "Option::is_none")]
419    pub safety_settings: Option<Vec<SafetySetting>>,
420    /// The tools that the model can use
421    #[serde(skip_serializing_if = "Option::is_none")]
422    pub tools: Option<Vec<super::tools::Tool>>,
423    /// The tool config
424    #[serde(skip_serializing_if = "Option::is_none")]
425    pub tool_config: Option<ToolConfig>,
426    /// The system instruction
427    #[serde(skip_serializing_if = "Option::is_none")]
428    pub system_instruction: Option<Content>,
429}
430
431/// Request to embed words
432#[derive(Debug, Clone, Serialize, Deserialize)]
433pub struct EmbedContentRequest {
434    /// The specified embedding model
435    pub model: String,
436    /// The chunks content to generate embeddings
437    pub content: Content,
438    /// The embedding task type (optional)
439    #[serde(skip_serializing_if = "Option::is_none")]
440    pub task_type: Option<TaskType>,
441    /// The title of the document (optional)
442    #[serde(skip_serializing_if = "Option::is_none")]
443    pub title: Option<String>,
444    /// The output_dimensionality (optional)
445    #[serde(skip_serializing_if = "Option::is_none")]
446    pub output_dimensionality: Option<i32>,
447}
448
449/// Request to batch embed requests
450#[derive(Debug, Clone, Serialize, Deserialize)]
451pub struct BatchEmbedContentsRequest {
452    /// The list of embed requests
453    pub requests: Vec<EmbedContentRequest>,
454}
455
456/// Request for batch content generation (corrected format)
457#[derive(Debug, Clone, Serialize, Deserialize)]
458#[serde(rename_all = "camelCase")]
459pub struct BatchGenerateContentRequest {
460    /// The batch configuration
461    pub batch: BatchConfig,
462}
463
464/// Batch configuration
465#[derive(Debug, Clone, Serialize, Deserialize)]
466#[serde(rename_all = "camelCase")]
467pub struct BatchConfig {
468    /// Display name for the batch
469    pub display_name: String,
470    /// Input configuration
471    pub input_config: InputConfig,
472}
473
474/// Input configuration for batch requests
475#[derive(Debug, Clone, Serialize, Deserialize)]
476#[serde(rename_all = "camelCase")]
477pub struct InputConfig {
478    /// The requests container
479    pub requests: RequestsContainer,
480}
481
482/// Container for requests
483#[derive(Debug, Clone, Serialize, Deserialize)]
484#[serde(rename_all = "camelCase")]
485pub struct RequestsContainer {
486    /// List of requests
487    pub requests: Vec<BatchRequestItem>,
488}
489
490/// Individual batch request item
491#[derive(Debug, Clone, Serialize, Deserialize)]
492#[serde(rename_all = "camelCase")]
493pub struct BatchRequestItem {
494    /// The actual request
495    pub request: GenerateContentRequest,
496    /// Metadata for the request
497    pub metadata: Option<RequestMetadata>,
498}
499
500/// Metadata for batch request
501#[derive(Debug, Clone, Serialize, Deserialize)]
502#[serde(rename_all = "camelCase")]
503pub struct RequestMetadata {
504    /// Key for the request
505    pub key: String,
506}
507
508/// Response from the Gemini API for batch content generation (async batch creation)
509#[derive(Debug, Clone, Serialize, Deserialize)]
510#[serde(rename_all = "camelCase")]
511pub struct BatchGenerateContentResponse {
512    /// The name/ID of the created batch
513    pub name: String,
514    /// Metadata about the batch
515    pub metadata: BatchMetadata,
516}
517
518/// Metadata for the batch operation
519#[derive(Debug, Clone, Serialize, Deserialize)]
520#[serde(rename_all = "camelCase")]
521pub struct BatchMetadata {
522    /// Type annotation
523    #[serde(rename = "@type")]
524    pub type_annotation: String,
525    /// Model used for the batch
526    pub model: String,
527    /// Display name of the batch
528    pub display_name: String,
529    /// Creation time
530    pub create_time: String,
531    /// Update time
532    pub update_time: String,
533    /// Batch statistics
534    pub batch_stats: BatchStats,
535    /// Current state of the batch
536    pub state: BatchState,
537    /// Name of the batch (duplicate)
538    pub name: String,
539    /// The output configuration for the batch.
540    pub output: Option<OutputConfig>,
541}
542
543/// The state of a batch operation.
544#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
545#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
546pub enum BatchState {
547    BatchStateUnspecified,
548    BatchStatePending,
549    BatchStateRunning,
550    BatchStateSucceeded,
551    BatchStateFailed,
552    BatchStateCancelled,
553    BatchStateExpired,
554}
555
556/// Statistics for the batch
557#[derive(Debug, Clone, Serialize, Deserialize)]
558#[serde(rename_all = "camelCase")]
559pub struct BatchStats {
560    /// Total number of requests in the batch
561    #[serde(deserialize_with = "from_str_to_i64")]
562    pub request_count: i64,
563    /// Number of pending requests
564    #[serde(default, deserialize_with = "from_str_to_i64_optional")]
565    pub pending_request_count: Option<i64>,
566    /// Number of completed requests
567    #[serde(default, deserialize_with = "from_str_to_i64_optional")]
568    pub completed_request_count: Option<i64>,
569    /// Number of failed requests
570    #[serde(default, deserialize_with = "from_str_to_i64_optional")]
571    pub failed_request_count: Option<i64>,
572    /// Number of successful requests
573    #[serde(default, deserialize_with = "from_str_to_i64_optional")]
574    pub successful_request_count: Option<i64>,
575}
576
577fn from_str_to_i64<'de, D>(deserializer: D) -> Result<i64, D::Error>
578where
579    D: serde::Deserializer<'de>,
580{
581    let s: String = serde::Deserialize::deserialize(deserializer)?;
582    s.parse::<i64>().map_err(serde::de::Error::custom)
583}
584
585fn from_str_to_i64_optional<'de, D>(deserializer: D) -> Result<Option<i64>, D::Error>
586where
587    D: serde::Deserializer<'de>,
588{
589    match Option::<String>::deserialize(deserializer)? {
590        Some(s) => s.parse::<i64>().map(Some).map_err(serde::de::Error::custom),
591        None => Ok(None),
592    }
593}
594
595/// Configuration for thinking (Gemini 2.5 series only)
596#[derive(Debug, Clone, Serialize, Deserialize)]
597#[serde(rename_all = "camelCase")]
598pub struct ThinkingConfig {
599    /// The thinking budget (number of thinking tokens)
600    ///
601    /// - Set to 0 to disable thinking
602    /// - Set to -1 for dynamic thinking (model decides)
603    /// - Set to a positive number for a specific token budget
604    ///
605    /// Model-specific ranges:
606    /// - 2.5 Pro: 128 to 32768 (cannot disable thinking)
607    /// - 2.5 Flash: 0 to 24576
608    /// - 2.5 Flash Lite: 512 to 24576
609    #[serde(skip_serializing_if = "Option::is_none")]
610    pub thinking_budget: Option<i32>,
611
612    /// Whether to include thought summaries in the response
613    ///
614    /// When enabled, the response will include synthesized versions of the model's
615    /// raw thoughts, providing insights into the reasoning process.
616    #[serde(skip_serializing_if = "Option::is_none")]
617    pub include_thoughts: Option<bool>,
618}
619
620impl ThinkingConfig {
621    /// Create a new thinking config with default settings
622    pub fn new() -> Self {
623        Self {
624            thinking_budget: None,
625            include_thoughts: None,
626        }
627    }
628
629    /// Set the thinking budget
630    pub fn with_thinking_budget(mut self, budget: i32) -> Self {
631        self.thinking_budget = Some(budget);
632        self
633    }
634
635    /// Enable dynamic thinking (model decides the budget)
636    pub fn with_dynamic_thinking(mut self) -> Self {
637        self.thinking_budget = Some(-1);
638        self
639    }
640
641    /// Include thought summaries in the response
642    pub fn with_thoughts_included(mut self, include: bool) -> Self {
643        self.include_thoughts = Some(include);
644        self
645    }
646}
647
648impl Default for ThinkingConfig {
649    fn default() -> Self {
650        Self::new()
651    }
652}
653
654/// Configuration for generation
655#[derive(Debug, Clone, Serialize, Deserialize)]
656#[serde(rename_all = "camelCase")]
657pub struct GenerationConfig {
658    /// The temperature for the model (0.0 to 1.0)
659    ///
660    /// Controls the randomness of the output. Higher values (e.g., 0.9) make output
661    /// more random, lower values (e.g., 0.1) make output more deterministic.
662    #[serde(skip_serializing_if = "Option::is_none")]
663    pub temperature: Option<f32>,
664
665    /// The top-p value for the model (0.0 to 1.0)
666    ///
667    /// For each token generation step, the model considers the top_p percentage of
668    /// probability mass for potential token choices. Lower values are more selective,
669    /// higher values allow more variety.
670    #[serde(skip_serializing_if = "Option::is_none")]
671    pub top_p: Option<f32>,
672
673    /// The top-k value for the model
674    ///
675    /// For each token generation step, the model considers the top_k most likely tokens.
676    /// Lower values are more selective, higher values allow more variety.
677    #[serde(skip_serializing_if = "Option::is_none")]
678    pub top_k: Option<i32>,
679
680    /// The maximum number of tokens to generate
681    ///
682    /// Limits the length of the generated content. One token is roughly 4 characters.
683    #[serde(skip_serializing_if = "Option::is_none")]
684    pub max_output_tokens: Option<i32>,
685
686    /// The candidate count
687    ///
688    /// Number of alternative responses to generate.
689    #[serde(skip_serializing_if = "Option::is_none")]
690    pub candidate_count: Option<i32>,
691
692    /// Whether to stop on specific sequences
693    ///
694    /// The model will stop generating content when it encounters any of these sequences.
695    #[serde(skip_serializing_if = "Option::is_none")]
696    pub stop_sequences: Option<Vec<String>>,
697
698    /// The response mime type
699    ///
700    /// Specifies the format of the model's response.
701    #[serde(skip_serializing_if = "Option::is_none")]
702    pub response_mime_type: Option<String>,
703    /// The response schema
704    ///
705    /// Specifies the JSON schema for structured responses.
706    #[serde(skip_serializing_if = "Option::is_none")]
707    pub response_schema: Option<serde_json::Value>,
708
709    /// The thinking configuration
710    ///
711    /// Configuration for the model's thinking process (Gemini 2.5 series only).
712    #[serde(skip_serializing_if = "Option::is_none")]
713    pub thinking_config: Option<ThinkingConfig>,
714}
715
716impl Default for GenerationConfig {
717    fn default() -> Self {
718        Self {
719            temperature: Some(0.7),
720            top_p: Some(0.95),
721            top_k: Some(40),
722            max_output_tokens: Some(1024),
723            candidate_count: Some(1),
724            stop_sequences: None,
725            response_mime_type: None,
726            response_schema: None,
727            thinking_config: None,
728        }
729    }
730}
731
732/// Configuration for tools
733#[derive(Debug, Clone, Serialize, Deserialize)]
734pub struct ToolConfig {
735    /// The function calling config
736    #[serde(skip_serializing_if = "Option::is_none")]
737    pub function_calling_config: Option<FunctionCallingConfig>,
738}
739
740/// Configuration for function calling
741#[derive(Debug, Clone, Serialize, Deserialize)]
742pub struct FunctionCallingConfig {
743    /// The mode for function calling
744    pub mode: FunctionCallingMode,
745}
746
747/// Mode for function calling
748#[derive(Debug, Clone, Serialize, Deserialize)]
749#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
750pub enum FunctionCallingMode {
751    /// The model may use function calling
752    Auto,
753    /// The model must use function calling
754    Any,
755    /// The model must not use function calling
756    None,
757}
758
759/// Setting for safety
760#[derive(Debug, Clone, Serialize, Deserialize)]
761pub struct SafetySetting {
762    /// The category of content to filter
763    pub category: HarmCategory,
764    /// The threshold for filtering
765    pub threshold: HarmBlockThreshold,
766}
767
768/// Category of harmful content
769#[derive(Debug, Clone, Serialize, Deserialize)]
770#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
771pub enum HarmCategory {
772    /// Dangerous content
773    Dangerous,
774    /// Harassment content
775    Harassment,
776    /// Hate speech
777    HateSpeech,
778    /// Sexually explicit content
779    SexuallyExplicit,
780}
781
782/// Threshold for blocking harmful content
783#[allow(clippy::enum_variant_names)]
784#[derive(Debug, Clone, Serialize, Deserialize)]
785#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
786pub enum HarmBlockThreshold {
787    /// Block content with low probability of harm
788    BlockLowAndAbove,
789    /// Block content with medium probability of harm
790    BlockMediumAndAbove,
791    /// Block content with high probability of harm
792    BlockHighAndAbove,
793    /// Block content with maximum probability of harm
794    BlockOnlyHigh,
795    /// Never block content
796    BlockNone,
797}
798
799/// Embedding Task types
800#[derive(Debug, Clone, Serialize, Deserialize)]
801#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
802pub enum TaskType {
803    ///Used to generate embeddings that are optimized to assess text similarity
804    SemanticSimilarity,
805    ///Used to generate embeddings that are optimized to classify texts according to preset labels
806    Classification,
807    ///Used to generate embeddings that are optimized to cluster texts based on their similarities
808    Clustering,
809
810    ///Used to generate embeddings that are optimized for document search or information retrieval.
811    RetrievalDocument,
812    RetrievalQuery,
813    QuestionAnswering,
814    FactVerification,
815
816    /// Used to retrieve a code block based on a natural language query, such as sort an array or reverse a linked list.
817    /// Embeddings of the code blocks are computed using RETRIEVAL_DOCUMENT.
818    CodeRetrievalQuery,
819}
820
821/// Represents the overall status of a batch operation.
822#[derive(Debug, Clone, PartialEq)]
823pub enum BatchStatus {
824    /// The operation is waiting to be processed.
825    Pending,
826    /// The operation is currently being processed.
827    Running {
828        pending_count: i64,
829        completed_count: i64,
830        failed_count: i64,
831        total_count: i64,
832    },
833    /// The operation has completed successfully.
834    Succeeded { results: Vec<BatchResultItem> },
835    /// The operation was cancelled by the user.
836    Cancelled,
837    /// The operation has expired.
838    Expired,
839}
840
841impl BatchStatus {
842    /// Creates a `BatchStatus` from a `BatchOperation` response.
843    pub(crate) fn from_operation(operation: BatchOperation) -> crate::Result<Self> {
844        if operation.done {
845            // The operation is complete. Determine the final state.
846            match operation.result {
847                Some(OperationResult::Failure { error }) => {
848                    Err(crate::Error::BatchFailed {
849                        name: operation.name,
850                        error,
851                    })
852                }
853                Some(OperationResult::Success { response }) => {
854                    let mut results: Vec<BatchResultItem> = response
855                        .inlined_responses
856                        .inlined_responses
857                        .into_iter()
858                        .map(|item| match item {
859                            BatchGenerateContentResponseItem::Success { response, metadata } => {
860                                BatchResultItem::Success {
861                                    key: metadata.key,
862                                    response,
863                                }
864                            }
865                            BatchGenerateContentResponseItem::Error { error, metadata } => {
866                                BatchResultItem::Error {
867                                    key: metadata.key,
868                                    error,
869                                }
870                            }
871                        })
872                        .collect();
873
874                    // Sort results by key to ensure a consistent order.
875                    results.sort_by_key(|item| {
876                        let key_str = match item {
877                            BatchResultItem::Success { key, .. } => key,
878                            BatchResultItem::Error { key, .. } => key,
879                        };
880                        key_str.parse::<usize>().unwrap_or(usize::MAX)
881                    });
882
883                    Ok(BatchStatus::Succeeded { results })
884                }
885                // If `done` is true with no error, a response is expected for success.
886                // If not, it might be a successful cancellation or an inconsistent state.
887                None => match operation.metadata.state {
888                    BatchState::BatchStateCancelled => Ok(BatchStatus::Cancelled),
889                    BatchState::BatchStateExpired => Ok(BatchStatus::Expired),
890                    BatchState::BatchStateSucceeded => Ok(BatchStatus::Succeeded { results: vec![] }), // Succeeded but with no data
891                    _ => Err(crate::Error::InconsistentBatchState {
892                        description: format!(
893                            "Operation is done but has no response or error. Final state is ambiguous: {:?}.",
894                            operation.metadata.state
895                        ),
896                    }),
897                },
898            }
899        } else {
900            // The operation is still in progress.
901            match operation.metadata.state {
902                BatchState::BatchStatePending => Ok(BatchStatus::Pending),
903                BatchState::BatchStateRunning => {
904                    let total_count = operation.metadata.batch_stats.request_count;
905                    let pending_count = operation
906                        .metadata
907                        .batch_stats
908                        .pending_request_count
909                        .unwrap_or(total_count); // Assume all are pending if not specified
910                    let completed_count = operation
911                        .metadata
912                        .batch_stats
913                        .completed_request_count
914                        .unwrap_or(0);
915                    let failed_count = operation
916                        .metadata
917                        .batch_stats
918                        .failed_request_count
919                        .unwrap_or(0);
920                    Ok(BatchStatus::Running {
921                        pending_count,
922                        completed_count,
923                        failed_count,
924                        total_count,
925                    })
926                }
927                // Any other state is inconsistent with `done: false`.
928                terminal_state => Err(crate::Error::InconsistentBatchState {
929                    description: format!(
930                        "Operation is not done, but API reported a terminal state: {:?}.",
931                        terminal_state
932                    ),
933                }),
934            }
935        }
936    }
937}
938
939/// Represents a long-running operation from the Gemini API.
940#[derive(Debug, Serialize, Deserialize)]
941pub struct BatchOperation {
942    pub name: String,
943    pub metadata: BatchMetadata,
944    #[serde(default)]
945    pub done: bool,
946    #[serde(flatten)]
947    pub result: Option<OperationResult>,
948}
949
950/// Represents the result of a completed batch operation, which is either a response or an error.
951#[derive(Debug, Serialize, Deserialize)]
952#[serde(untagged)]
953pub enum OperationResult {
954    Success { response: BatchOperationResponse },
955    Failure { error: crate::error::OperationError },
956}
957
958/// Represents the response of a batch operation.
959#[derive(Debug, Clone, Serialize, Deserialize)]
960#[serde(rename_all = "camelCase")]
961pub struct BatchOperationResponse {
962    #[serde(rename = "@type")]
963    pub type_annotation: String,
964    pub inlined_responses: InlinedResponses,
965}
966
967/// Represents the output configuration of a batch operation.
968#[derive(Debug, Clone, Serialize, Deserialize)]
969#[serde(rename_all = "camelCase")]
970pub struct OutputConfig {
971    pub inlined_responses: InlinedResponses,
972}
973
974/// A container for inlined responses.
975#[derive(Debug, Clone, Serialize, Deserialize)]
976#[serde(rename_all = "camelCase")]
977pub struct InlinedResponses {
978    pub inlined_responses: Vec<BatchGenerateContentResponseItem>,
979}
980
981/// An item in a batch generate content response.
982#[derive(Debug, Clone, Serialize, Deserialize)]
983#[serde(rename_all = "camelCase")]
984#[serde(untagged)]
985pub enum BatchGenerateContentResponseItem {
986    Success {
987        response: GenerationResponse,
988        metadata: RequestMetadata,
989    },
990    Error {
991        error: IndividualRequestError,
992        metadata: RequestMetadata,
993    },
994}
995
996/// An error for an individual request in a batch.
997#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
998pub struct IndividualRequestError {
999    pub code: i32,
1000    pub message: String,
1001    /// Additional details about the error
1002    #[serde(skip_serializing_if = "Option::is_none")]
1003    pub details: Option<serde_json::Value>,
1004}
1005
1006/// The outcome of a single request in a batch operation.
1007#[derive(Debug, Clone, PartialEq)]
1008pub enum BatchResultItem {
1009    Success {
1010        key: String,
1011        response: GenerationResponse,
1012    },
1013    Error {
1014        key: String,
1015        error: IndividualRequestError,
1016    },
1017}
1018
1019/// Response from the Gemini API for listing batch operations.
1020#[derive(Debug, serde::Deserialize)]
1021#[serde(rename_all = "camelCase")]
1022pub struct ListBatchesResponse {
1023    /// A list of batch operations.
1024    pub operations: Vec<BatchOperation>,
1025    /// A token to retrieve the next page of results.
1026    pub next_page_token: Option<String>,
1027}