Skip to main content

google_gemini_rs/google/
request.rs

1//! Request types and wrappers for Google AI Models. See: https://ai.google.dev/api/generate-content
2
3use std::collections::HashMap;
4
5use rust_mcp_sdk::{error::McpSdkError, schema::ToolInputSchema};
6use serde::{Deserialize, Serialize};
7use serde_json::{Value, json};
8use thiserror::Error;
9
10use super::common::{Content, HarmCategory, Modality};
11
12#[derive(Debug, Error)]
13pub enum Error {
14    #[error(transparent)]
15    McpSdk(#[from] McpSdkError),
16    #[error("{0}")]
17    NotFound(String),
18    #[error(transparent)]
19    Serde(#[from] serde_json::Error),
20}
21
22#[derive(Clone, Debug, Serialize, Deserialize, Default)]
23#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
24pub enum Type {
25    #[serde(alias = "typeunspecified")]
26    #[default]
27    TypeUnspecified,
28    #[serde(alias = "string")]
29    String,
30    #[serde(alias = "number")]
31    Number,
32    #[serde(alias = "integer")]
33    Integer,
34    #[serde(alias = "boolean")]
35    Boolean,
36    #[serde(alias = "array")]
37    Array,
38    #[serde(alias = "object")]
39    Object,
40    #[serde(alias = "null")]
41    Null,
42}
43
44#[derive(Clone, Debug, Serialize, Deserialize, Default)]
45#[serde(rename_all = "camelCase")]
46pub struct Schema {
47    pub r#type: Type,
48    #[serde(default, skip_serializing_if = "Option::is_none")]
49    pub format: Option<String>,
50    #[serde(default, skip_serializing_if = "Option::is_none")]
51    pub title: Option<String>,
52    #[serde(default, skip_serializing_if = "Option::is_none")]
53    pub description: Option<String>,
54    #[serde(default, skip_serializing_if = "Option::is_none")]
55    pub nullable: Option<bool>,
56    #[serde(default, skip_serializing_if = "Vec::is_empty")]
57    pub r#enum: Vec<String>,
58    #[serde(default, skip_serializing_if = "Option::is_none")]
59    pub max_items: Option<String>,
60    #[serde(default, skip_serializing_if = "Option::is_none")]
61    pub min_items: Option<String>,
62    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
63    pub properties: HashMap<String, Schema>,
64    #[serde(default, skip_serializing_if = "Vec::is_empty")]
65    pub required: Vec<String>,
66    #[serde(default, skip_serializing_if = "Option::is_none")]
67    pub min_properties: Option<String>,
68    #[serde(default, skip_serializing_if = "Option::is_none")]
69    pub max_properties: Option<String>,
70    #[serde(default, skip_serializing_if = "Option::is_none")]
71    pub min_length: Option<String>,
72    #[serde(default, skip_serializing_if = "Option::is_none")]
73    pub max_length: Option<String>,
74    #[serde(default, skip_serializing_if = "Option::is_none")]
75    pub pattern: Option<String>,
76    #[serde(default, skip_serializing_if = "Option::is_none")]
77    pub example: Option<Value>,
78    #[serde(default, skip_serializing_if = "Vec::is_empty")]
79    pub any_of: Vec<Schema>,
80    #[serde(default, skip_serializing_if = "Vec::is_empty")]
81    pub property_ordering: Vec<String>,
82    #[serde(default, skip_serializing_if = "Option::is_none")]
83    pub default: Option<Value>,
84    #[serde(default, skip_serializing_if = "Option::is_none")]
85    pub items: Option<Box<Schema>>,
86    #[serde(default, skip_serializing_if = "Option::is_none")]
87    pub minimum: Option<f32>,
88    #[serde(default, skip_serializing_if = "Option::is_none")]
89    pub maximum: Option<f32>,
90}
91
92impl TryFrom<ToolInputSchema> for Schema {
93    type Error = Error;
94
95    fn try_from(value: ToolInputSchema) -> Result<Self, Error> {
96        // Behold the power of serde: convert the MCP tool schema to the
97        // Gemini tool schema.
98        Ok(serde_json::from_value::<Schema>(json!(value))?)
99    }
100}
101
102#[derive(Clone, Debug, Serialize, Deserialize)]
103#[serde(rename_all = "camelCase")]
104pub struct FunctionDeclaration {
105    pub name: String,
106    pub description: String,
107    #[serde(skip_serializing_if = "Option::is_none")]
108    pub parameters: Option<Schema>,
109    #[serde(skip_serializing_if = "Option::is_none")]
110    pub response: Option<Schema>,
111}
112
113pub fn map_fn_name(index: usize, name: &str) -> String {
114    format!("{index}_{name}")
115}
116
117pub fn unmap_fn_name(name: &str) -> Result<String, Error> {
118    Ok(name
119        .split_once('_')
120        .ok_or_else(|| Error::NotFound("Function name: {name}".to_string()))?
121        .1
122        .to_string())
123}
124
125impl From<&rust_mcp_sdk::schema::Tool> for FunctionDeclaration {
126    fn from(value: &rust_mcp_sdk::schema::Tool) -> Self {
127        Self {
128            name: value.name.clone(),
129            description: value
130                .description
131                .clone()
132                .unwrap_or_else(|| "None".to_string()),
133            parameters: value.input_schema.clone().try_into().ok(),
134            response: None,
135        }
136    }
137}
138
139#[derive(Clone, Debug, Serialize, Deserialize)]
140#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
141pub enum Mode {
142    ModeUnspecified,
143    ModeDynamic,
144}
145
146#[derive(Clone, Debug, Serialize, Deserialize)]
147#[serde(rename_all = "camelCase")]
148pub struct DynamicRetrievalConfig {
149    pub mode: Mode,
150    pub dynamic_threshold: i32,
151}
152
153#[derive(Clone, Debug, Serialize, Deserialize)]
154#[serde(rename_all = "camelCase")]
155pub struct GoogleSearchRetrieval {
156    pub dynamic_retrieval_config: DynamicRetrievalConfig,
157}
158
159#[derive(Clone, Debug, Serialize, Deserialize)]
160#[serde(rename_all = "camelCase")]
161pub struct UrlContext {}
162
163#[derive(Clone, Debug, Serialize, Deserialize)]
164#[serde(rename_all = "camelCase")]
165pub struct Tool {
166    #[serde(skip_serializing_if = "Vec::is_empty")]
167    pub function_declarations: Vec<FunctionDeclaration>,
168    #[serde(skip_serializing_if = "Option::is_none")]
169    pub google_search_retrieval: Option<GoogleSearchRetrieval>,
170    #[serde(skip_serializing_if = "Option::is_none")]
171    pub code_execution: Option<Value>,
172    #[serde(skip_serializing_if = "Option::is_none")]
173    pub google_search: Option<Value>,
174    #[serde(skip_serializing_if = "Option::is_none")]
175    pub url_context: Option<UrlContext>,
176}
177
178impl From<Vec<rust_mcp_sdk::schema::Tool>> for Tool {
179    fn from(value: Vec<rust_mcp_sdk::schema::Tool>) -> Self {
180        Self {
181            function_declarations: value.iter().map(|t| t.into()).collect(),
182            google_search_retrieval: None,
183            code_execution: None,
184            google_search: None,
185            url_context: None,
186        }
187    }
188}
189
190#[derive(Clone, Debug, Serialize, Deserialize)]
191#[serde(rename_all = "camelCase")]
192pub struct FunctionCallingConfig {
193    #[serde(skip_serializing_if = "Option::is_none")]
194    pub mode: Option<Mode>,
195    #[serde(skip_serializing_if = "Vec::is_empty")]
196    pub allowed_function_names: Vec<String>,
197}
198
199#[derive(Clone, Debug, Serialize, Deserialize)]
200#[serde(rename_all = "camelCase")]
201pub struct ToolConfig {
202    #[serde(skip_serializing_if = "Option::is_none")]
203    pub function_calling_config: Option<FunctionCallingConfig>,
204}
205
206#[derive(Debug, Clone, Serialize, Deserialize, Default)]
207#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
208pub enum HarmBlockThreshold {
209    BlockNone,
210    BlockOnlyHigh,
211    BlockMediumAndAbove,
212    #[default]
213    BlockLowAndAbove,
214    HarmBlockThresholdUnspecified,
215    Off,
216}
217
218#[derive(Clone, Debug, Serialize, Deserialize)]
219#[serde(rename_all = "camelCase")]
220pub struct SafetySettings {
221    pub category: HarmCategory,
222    pub threshold: HarmBlockThreshold,
223}
224
225#[derive(Clone, Debug, Serialize, Deserialize)]
226#[serde(rename_all = "camelCase")]
227pub struct PrebuiltVoiceConfig {
228    pub voice_name: String,
229}
230
231#[derive(Clone, Debug, Serialize, Deserialize)]
232#[serde(rename_all = "camelCase")]
233pub struct VoiceConfig {
234    pub prebuilt_voice_config: PrebuiltVoiceConfig,
235}
236
237#[derive(Clone, Debug, Serialize, Deserialize)]
238#[serde(rename_all = "camelCase")]
239pub struct SpeechConfig {
240    pub voice_config: VoiceConfig,
241    #[serde(skip_serializing_if = "Option::is_none")]
242    pub language_code: Option<String>,
243}
244
245#[derive(Clone, Debug, Serialize, Deserialize)]
246#[serde(rename_all = "camelCase")]
247pub struct ThinkingConfig {
248    pub include_thoughts: bool,
249    pub thinking_budget: i32,
250}
251
252#[derive(Clone, Debug, Serialize, Deserialize)]
253#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
254pub enum MediaResolution {
255    MediaResolutionUnspecified,
256    MediaResolutionLow,
257    MediaResolutionMedium,
258    MediaResolutionHigh,
259}
260
261#[derive(Clone, Debug, Serialize, Deserialize, Default)]
262#[serde(rename_all = "camelCase")]
263pub struct GenerationConfig {
264    #[serde(skip_serializing_if = "Vec::is_empty")]
265    pub stop_sequences: Vec<String>,
266    #[serde(skip_serializing_if = "Option::is_none")]
267    pub response_mime_type: Option<String>,
268    #[serde(skip_serializing_if = "Option::is_none")]
269    pub response_schema: Option<Schema>,
270    #[serde(skip_serializing_if = "Vec::is_empty")]
271    pub response_modalities: Vec<Modality>,
272    #[serde(skip_serializing_if = "Option::is_none")]
273    pub candidate_count: Option<i32>,
274    #[serde(skip_serializing_if = "Option::is_none")]
275    pub max_output_tokens: Option<i32>,
276    #[serde(skip_serializing_if = "Option::is_none")]
277    pub temperature: Option<f32>,
278    #[serde(skip_serializing_if = "Option::is_none")]
279    pub top_p: Option<f32>,
280    #[serde(skip_serializing_if = "Option::is_none")]
281    pub top_k: Option<i32>,
282    #[serde(skip_serializing_if = "Option::is_none")]
283    pub seed: Option<i32>,
284    #[serde(skip_serializing_if = "Option::is_none")]
285    pub presence_penalty: Option<f32>,
286    #[serde(skip_serializing_if = "Option::is_none")]
287    pub frequency_penalty: Option<f32>,
288    #[serde(skip_serializing_if = "Option::is_none")]
289    pub response_logprobs: Option<bool>,
290    #[serde(skip_serializing_if = "Option::is_none")]
291    pub logprobs: Option<i32>,
292    #[serde(skip_serializing_if = "Option::is_none")]
293    pub enable_enhanced_civic_answers: Option<bool>,
294    #[serde(skip_serializing_if = "Option::is_none")]
295    pub speech_config: Option<SpeechConfig>,
296    #[serde(skip_serializing_if = "Option::is_none")]
297    pub thinking_config: Option<ThinkingConfig>,
298    #[serde(skip_serializing_if = "Option::is_none")]
299    pub media_resolution: Option<MediaResolution>,
300}
301
302/// Helper enum for updating portion of the GenerationConfig struct.
303#[derive(Clone, Debug)]
304pub enum UpdateGenConfig {
305    StopSequences(Vec<String>),
306    ResponseMimeType(Option<String>),
307    ResponseSchema(Option<Schema>),
308    ResponseModalities(Vec<Modality>),
309    CandidateCount(Option<i32>),
310    MaxOutputTokens(Option<i32>),
311    Temperature(Option<f32>),
312    TopP(Option<f32>),
313    TopK(Option<i32>),
314    Seed(Option<i32>),
315    PresencePenalty(Option<f32>),
316    FrequencyPenalty(Option<f32>),
317    ResponseLogprobs(Option<bool>),
318    Logprobs(Option<i32>),
319    EnableEnhancedCivicAnswers(Option<bool>),
320    SpeechConfig(Option<SpeechConfig>),
321    ThinkingConfig(Option<ThinkingConfig>),
322    MediaResolution(Option<MediaResolution>),
323}
324
325#[derive(Clone, Debug, Serialize, Deserialize)]
326#[serde(rename_all = "camelCase")]
327pub struct GenerateContentRequest {
328    #[serde(skip_serializing_if = "Option::is_none")]
329    pub system_instruction: Option<Content>,
330    /// The chat history (user and model, except for the last user message which is a message)
331    pub contents: Vec<Content>,
332    /// The tools to use
333    #[serde(skip_serializing_if = "Vec::is_empty")]
334    pub tools: Vec<Tool>,
335    #[serde(skip_serializing_if = "Option::is_none")]
336    pub tool_config: Option<ToolConfig>,
337    #[serde(skip_serializing_if = "Vec::is_empty")]
338    pub safety_settings: Vec<SafetySettings>,
339    #[serde(skip_serializing_if = "Option::is_none")]
340    pub generation_config: Option<GenerationConfig>,
341    #[serde(skip_serializing_if = "Option::is_none")]
342    pub cached_content: Option<String>,
343}