gemini_rust/
models.rs

1use serde::{Deserialize, Serialize};
2
3/// Role of a message in a conversation
4#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
5#[serde(rename_all = "lowercase")]
6pub enum Role {
7    /// Message from the user
8    User,
9    /// Message from the model
10    Model,
11    /// Function response
12    Function,
13}
14
15/// Content part that can be included in a message
16#[derive(Debug, Clone, Serialize, Deserialize)]
17#[serde(untagged)]
18pub enum Part {
19    /// Text content
20    Text {
21        /// The text content
22        text: String,
23    },
24    /// Function call from the model
25    FunctionCall {
26        /// The function call details
27        #[serde(rename = "functionCall")]
28        function_call: super::tools::FunctionCall,
29    },
30    /// Function response (results from executing a function call)
31    FunctionResponse {
32        /// The function response details
33        #[serde(rename = "functionResponse")]
34        function_response: super::tools::FunctionResponse,
35    },
36}
37
38/// Content of a message
39#[derive(Debug, Default, Clone, Serialize, Deserialize)]
40pub struct Content {
41    /// Parts of the content
42    pub parts: Vec<Part>,
43    /// Role of the content
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub role: Option<Role>,
46}
47
48impl Content {
49    /// Create a new text content
50    pub fn text(text: impl Into<String>) -> Self {
51        Self {
52            parts: vec![Part::Text { text: text.into() }],
53            role: None,
54        }
55    }
56
57    /// Create a new content with a function call
58    pub fn function_call(function_call: super::tools::FunctionCall) -> Self {
59        Self {
60            parts: vec![Part::FunctionCall { function_call }],
61            role: None,
62        }
63    }
64
65    /// Create a new content with a function response
66    pub fn function_response(function_response: super::tools::FunctionResponse) -> Self {
67        Self {
68            parts: vec![Part::FunctionResponse { function_response }],
69            role: None,
70        }
71    }
72
73    /// Create a new content with a function response from name and JSON value
74    pub fn function_response_json(name: impl Into<String>, response: serde_json::Value) -> Self {
75        Self {
76            parts: vec![Part::FunctionResponse {
77                function_response: super::tools::FunctionResponse::new(name, response),
78            }],
79            role: None,
80        }
81    }
82
83    /// Add a role to this content
84    pub fn with_role(mut self, role: Role) -> Self {
85        self.role = Some(role);
86        self
87    }
88}
89
90/// Message in a conversation
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct Message {
93    /// Content of the message
94    pub content: Content,
95    /// Role of the message
96    pub role: Role,
97}
98
99impl Message {
100    /// Create a new user message with text content
101    pub fn user(text: impl Into<String>) -> Self {
102        Self {
103            content: Content::text(text),
104            role: Role::User,
105        }
106    }
107
108    /// Create a new model message with text content
109    pub fn model(text: impl Into<String>) -> Self {
110        Self {
111            content: Content::text(text),
112            role: Role::Model,
113        }
114    }
115
116    /// Create a new function message with function response content from JSON
117    pub fn function(name: impl Into<String>, response: serde_json::Value) -> Self {
118        Self {
119            content: Content::function_response_json(name, response),
120            role: Role::Function,
121        }
122    }
123
124    /// Create a new function message with function response from a JSON string
125    pub fn function_str(
126        name: impl Into<String>,
127        response: impl Into<String>,
128    ) -> Result<Self, serde_json::Error> {
129        let response_str = response.into();
130        let json = serde_json::from_str(&response_str)?;
131        Ok(Self {
132            content: Content::function_response_json(name, json),
133            role: Role::Function,
134        })
135    }
136}
137
138/// Safety rating for content
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct SafetyRating {
141    /// The category of the safety rating
142    pub category: String,
143    /// The probability that the content is harmful
144    pub probability: String,
145}
146
147/// Citation metadata for content
148#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct CitationMetadata {
150    /// The citation sources
151    pub citation_sources: Vec<CitationSource>,
152}
153
154/// Citation source
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct CitationSource {
157    /// The URI of the citation source
158    pub uri: Option<String>,
159    /// The title of the citation source
160    pub title: Option<String>,
161    /// The start index of the citation in the response
162    pub start_index: Option<i32>,
163    /// The end index of the citation in the response
164    pub end_index: Option<i32>,
165    /// The license of the citation source
166    pub license: Option<String>,
167    /// The publication date of the citation source
168    pub publication_date: Option<String>,
169}
170
171/// A candidate response
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct Candidate {
174    /// The content of the candidate
175    pub content: Content,
176    /// The safety ratings for the candidate
177    #[serde(skip_serializing_if = "Option::is_none")]
178    pub safety_ratings: Option<Vec<SafetyRating>>,
179    /// The citation metadata for the candidate
180    #[serde(skip_serializing_if = "Option::is_none")]
181    pub citation_metadata: Option<CitationMetadata>,
182    /// The finish reason for the candidate
183    #[serde(skip_serializing_if = "Option::is_none")]
184    pub finish_reason: Option<String>,
185    /// The tokens used in the response
186    #[serde(skip_serializing_if = "Option::is_none")]
187    pub usage_metadata: Option<UsageMetadata>,
188}
189
190/// Metadata about token usage
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct UsageMetadata {
193    /// The number of prompt tokens
194    pub prompt_token_count: i32,
195    /// The number of response tokens
196    pub candidates_token_count: i32,
197    /// The total number of tokens
198    pub total_token_count: i32,
199}
200
201/// Response from the Gemini API for content generation
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct GenerationResponse {
204    /// The candidates generated
205    pub candidates: Vec<Candidate>,
206    /// The prompt feedback
207    #[serde(skip_serializing_if = "Option::is_none")]
208    pub prompt_feedback: Option<PromptFeedback>,
209    /// Usage metadata
210    #[serde(skip_serializing_if = "Option::is_none")]
211    pub usage_metadata: Option<UsageMetadata>,
212}
213
214/// Feedback about the prompt
215#[derive(Debug, Clone, Serialize, Deserialize)]
216pub struct PromptFeedback {
217    /// The safety ratings for the prompt
218    pub safety_ratings: Vec<SafetyRating>,
219    /// The block reason if the prompt was blocked
220    #[serde(skip_serializing_if = "Option::is_none")]
221    pub block_reason: Option<String>,
222}
223
224impl GenerationResponse {
225    /// Get the text of the first candidate
226    pub fn text(&self) -> String {
227        self.candidates
228            .first()
229            .and_then(|c| {
230                c.content.parts.first().and_then(|p| match p {
231                    Part::Text { text } => Some(text.clone()),
232                    _ => None,
233                })
234            })
235            .unwrap_or_default()
236    }
237
238    /// Get function calls from the response
239    pub fn function_calls(&self) -> Vec<&super::tools::FunctionCall> {
240        self.candidates
241            .iter()
242            .flat_map(|c| {
243                c.content.parts.iter().filter_map(|p| match p {
244                    Part::FunctionCall { function_call } => Some(function_call),
245                    _ => None,
246                })
247            })
248            .collect()
249    }
250}
251
252/// Request to generate content
253#[derive(Debug, Clone, Serialize, Deserialize)]
254pub struct GenerateContentRequest {
255    /// The contents to generate content from
256    pub contents: Vec<Content>,
257    /// The generation config
258    #[serde(skip_serializing_if = "Option::is_none")]
259    pub generation_config: Option<GenerationConfig>,
260    /// The safety settings
261    #[serde(skip_serializing_if = "Option::is_none")]
262    pub safety_settings: Option<Vec<SafetySetting>>,
263    /// The tools that the model can use
264    #[serde(skip_serializing_if = "Option::is_none")]
265    pub tools: Option<Vec<super::tools::Tool>>,
266    /// The tool config
267    #[serde(skip_serializing_if = "Option::is_none")]
268    pub tool_config: Option<ToolConfig>,
269    /// The system instruction
270    #[serde(skip_serializing_if = "Option::is_none")]
271    pub system_instruction: Option<Content>,
272}
273
274/// Configuration for generation
275#[derive(Debug, Clone, Serialize, Deserialize)]
276pub struct GenerationConfig {
277    /// The temperature for the model (0.0 to 1.0)
278    ///
279    /// Controls the randomness of the output. Higher values (e.g., 0.9) make output
280    /// more random, lower values (e.g., 0.1) make output more deterministic.
281    #[serde(skip_serializing_if = "Option::is_none")]
282    pub temperature: Option<f32>,
283
284    /// The top-p value for the model (0.0 to 1.0)
285    ///
286    /// For each token generation step, the model considers the top_p percentage of
287    /// probability mass for potential token choices. Lower values are more selective,
288    /// higher values allow more variety.
289    #[serde(skip_serializing_if = "Option::is_none")]
290    pub top_p: Option<f32>,
291
292    /// The top-k value for the model
293    ///
294    /// For each token generation step, the model considers the top_k most likely tokens.
295    /// Lower values are more selective, higher values allow more variety.
296    #[serde(skip_serializing_if = "Option::is_none")]
297    pub top_k: Option<i32>,
298
299    /// The maximum number of tokens to generate
300    ///
301    /// Limits the length of the generated content. One token is roughly 4 characters.
302    #[serde(skip_serializing_if = "Option::is_none")]
303    pub max_output_tokens: Option<i32>,
304
305    /// The candidate count
306    ///
307    /// Number of alternative responses to generate.
308    #[serde(skip_serializing_if = "Option::is_none")]
309    pub candidate_count: Option<i32>,
310
311    /// Whether to stop on specific sequences
312    ///
313    /// The model will stop generating content when it encounters any of these sequences.
314    #[serde(skip_serializing_if = "Option::is_none")]
315    pub stop_sequences: Option<Vec<String>>,
316
317    /// The response mime type
318    ///
319    /// Specifies the format of the model's response.
320    #[serde(skip_serializing_if = "Option::is_none")]
321    pub response_mime_type: Option<String>,
322
323    /// The response schema
324    ///
325    /// Specifies the JSON schema for structured responses.
326    #[serde(skip_serializing_if = "Option::is_none")]
327    pub response_schema: Option<serde_json::Value>,
328}
329
330impl Default for GenerationConfig {
331    fn default() -> Self {
332        Self {
333            temperature: Some(0.7),
334            top_p: Some(0.95),
335            top_k: Some(40),
336            max_output_tokens: Some(1024),
337            candidate_count: Some(1),
338            stop_sequences: None,
339            response_mime_type: None,
340            response_schema: None,
341        }
342    }
343}
344
345/// Configuration for tools
346#[derive(Debug, Clone, Serialize, Deserialize)]
347pub struct ToolConfig {
348    /// The function calling config
349    #[serde(skip_serializing_if = "Option::is_none")]
350    pub function_calling_config: Option<FunctionCallingConfig>,
351}
352
353/// Configuration for function calling
354#[derive(Debug, Clone, Serialize, Deserialize)]
355pub struct FunctionCallingConfig {
356    /// The mode for function calling
357    pub mode: FunctionCallingMode,
358}
359
360/// Mode for function calling
361#[derive(Debug, Clone, Serialize, Deserialize)]
362#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
363pub enum FunctionCallingMode {
364    /// The model may use function calling
365    Auto,
366    /// The model must use function calling
367    Any,
368    /// The model must not use function calling
369    None,
370}
371
372/// Setting for safety
373#[derive(Debug, Clone, Serialize, Deserialize)]
374pub struct SafetySetting {
375    /// The category of content to filter
376    pub category: HarmCategory,
377    /// The threshold for filtering
378    pub threshold: HarmBlockThreshold,
379}
380
381/// Category of harmful content
382#[derive(Debug, Clone, Serialize, Deserialize)]
383#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
384pub enum HarmCategory {
385    /// Dangerous content
386    Dangerous,
387    /// Harassment content
388    Harassment,
389    /// Hate speech
390    HateSpeech,
391    /// Sexually explicit content
392    SexuallyExplicit,
393}
394
395/// Threshold for blocking harmful content
396#[derive(Debug, Clone, Serialize, Deserialize)]
397#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
398pub enum HarmBlockThreshold {
399    /// Block content with low probability of harm
400    BlockLowAndAbove,
401    /// Block content with medium probability of harm
402    BlockMediumAndAbove,
403    /// Block content with high probability of harm
404    BlockHighAndAbove,
405    /// Block content with maximum probability of harm
406    BlockOnlyHigh,
407    /// Never block content
408    BlockNone,
409}