gemini_rust/
models.rs

1use serde::{Deserialize, Serialize};
2
3/// Role of a message in a conversation
4#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
5#[serde(rename_all = "lowercase")]
6pub enum Role {
7    /// Message from the user
8    User,
9    /// Message from the model
10    Model,
11}
12
13/// Content part that can be included in a message
14#[derive(Debug, Clone, Serialize, Deserialize)]
15#[serde(untagged)]
16pub enum Part {
17    /// Text content
18    Text {
19        /// The text content
20        text: String,
21    },
22    /// Function call from the model
23    FunctionCall {
24        /// The function call details
25        #[serde(rename = "functionCall")]
26        function_call: super::tools::FunctionCall,
27    },
28    /// Function response (results from executing a function call)
29    FunctionResponse {
30        /// The function response details
31        #[serde(rename = "functionResponse")]
32        function_response: super::tools::FunctionResponse,
33    },
34}
35
36/// Content of a message
37#[derive(Debug, Default, Clone, Serialize, Deserialize)]
38pub struct Content {
39    /// Parts of the content
40    pub parts: Vec<Part>,
41    /// Role of the content
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub role: Option<Role>,
44}
45
46impl Content {
47    /// Create a new text content
48    pub fn text(text: impl Into<String>) -> Self {
49        Self {
50            parts: vec![Part::Text { text: text.into() }],
51            role: None,
52        }
53    }
54
55    /// Create a new content with a function call
56    pub fn function_call(function_call: super::tools::FunctionCall) -> Self {
57        Self {
58            parts: vec![Part::FunctionCall { function_call }],
59            role: None,
60        }
61    }
62
63    /// Create a new content with a function response
64    pub fn function_response(function_response: super::tools::FunctionResponse) -> Self {
65        Self {
66            parts: vec![Part::FunctionResponse { function_response }],
67            role: None,
68        }
69    }
70
71    /// Create a new content with a function response from name and JSON value
72    pub fn function_response_json(name: impl Into<String>, response: serde_json::Value) -> Self {
73        Self {
74            parts: vec![Part::FunctionResponse {
75                function_response: super::tools::FunctionResponse::new(name, response),
76            }],
77            role: None,
78        }
79    }
80
81    /// Add a role to this content
82    pub fn with_role(mut self, role: Role) -> Self {
83        self.role = Some(role);
84        self
85    }
86}
87
88/// Message in a conversation
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct Message {
91    /// Content of the message
92    pub content: Content,
93    /// Role of the message
94    pub role: Role,
95}
96
97impl Message {
98    /// Create a new user message with text content
99    pub fn user(text: impl Into<String>) -> Self {
100        Self {
101            content: Content::text(text).with_role(Role::User),
102            role: Role::User,
103        }
104    }
105
106    /// Create a new model message with text content
107    pub fn model(text: impl Into<String>) -> Self {
108        Self {
109            content: Content::text(text).with_role(Role::Model),
110            role: Role::Model,
111        }
112    }
113
114    /// Create a new function message with function response content from JSON
115    pub fn function(name: impl Into<String>, response: serde_json::Value) -> Self {
116        Self {
117            content: Content::function_response_json(name, response).with_role(Role::Model),
118            role: Role::Model,
119        }
120    }
121
122    /// Create a new function message with function response from a JSON string
123    pub fn function_str(
124        name: impl Into<String>,
125        response: impl Into<String>,
126    ) -> Result<Self, serde_json::Error> {
127        let response_str = response.into();
128        let json = serde_json::from_str(&response_str)?;
129        Ok(Self {
130            content: Content::function_response_json(name, json).with_role(Role::Model),
131            role: Role::Model,
132        })
133    }
134}
135
136/// Safety rating for content
137#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct SafetyRating {
139    /// The category of the safety rating
140    pub category: String,
141    /// The probability that the content is harmful
142    pub probability: String,
143}
144
145/// Citation metadata for content
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct CitationMetadata {
148    /// The citation sources
149    pub citation_sources: Vec<CitationSource>,
150}
151
152/// Citation source
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct CitationSource {
155    /// The URI of the citation source
156    pub uri: Option<String>,
157    /// The title of the citation source
158    pub title: Option<String>,
159    /// The start index of the citation in the response
160    pub start_index: Option<i32>,
161    /// The end index of the citation in the response
162    pub end_index: Option<i32>,
163    /// The license of the citation source
164    pub license: Option<String>,
165    /// The publication date of the citation source
166    pub publication_date: Option<String>,
167}
168
169/// A candidate response
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct Candidate {
172    /// The content of the candidate
173    pub content: Content,
174    /// The safety ratings for the candidate
175    #[serde(skip_serializing_if = "Option::is_none")]
176    pub safety_ratings: Option<Vec<SafetyRating>>,
177    /// The citation metadata for the candidate
178    #[serde(skip_serializing_if = "Option::is_none")]
179    pub citation_metadata: Option<CitationMetadata>,
180    /// The finish reason for the candidate
181    #[serde(skip_serializing_if = "Option::is_none")]
182    pub finish_reason: Option<String>,
183    /// The tokens used in the response
184    #[serde(skip_serializing_if = "Option::is_none")]
185    pub usage_metadata: Option<UsageMetadata>,
186}
187
188/// Metadata about token usage
189#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct UsageMetadata {
191    /// The number of prompt tokens
192    pub prompt_token_count: i32,
193    /// The number of response tokens
194    pub candidates_token_count: i32,
195    /// The total number of tokens
196    pub total_token_count: i32,
197}
198
199/// Response from the Gemini API for content generation
200#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct GenerationResponse {
202    /// The candidates generated
203    pub candidates: Vec<Candidate>,
204    /// The prompt feedback
205    #[serde(skip_serializing_if = "Option::is_none")]
206    pub prompt_feedback: Option<PromptFeedback>,
207    /// Usage metadata
208    #[serde(skip_serializing_if = "Option::is_none")]
209    pub usage_metadata: Option<UsageMetadata>,
210}
211
212/// Feedback about the prompt
213#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct PromptFeedback {
215    /// The safety ratings for the prompt
216    pub safety_ratings: Vec<SafetyRating>,
217    /// The block reason if the prompt was blocked
218    #[serde(skip_serializing_if = "Option::is_none")]
219    pub block_reason: Option<String>,
220}
221
222impl GenerationResponse {
223    /// Get the text of the first candidate
224    pub fn text(&self) -> String {
225        self.candidates
226            .first()
227            .and_then(|c| {
228                c.content.parts.first().and_then(|p| match p {
229                    Part::Text { text } => Some(text.clone()),
230                    _ => None,
231                })
232            })
233            .unwrap_or_default()
234    }
235
236    /// Get function calls from the response
237    pub fn function_calls(&self) -> Vec<&super::tools::FunctionCall> {
238        self.candidates
239            .iter()
240            .flat_map(|c| {
241                c.content.parts.iter().filter_map(|p| match p {
242                    Part::FunctionCall { function_call } => Some(function_call),
243                    _ => None,
244                })
245            })
246            .collect()
247    }
248}
249
250/// Request to generate content
251#[derive(Debug, Clone, Serialize, Deserialize)]
252pub struct GenerateContentRequest {
253    /// The contents to generate content from
254    pub contents: Vec<Content>,
255    /// The generation config
256    #[serde(skip_serializing_if = "Option::is_none")]
257    pub generation_config: Option<GenerationConfig>,
258    /// The safety settings
259    #[serde(skip_serializing_if = "Option::is_none")]
260    pub safety_settings: Option<Vec<SafetySetting>>,
261    /// The tools that the model can use
262    #[serde(skip_serializing_if = "Option::is_none")]
263    pub tools: Option<Vec<super::tools::Tool>>,
264    /// The tool config
265    #[serde(skip_serializing_if = "Option::is_none")]
266    pub tool_config: Option<ToolConfig>,
267    /// The system instruction
268    #[serde(skip_serializing_if = "Option::is_none")]
269    pub system_instruction: Option<Content>,
270}
271
272/// Configuration for generation
273#[derive(Debug, Clone, Serialize, Deserialize)]
274pub struct GenerationConfig {
275    /// The temperature for the model (0.0 to 1.0)
276    ///
277    /// Controls the randomness of the output. Higher values (e.g., 0.9) make output
278    /// more random, lower values (e.g., 0.1) make output more deterministic.
279    #[serde(skip_serializing_if = "Option::is_none")]
280    pub temperature: Option<f32>,
281
282    /// The top-p value for the model (0.0 to 1.0)
283    ///
284    /// For each token generation step, the model considers the top_p percentage of
285    /// probability mass for potential token choices. Lower values are more selective,
286    /// higher values allow more variety.
287    #[serde(skip_serializing_if = "Option::is_none")]
288    pub top_p: Option<f32>,
289
290    /// The top-k value for the model
291    ///
292    /// For each token generation step, the model considers the top_k most likely tokens.
293    /// Lower values are more selective, higher values allow more variety.
294    #[serde(skip_serializing_if = "Option::is_none")]
295    pub top_k: Option<i32>,
296
297    /// The maximum number of tokens to generate
298    ///
299    /// Limits the length of the generated content. One token is roughly 4 characters.
300    #[serde(skip_serializing_if = "Option::is_none")]
301    pub max_output_tokens: Option<i32>,
302
303    /// The candidate count
304    ///
305    /// Number of alternative responses to generate.
306    #[serde(skip_serializing_if = "Option::is_none")]
307    pub candidate_count: Option<i32>,
308
309    /// Whether to stop on specific sequences
310    ///
311    /// The model will stop generating content when it encounters any of these sequences.
312    #[serde(skip_serializing_if = "Option::is_none")]
313    pub stop_sequences: Option<Vec<String>>,
314
315    /// The response mime type
316    ///
317    /// Specifies the format of the model's response.
318    #[serde(skip_serializing_if = "Option::is_none")]
319    pub response_mime_type: Option<String>,
320
321    /// The response schema
322    ///
323    /// Specifies the JSON schema for structured responses.
324    #[serde(skip_serializing_if = "Option::is_none")]
325    pub response_schema: Option<serde_json::Value>,
326}
327
328impl Default for GenerationConfig {
329    fn default() -> Self {
330        Self {
331            temperature: Some(0.7),
332            top_p: Some(0.95),
333            top_k: Some(40),
334            max_output_tokens: Some(1024),
335            candidate_count: Some(1),
336            stop_sequences: None,
337            response_mime_type: None,
338            response_schema: None,
339        }
340    }
341}
342
343/// Configuration for tools
344#[derive(Debug, Clone, Serialize, Deserialize)]
345pub struct ToolConfig {
346    /// The function calling config
347    #[serde(skip_serializing_if = "Option::is_none")]
348    pub function_calling_config: Option<FunctionCallingConfig>,
349}
350
351/// Configuration for function calling
352#[derive(Debug, Clone, Serialize, Deserialize)]
353pub struct FunctionCallingConfig {
354    /// The mode for function calling
355    pub mode: FunctionCallingMode,
356}
357
358/// Mode for function calling
359#[derive(Debug, Clone, Serialize, Deserialize)]
360#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
361pub enum FunctionCallingMode {
362    /// The model may use function calling
363    Auto,
364    /// The model must use function calling
365    Any,
366    /// The model must not use function calling
367    None,
368}
369
370/// Setting for safety
371#[derive(Debug, Clone, Serialize, Deserialize)]
372pub struct SafetySetting {
373    /// The category of content to filter
374    pub category: HarmCategory,
375    /// The threshold for filtering
376    pub threshold: HarmBlockThreshold,
377}
378
379/// Category of harmful content
380#[derive(Debug, Clone, Serialize, Deserialize)]
381#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
382pub enum HarmCategory {
383    /// Dangerous content
384    Dangerous,
385    /// Harassment content
386    Harassment,
387    /// Hate speech
388    HateSpeech,
389    /// Sexually explicit content
390    SexuallyExplicit,
391}
392
393/// Threshold for blocking harmful content
394#[allow(clippy::enum_variant_names)]
395#[derive(Debug, Clone, Serialize, Deserialize)]
396#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
397pub enum HarmBlockThreshold {
398    /// Block content with low probability of harm
399    BlockLowAndAbove,
400    /// Block content with medium probability of harm
401    BlockMediumAndAbove,
402    /// Block content with high probability of harm
403    BlockHighAndAbove,
404    /// Block content with maximum probability of harm
405    BlockOnlyHigh,
406    /// Never block content
407    BlockNone,
408}