gemini_rust/
models.rs

1use serde::{Deserialize, Serialize};
2
3/// Role of a message in a conversation
4#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
5#[serde(rename_all = "lowercase")]
6pub enum Role {
7    /// Message from the user
8    User,
9    /// Message from the model
10    Model,
11}
12
13/// Content part that can be included in a message
14#[derive(Debug, Clone, Serialize, Deserialize)]
15#[serde(untagged)]
16pub enum Part {
17    /// Text content
18    Text {
19        /// The text content
20        text: String,
21    },
22    InlineData {
23        /// The blob data
24        #[serde(rename = "inlineData")]
25        inline_data: Blob,
26    },
27    /// Function call from the model
28    FunctionCall {
29        /// The function call details
30        #[serde(rename = "functionCall")]
31        function_call: super::tools::FunctionCall,
32    },
33    /// Function response (results from executing a function call)
34    FunctionResponse {
35        /// The function response details
36        #[serde(rename = "functionResponse")]
37        function_response: super::tools::FunctionResponse,
38    },
39}
40
41/// Blob for a message part
42#[derive(Debug, Clone, Serialize, Deserialize)]
43#[serde(rename_all = "camelCase")]
44pub struct Blob {
45    pub mime_type: String,
46    pub data: String, // Base64 encoded data
47}
48
49impl Blob {
50    /// Create a new blob with mime type and data
51    pub fn new(mime_type: impl Into<String>, data: impl Into<String>) -> Self {
52        Self {
53            mime_type: mime_type.into(),
54            data: data.into(),
55        }
56    }
57}
58
59/// Content of a message
60#[derive(Debug, Default, Clone, Serialize, Deserialize)]
61pub struct Content {
62    /// Parts of the content
63    pub parts: Vec<Part>,
64    /// Role of the content
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub role: Option<Role>,
67}
68
69impl Content {
70    /// Create a new text content
71    pub fn text(text: impl Into<String>) -> Self {
72        Self {
73            parts: vec![Part::Text { text: text.into() }],
74            role: None,
75        }
76    }
77
78    /// Create a new content with a function call
79    pub fn function_call(function_call: super::tools::FunctionCall) -> Self {
80        Self {
81            parts: vec![Part::FunctionCall { function_call }],
82            role: None,
83        }
84    }
85
86    /// Create a new content with a function response
87    pub fn function_response(function_response: super::tools::FunctionResponse) -> Self {
88        Self {
89            parts: vec![Part::FunctionResponse { function_response }],
90            role: None,
91        }
92    }
93
94    /// Create a new content with a function response from name and JSON value
95    pub fn function_response_json(name: impl Into<String>, response: serde_json::Value) -> Self {
96        Self {
97            parts: vec![Part::FunctionResponse {
98                function_response: super::tools::FunctionResponse::new(name, response),
99            }],
100            role: None,
101        }
102    }
103
104    /// Create a new content with inline data (blob data)
105    pub fn inline_data(mime_type: impl Into<String>, data: impl Into<String>) -> Self {
106        Self {
107            parts: vec![Part::InlineData {
108                inline_data: Blob::new(mime_type, data),
109            }],
110            role: None,
111        }
112    }
113
114    /// Add a role to this content
115    pub fn with_role(mut self, role: Role) -> Self {
116        self.role = Some(role);
117        self
118    }
119}
120
121/// Message in a conversation
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct Message {
124    /// Content of the message
125    pub content: Content,
126    /// Role of the message
127    pub role: Role,
128}
129
130impl Message {
131    /// Create a new user message with text content
132    pub fn user(text: impl Into<String>) -> Self {
133        Self {
134            content: Content::text(text).with_role(Role::User),
135            role: Role::User,
136        }
137    }
138
139    /// Create a new model message with text content
140    pub fn model(text: impl Into<String>) -> Self {
141        Self {
142            content: Content::text(text).with_role(Role::Model),
143            role: Role::Model,
144        }
145    }
146
147    pub fn embed(text: impl Into<String>) -> Self {
148        Self {
149            content: Content::text(text),
150            role: Role::Model,
151        }
152    }
153
154    /// Create a new function message with function response content from JSON
155    pub fn function(name: impl Into<String>, response: serde_json::Value) -> Self {
156        Self {
157            content: Content::function_response_json(name, response).with_role(Role::Model),
158            role: Role::Model,
159        }
160    }
161
162    /// Create a new function message with function response from a JSON string
163    pub fn function_str(
164        name: impl Into<String>,
165        response: impl Into<String>,
166    ) -> Result<Self, serde_json::Error> {
167        let response_str = response.into();
168        let json = serde_json::from_str(&response_str)?;
169        Ok(Self {
170            content: Content::function_response_json(name, json).with_role(Role::Model),
171            role: Role::Model,
172        })
173    }
174}
175
176/// Safety rating for content
177#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct SafetyRating {
179    /// The category of the safety rating
180    pub category: String,
181    /// The probability that the content is harmful
182    pub probability: String,
183}
184
185/// Citation metadata for content
186#[derive(Debug, Clone, Serialize, Deserialize)]
187#[serde(rename_all = "camelCase")]
188pub struct CitationMetadata {
189    /// The citation sources
190    pub citation_sources: Vec<CitationSource>,
191}
192
193/// Citation source
194#[derive(Debug, Clone, Serialize, Deserialize)]
195#[serde(rename_all = "camelCase")]
196pub struct CitationSource {
197    /// The URI of the citation source
198    pub uri: Option<String>,
199    /// The title of the citation source
200    pub title: Option<String>,
201    /// The start index of the citation in the response
202    pub start_index: Option<i32>,
203    /// The end index of the citation in the response
204    pub end_index: Option<i32>,
205    /// The license of the citation source
206    pub license: Option<String>,
207    /// The publication date of the citation source
208    pub publication_date: Option<String>,
209}
210
211/// A candidate response
212#[derive(Debug, Clone, Serialize, Deserialize)]
213#[serde(rename_all = "camelCase")]
214pub struct Candidate {
215    /// The content of the candidate
216    pub content: Content,
217    /// The safety ratings for the candidate
218    #[serde(skip_serializing_if = "Option::is_none")]
219    pub safety_ratings: Option<Vec<SafetyRating>>,
220    /// The citation metadata for the candidate
221    #[serde(skip_serializing_if = "Option::is_none")]
222    pub citation_metadata: Option<CitationMetadata>,
223    /// The finish reason for the candidate
224    #[serde(skip_serializing_if = "Option::is_none")]
225    pub finish_reason: Option<String>,
226    /// The tokens used in the response
227    #[serde(skip_serializing_if = "Option::is_none")]
228    pub usage_metadata: Option<UsageMetadata>,
229}
230
231/// Metadata about token usage
232#[derive(Debug, Clone, Serialize, Deserialize)]
233#[serde(rename_all = "camelCase")]
234pub struct UsageMetadata {
235    /// The number of prompt tokens
236    pub prompt_token_count: i32,
237    /// The number of response tokens
238    pub candidates_token_count: i32,
239    /// The total number of tokens
240    pub total_token_count: i32,
241}
242
243/// Response from the Gemini API for content generation
244#[derive(Debug, Clone, Serialize, Deserialize)]
245#[serde(rename_all = "camelCase")]
246pub struct GenerationResponse {
247    /// The candidates generated
248    pub candidates: Vec<Candidate>,
249    /// The prompt feedback
250    #[serde(skip_serializing_if = "Option::is_none")]
251    pub prompt_feedback: Option<PromptFeedback>,
252    /// Usage metadata
253    #[serde(skip_serializing_if = "Option::is_none")]
254    pub usage_metadata: Option<UsageMetadata>,
255}
256
257/// Content of the embedding
258#[derive(Debug, Clone, Serialize, Deserialize)]
259pub struct ContentEmbedding {
260    /// The values generated
261    pub values: Vec<f32>, //Maybe Quantize this
262}
263
264/// Response from the Gemini API for content embedding
265#[derive(Debug, Clone, Serialize, Deserialize)]
266pub struct ContentEmbeddingResponse {
267    /// The embeddings generated
268    pub embedding: ContentEmbedding,
269}
270
271/// Response from the Gemini API for batch content embedding
272#[derive(Debug, Clone, Serialize, Deserialize)]
273pub struct BatchContentEmbeddingResponse {
274    /// The embeddings generated
275    pub embeddings: Vec<ContentEmbedding>,
276}
277
278/// Feedback about the prompt
279#[derive(Debug, Clone, Serialize, Deserialize)]
280#[serde(rename_all = "camelCase")]
281pub struct PromptFeedback {
282    /// The safety ratings for the prompt
283    pub safety_ratings: Vec<SafetyRating>,
284    /// The block reason if the prompt was blocked
285    #[serde(skip_serializing_if = "Option::is_none")]
286    pub block_reason: Option<String>,
287}
288
289impl GenerationResponse {
290    /// Get the text of the first candidate
291    pub fn text(&self) -> String {
292        self.candidates
293            .first()
294            .and_then(|c| {
295                c.content.parts.first().and_then(|p| match p {
296                    Part::Text { text } => Some(text.clone()),
297                    _ => None,
298                })
299            })
300            .unwrap_or_default()
301    }
302
303    /// Get function calls from the response
304    pub fn function_calls(&self) -> Vec<&super::tools::FunctionCall> {
305        self.candidates
306            .iter()
307            .flat_map(|c| {
308                c.content.parts.iter().filter_map(|p| match p {
309                    Part::FunctionCall { function_call } => Some(function_call),
310                    _ => None,
311                })
312            })
313            .collect()
314    }
315}
316
317/// Request to generate content
318#[derive(Debug, Clone, Serialize, Deserialize)]
319pub struct GenerateContentRequest {
320    /// The contents to generate content from
321    pub contents: Vec<Content>,
322    /// The generation config
323    #[serde(skip_serializing_if = "Option::is_none")]
324    pub generation_config: Option<GenerationConfig>,
325    /// The safety settings
326    #[serde(skip_serializing_if = "Option::is_none")]
327    pub safety_settings: Option<Vec<SafetySetting>>,
328    /// The tools that the model can use
329    #[serde(skip_serializing_if = "Option::is_none")]
330    pub tools: Option<Vec<super::tools::Tool>>,
331    /// The tool config
332    #[serde(skip_serializing_if = "Option::is_none")]
333    pub tool_config: Option<ToolConfig>,
334    /// The system instruction
335    #[serde(skip_serializing_if = "Option::is_none")]
336    pub system_instruction: Option<Content>,
337}
338
339/// Request to embed words
340#[derive(Debug, Clone, Serialize, Deserialize)]
341pub struct EmbedContentRequest {
342    /// The specified embedding model
343    pub model: String,
344    /// The chunks content to generate embeddings
345    pub content: Content,
346    /// The embedding task type (optional)
347    #[serde(skip_serializing_if = "Option::is_none")]
348    pub task_type: Option<TaskType>,
349    /// The title of the document (optional)
350    #[serde(skip_serializing_if = "Option::is_none")]
351    pub title: Option<String>,
352    /// The output_dimensionality (optional)
353    #[serde(skip_serializing_if = "Option::is_none")]
354    pub output_dimensionality: Option<i32>,
355}
356
357/// Request to batch embed requests
358#[derive(Debug, Clone, Serialize, Deserialize)]
359pub struct BatchEmbedContentsRequest {
360    /// The list of embed requests
361    pub requests: Vec<EmbedContentRequest>,
362}
363
364/// Configuration for generation
365#[derive(Debug, Clone, Serialize, Deserialize)]
366pub struct GenerationConfig {
367    /// The temperature for the model (0.0 to 1.0)
368    ///
369    /// Controls the randomness of the output. Higher values (e.g., 0.9) make output
370    /// more random, lower values (e.g., 0.1) make output more deterministic.
371    #[serde(skip_serializing_if = "Option::is_none")]
372    pub temperature: Option<f32>,
373
374    /// The top-p value for the model (0.0 to 1.0)
375    ///
376    /// For each token generation step, the model considers the top_p percentage of
377    /// probability mass for potential token choices. Lower values are more selective,
378    /// higher values allow more variety.
379    #[serde(skip_serializing_if = "Option::is_none")]
380    pub top_p: Option<f32>,
381
382    /// The top-k value for the model
383    ///
384    /// For each token generation step, the model considers the top_k most likely tokens.
385    /// Lower values are more selective, higher values allow more variety.
386    #[serde(skip_serializing_if = "Option::is_none")]
387    pub top_k: Option<i32>,
388
389    /// The maximum number of tokens to generate
390    ///
391    /// Limits the length of the generated content. One token is roughly 4 characters.
392    #[serde(skip_serializing_if = "Option::is_none")]
393    pub max_output_tokens: Option<i32>,
394
395    /// The candidate count
396    ///
397    /// Number of alternative responses to generate.
398    #[serde(skip_serializing_if = "Option::is_none")]
399    pub candidate_count: Option<i32>,
400
401    /// Whether to stop on specific sequences
402    ///
403    /// The model will stop generating content when it encounters any of these sequences.
404    #[serde(skip_serializing_if = "Option::is_none")]
405    pub stop_sequences: Option<Vec<String>>,
406
407    /// The response mime type
408    ///
409    /// Specifies the format of the model's response.
410    #[serde(skip_serializing_if = "Option::is_none")]
411    pub response_mime_type: Option<String>,
412
413    /// The response schema
414    ///
415    /// Specifies the JSON schema for structured responses.
416    #[serde(skip_serializing_if = "Option::is_none")]
417    pub response_schema: Option<serde_json::Value>,
418}
419
420impl Default for GenerationConfig {
421    fn default() -> Self {
422        Self {
423            temperature: Some(0.7),
424            top_p: Some(0.95),
425            top_k: Some(40),
426            max_output_tokens: Some(1024),
427            candidate_count: Some(1),
428            stop_sequences: None,
429            response_mime_type: None,
430            response_schema: None,
431        }
432    }
433}
434
435/// Configuration for tools
436#[derive(Debug, Clone, Serialize, Deserialize)]
437pub struct ToolConfig {
438    /// The function calling config
439    #[serde(skip_serializing_if = "Option::is_none")]
440    pub function_calling_config: Option<FunctionCallingConfig>,
441}
442
443/// Configuration for function calling
444#[derive(Debug, Clone, Serialize, Deserialize)]
445pub struct FunctionCallingConfig {
446    /// The mode for function calling
447    pub mode: FunctionCallingMode,
448}
449
450/// Mode for function calling
451#[derive(Debug, Clone, Serialize, Deserialize)]
452#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
453pub enum FunctionCallingMode {
454    /// The model may use function calling
455    Auto,
456    /// The model must use function calling
457    Any,
458    /// The model must not use function calling
459    None,
460}
461
462/// Setting for safety
463#[derive(Debug, Clone, Serialize, Deserialize)]
464pub struct SafetySetting {
465    /// The category of content to filter
466    pub category: HarmCategory,
467    /// The threshold for filtering
468    pub threshold: HarmBlockThreshold,
469}
470
471/// Category of harmful content
472#[derive(Debug, Clone, Serialize, Deserialize)]
473#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
474pub enum HarmCategory {
475    /// Dangerous content
476    Dangerous,
477    /// Harassment content
478    Harassment,
479    /// Hate speech
480    HateSpeech,
481    /// Sexually explicit content
482    SexuallyExplicit,
483}
484
485/// Threshold for blocking harmful content
486#[allow(clippy::enum_variant_names)]
487#[derive(Debug, Clone, Serialize, Deserialize)]
488#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
489pub enum HarmBlockThreshold {
490    /// Block content with low probability of harm
491    BlockLowAndAbove,
492    /// Block content with medium probability of harm
493    BlockMediumAndAbove,
494    /// Block content with high probability of harm
495    BlockHighAndAbove,
496    /// Block content with maximum probability of harm
497    BlockOnlyHigh,
498    /// Never block content
499    BlockNone,
500}
501
502/// Embedding Task types
503#[derive(Debug, Clone, Serialize, Deserialize)]
504#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
505pub enum TaskType {
506    ///Used to generate embeddings that are optimized to assess text similarity
507    SemanticSimilarity,
508    ///Used to generate embeddings that are optimized to classify texts according to preset labels
509    Classification,
510    ///Used to generate embeddings that are optimized to cluster texts based on their similarities
511    Clustering,
512
513    ///Used to generate embeddings that are optimized for document search or information retrieval.
514    RetrievalDocument,
515    RetrievalQuery,
516    QuestionAnswering,
517    FactVerification,
518
519    /// Used to retrieve a code block based on a natural language query, such as sort an array or reverse a linked list.
520    /// Embeddings of the code blocks are computed using RETRIEVAL_DOCUMENT.
521    CodeRetrievalQuery,
522}