allms/llm_models/
google.rs

1#![allow(deprecated)]
2
3use anyhow::{anyhow, Result};
4use async_trait::async_trait;
5use futures::stream::StreamExt;
6use log::{error, info};
7use reqwest::{header, Client};
8use serde::{Deserialize, Serialize};
9use serde_json::{json, Value};
10
11use crate::apis::GoogleApiEndpoints;
12use crate::completions::ThinkingLevel;
13use crate::constants::{
14    GOOGLE_GEMINI_API_URL, GOOGLE_GEMINI_BETA_API_URL, GOOGLE_VERTEX_API_URL,
15    GOOGLE_VERTEX_ENDPOINT_API_URL,
16};
17use crate::domain::{GoogleGeminiProApiResp, RateLimit};
18use crate::llm_models::tools::{GeminiCodeInterpreterConfig, GeminiWebSearchConfig};
19use crate::llm_models::{LLMModel, LLMTools};
20
21#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)]
22// Google Docs: https://ai.google.dev/gemini-api/docs/models/gemini
23// Google Vertex Docs: https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini
24pub enum GoogleModels {
25    // 3.0
26    Gemini3Pro,
27    // 2.5
28    Gemini2_5Pro,
29    Gemini2_5Flash,
30    Gemini2_5FlashLite,
31    // 2.0
32    Gemini2_0Flash,
33    Gemini2_0FlashLite,
34    // 2.0 - Experimental
35    Gemini2_0ProExp,
36    Gemini2_0FlashThinkingExp,
37    // 1.5
38    Gemini1_5Flash,
39    Gemini1_5Flash8B,
40    Gemini1_5Pro,
41    // Fine-tuned models
42    FineTunedEndpoint {
43        name: String,
44    },
45    // Legacy approach to Vertex models
46    #[deprecated(
47        since = "0.19.0",
48        note = "Starting 0.19.0 `allms` allows to set the API version to `google-vertex` or `google-studio` instead of using the model name to call the right API."
49    )]
50    Gemini1_5FlashVertex,
51    #[deprecated(
52        since = "0.19.0",
53        note = "Starting 0.19.0 `allms` allows to set the API version to `google-vertex` or `google-studio` instead of using the model name to call the right API."
54    )]
55    Gemini1_5Flash8BVertex,
56    #[deprecated(
57        since = "0.19.0",
58        note = "Starting 0.19.0 `allms` allows to set the API version to `google-vertex` or `google-studio` instead of using the model name to call the right API."
59    )]
60    Gemini1_5ProVertex,
61    #[deprecated(
62        since = "0.19.0",
63        note = "Starting 0.19.0 `allms` allows to set the API version to `google-vertex` or `google-studio` instead of using the model name to call the right API."
64    )]
65    Gemini2_0FlashVertex,
66    #[deprecated(
67        since = "0.19.0",
68        note = "Starting 0.19.0 `allms` allows to set the API version to `google-vertex` or `google-studio` instead of using the model name to call the right API."
69    )]
70    Gemini2_0FlashLiteVertex,
71    #[deprecated(
72        since = "0.19.0",
73        note = "Starting 0.19.0 `allms` allows to set the API version to `google-vertex` or `google-studio` instead of using the model name to call the right API."
74    )]
75    Gemini2_0ProExpVertex,
76    #[deprecated(
77        since = "0.19.0",
78        note = "Starting 0.19.0 `allms` allows to set the API version to `google-vertex` or `google-studio` instead of using the model name to call the right API."
79    )]
80    Gemini2_0FlashThinkingExpVertex,
81}
82
83#[async_trait(?Send)]
84impl LLMModel for GoogleModels {
85    fn as_str(&self) -> &str {
86        match self {
87            GoogleModels::Gemini1_5Pro | GoogleModels::Gemini1_5ProVertex => "gemini-1.5-pro",
88            GoogleModels::Gemini1_5Flash | GoogleModels::Gemini1_5FlashVertex => "gemini-1.5-flash",
89            GoogleModels::Gemini1_5Flash8B | GoogleModels::Gemini1_5Flash8BVertex => {
90                "gemini-1.5-flash-8b"
91            }
92            GoogleModels::Gemini2_0Flash | GoogleModels::Gemini2_0FlashVertex => "gemini-2.0-flash",
93            GoogleModels::Gemini2_0FlashLite | GoogleModels::Gemini2_0FlashLiteVertex => {
94                "gemini-2.0-flash-lite"
95            }
96            GoogleModels::Gemini2_0ProExp | GoogleModels::Gemini2_0ProExpVertex => {
97                "gemini-2.0-pro-exp-02-05"
98            }
99            GoogleModels::Gemini2_0FlashThinkingExp
100            | GoogleModels::Gemini2_0FlashThinkingExpVertex => {
101                "gemini-2.0-flash-thinking-exp-01-21"
102            }
103            GoogleModels::Gemini2_5Flash => "gemini-2.5-flash",
104            GoogleModels::Gemini2_5Pro => "gemini-2.5-pro",
105            GoogleModels::Gemini2_5FlashLite => "gemini-2.5-flash-lite",
106            GoogleModels::Gemini3Pro => "gemini-3-pro-preview",
107            GoogleModels::FineTunedEndpoint { name } => name,
108        }
109    }
110
111    fn try_from_str(name: &str) -> Option<Self> {
112        match name.to_lowercase().as_str() {
113            "gemini-1.5-pro" => Some(GoogleModels::Gemini1_5Pro),
114            "gemini-1.5-pro-vertex" => Some(GoogleModels::Gemini1_5ProVertex),
115            "gemini-1.5-flash" => Some(GoogleModels::Gemini1_5Flash),
116            "gemini-1.5-flash-vertex" => Some(GoogleModels::Gemini1_5FlashVertex),
117            "gemini-1.5-flash-8b" => Some(GoogleModels::Gemini1_5Flash8B),
118            "gemini-1.5-flash-8b-vertex" => Some(GoogleModels::Gemini1_5Flash8BVertex),
119            "gemini-2.0-flash" => Some(GoogleModels::Gemini2_0Flash),
120            "gemini-2.0-flash-vertex" => Some(GoogleModels::Gemini2_0FlashVertex),
121            "gemini-2.0-flash-lite" => Some(GoogleModels::Gemini2_0FlashLite),
122            "gemini-2.0-flash-lite-vertex" => Some(GoogleModels::Gemini2_0FlashLiteVertex),
123            "gemini-2.0-pro" => Some(GoogleModels::Gemini2_0ProExp),
124            "gemini-2.0-pro-exp" => Some(GoogleModels::Gemini2_0ProExp),
125            "gemini-2.0-pro-vertex" => Some(GoogleModels::Gemini2_0ProExpVertex),
126            "gemini-2.0-flash-thinking" => Some(GoogleModels::Gemini2_0FlashThinkingExp),
127            "gemini-2.0-flash-thinking-exp" => Some(GoogleModels::Gemini2_0FlashThinkingExp),
128            "gemini-2.0-flash-thinking-vertex" => {
129                Some(GoogleModels::Gemini2_0FlashThinkingExpVertex)
130            }
131            "gemini-2.5-flash" => Some(GoogleModels::Gemini2_5Flash),
132            "gemini-2.5-pro" => Some(GoogleModels::Gemini2_5Pro),
133            "gemini-2.5-flash-lite" => Some(GoogleModels::Gemini2_5FlashLite),
134            "gemini-3-pro-preview" => Some(GoogleModels::Gemini3Pro),
135            "gemini-3-pro" => Some(GoogleModels::Gemini3Pro),
136            // Gemini 1.0 Pro is deprecated starting 2/15/2025. We are re-routing to 1.5 Pro for the model
137            "gemini-pro" => Some(GoogleModels::Gemini1_5Pro),
138            "gemini-1.0-pro" => Some(GoogleModels::Gemini1_5Pro),
139            "gemini-pro-vertex" => Some(GoogleModels::Gemini1_5ProVertex),
140            "gemini-1.0-pro-vertex" => Some(GoogleModels::Gemini1_5ProVertex),
141            // Fine-tuned models need to be constructed via the endpoint method
142            _ => None,
143        }
144    }
145
146    fn default_max_tokens(&self) -> usize {
147        // Docs: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models
148        match self {
149            GoogleModels::Gemini1_5Pro | GoogleModels::Gemini1_5ProVertex => 2_097_152,
150            GoogleModels::Gemini1_5Flash | GoogleModels::Gemini1_5FlashVertex => 1_048_576,
151            GoogleModels::Gemini1_5Flash8B | GoogleModels::Gemini1_5Flash8BVertex => 1_048_576,
152            GoogleModels::Gemini2_0Flash | GoogleModels::Gemini2_0FlashVertex => 1_048_576,
153            GoogleModels::Gemini2_0FlashLite | GoogleModels::Gemini2_0FlashLiteVertex => 1_048_576,
154            GoogleModels::Gemini2_0ProExp | GoogleModels::Gemini2_0ProExpVertex => 2_097_152,
155            GoogleModels::Gemini2_0FlashThinkingExp
156            | GoogleModels::Gemini2_0FlashThinkingExpVertex => 1_048_576,
157            GoogleModels::Gemini2_5Flash => 1_048_576,
158            GoogleModels::Gemini2_5Pro => 1_048_576,
159            GoogleModels::Gemini2_5FlashLite => 1_048_576,
160            GoogleModels::Gemini3Pro => 1_048_576,
161            // TODO: Is this a good assumption?
162            GoogleModels::FineTunedEndpoint { .. } => 1_048_576,
163        }
164    }
165
166    fn get_version_endpoint(&self, version: Option<String>) -> String {
167        // If no version provided default to Google Studio API
168        let version = version
169            .map(|version| GoogleApiEndpoints::from_str(&version))
170            .unwrap_or_default();
171
172        match (self, version) {
173            // Google Studio API
174            (
175                GoogleModels::Gemini1_5Pro
176                | GoogleModels::Gemini1_5Flash
177                | GoogleModels::Gemini1_5Flash8B
178                | GoogleModels::Gemini2_0Flash
179                | GoogleModels::Gemini2_0FlashLite
180                | GoogleModels::Gemini2_0ProExp
181                | GoogleModels::Gemini2_0FlashThinkingExp,
182                GoogleApiEndpoints::GoogleStudio,
183            ) => format!(
184                "{}/{}:generateContent",
185                &*GOOGLE_GEMINI_API_URL,
186                self.as_str()
187            ),
188            // 2.5 models are only available in the beta API
189            (
190                GoogleModels::Gemini2_5Flash
191                | GoogleModels::Gemini2_5Pro
192                | GoogleModels::Gemini2_5FlashLite
193                | GoogleModels::Gemini3Pro,
194                GoogleApiEndpoints::GoogleStudio,
195            ) => format!(
196                "{}/{}:generateContent",
197                &*GOOGLE_GEMINI_BETA_API_URL,
198                self.as_str()
199            ),
200            // Fine-tuned models are only available in the Vertex API
201            // TODO: Explore fine-tuned models in the Studio API
202            (GoogleModels::FineTunedEndpoint { .. }, GoogleApiEndpoints::GoogleStudio) => {
203                // Construct Vertex URL when needed
204                format!(
205                    "{}/{}:generateContent",
206                    &*GOOGLE_VERTEX_ENDPOINT_API_URL,
207                    self.as_str()
208                )
209            }
210            // Google Vertex API
211            (
212                GoogleModels::Gemini1_5Pro
213                | GoogleModels::Gemini1_5Flash
214                | GoogleModels::Gemini1_5Flash8B
215                | GoogleModels::Gemini2_0Flash
216                | GoogleModels::Gemini2_0FlashLite
217                | GoogleModels::Gemini2_0ProExp
218                | GoogleModels::Gemini2_0FlashThinkingExp
219                | GoogleModels::Gemini2_5Flash
220                | GoogleModels::Gemini2_5Pro
221                | GoogleModels::Gemini2_5FlashLite,
222                GoogleApiEndpoints::GoogleVertex,
223            ) => {
224                // Construct Vertex URL when needed
225                format!(
226                    "{}/{}:streamGenerateContent?alt=sse",
227                    &*GOOGLE_VERTEX_API_URL,
228                    self.as_str()
229                )
230            }
231            // Google Vertex does not support Gemini 3 Pro. Rerouting to Studio API
232            // Docs: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/learn/model-versions
233            // TODO: Add Gemini 3 Pro to the Google Vertex API once it is supported
234            (GoogleModels::Gemini3Pro, GoogleApiEndpoints::GoogleVertex) => {
235                error!("[allms][Google Vertex] Gemini 3 Pro is not supported in the Google Vertex API. Rerouting to Studio API.");
236                format!(
237                    "{}/{}:generateContent",
238                    &*GOOGLE_GEMINI_BETA_API_URL,
239                    self.as_str()
240                )
241            }
242            // Google Vertex API for fine-tuned models
243            (GoogleModels::FineTunedEndpoint { .. }, GoogleApiEndpoints::GoogleVertex) => {
244                // Construct Vertex URL when needed
245                format!(
246                    "{}/{}:generateContent",
247                    &*GOOGLE_VERTEX_ENDPOINT_API_URL,
248                    self.as_str()
249                )
250            }
251            // Legacy Google Vertex API implementation
252            #[allow(deprecated)]
253            (
254                GoogleModels::Gemini1_5ProVertex
255                | GoogleModels::Gemini1_5FlashVertex
256                | GoogleModels::Gemini1_5Flash8BVertex
257                | GoogleModels::Gemini2_0FlashVertex
258                | GoogleModels::Gemini2_0FlashLiteVertex
259                | GoogleModels::Gemini2_0ProExpVertex
260                | GoogleModels::Gemini2_0FlashThinkingExpVertex,
261                _,
262            ) => {
263                // Construct Vertex URL when needed
264                format!(
265                    "{}/{}:streamGenerateContent?alt=sse",
266                    &*GOOGLE_VERTEX_API_URL,
267                    self.as_str()
268                )
269            }
270        }
271    }
272
273    //This method prepares the body of the API call for different models
274    fn get_body(
275        &self,
276        instructions: &str,
277        json_schema: &Value,
278        function_call: bool,
279        _max_tokens: &usize,
280        temperature: &f32,
281        tools: Option<&[LLMTools]>,
282        thinking_level: Option<&ThinkingLevel>,
283    ) -> serde_json::Value {
284        //Prepare the 'messages' part of the body
285        let base_instructions_json = json!({
286            "text": self.get_base_instructions(Some(function_call))
287        });
288
289        let output_instructions_json = json!({ "text": format!("<output json schema>
290                {json_schema}
291                </output json schema>") });
292
293        let user_instructions_json = json!({
294            "text": format!("<instructions>
295                {instructions}
296                </instructions>"),
297        });
298
299        let mut message_parts = vec![
300            base_instructions_json,
301            output_instructions_json,
302            user_instructions_json,
303        ];
304
305        // If the `URL context` tool was configured we include a part with the URLs to be used as context
306        if let Some(tools_inner) = tools {
307            if let Some(LLMTools::GeminiWebSearch(config)) = tools_inner
308                .iter()
309                .find(|tool| matches!(tool, LLMTools::GeminiWebSearch(_)))
310            {
311                let urls = config.get_context_urls();
312                if !urls.is_empty() {
313                    message_parts.push(json!({
314                        "text": format!("<url_context>
315                            {:?}
316                            </url_context>",
317                        urls),
318                    }));
319                }
320            }
321        }
322
323        let contents = json!({
324            "role": "user",
325            "parts": message_parts,
326        });
327
328        let generation_config = json!({
329            "temperature": temperature,
330        });
331
332        let mut body = json!({
333            "contents": contents,
334            "generationConfig": generation_config,
335        });
336
337        // Include tools if provided
338        if let Some(tools_inner) = tools {
339            let processed_tools: Vec<Value> = tools_inner
340                .iter()
341                .filter_map(|tool| {
342                    self.get_supported_tools()
343                        .iter()
344                        .find(|supported| {
345                            std::mem::discriminant(tool) == std::mem::discriminant(supported)
346                        })
347                        .and_then(|_| tool.get_config_json())
348                })
349                .collect();
350
351            if !processed_tools.is_empty() {
352                body["tools"] = json!(processed_tools);
353            }
354        }
355
356        // Include thinking level if provided
357        if self.thinking_level_supported() {
358            if let Some(thinking_level) = thinking_level {
359                body["generationConfig"]["thinkingConfig"]["thinkingLevel"] =
360                    json!(thinking_level.as_str());
361            }
362        }
363
364        body
365    }
366
367    /*
368     * This function leverages Google API to perform any query as per the provided body.
369     *
370     * It returns a String the Response object that needs to be parsed based on the self.model.
371     */
372    async fn call_api(
373        &self,
374        api_key: &str,
375        version: Option<String>,
376        body: &serde_json::Value,
377        debug: bool,
378        _tools: Option<&[LLMTools]>,
379    ) -> Result<String> {
380        // If no version provided default to Google Studio API
381        let api_version = version
382            .as_ref()
383            .map(|version| GoogleApiEndpoints::from_str(version))
384            .unwrap_or_default();
385
386        match (self, api_version) {
387            // Google Studio API
388            (
389                GoogleModels::Gemini1_5Pro
390                | GoogleModels::Gemini1_5Flash
391                | GoogleModels::Gemini1_5Flash8B
392                | GoogleModels::Gemini2_0Flash
393                | GoogleModels::Gemini2_0FlashLite
394                | GoogleModels::Gemini2_0ProExp
395                | GoogleModels::Gemini2_0FlashThinkingExp
396                | GoogleModels::Gemini2_5Flash
397                | GoogleModels::Gemini2_5Pro
398                | GoogleModels::Gemini2_5FlashLite
399                | GoogleModels::Gemini3Pro,
400                GoogleApiEndpoints::GoogleStudio,
401            ) => self.call_api_studio(api_key, version, body, debug).await,
402            // Fine-tuned models are only available in the Vertex API
403            // TODO: Explore fine-tuned models in the Studio API
404            (GoogleModels::FineTunedEndpoint { .. }, GoogleApiEndpoints::GoogleStudio) => {
405                self.call_api_vertex(api_key, version, body, debug).await
406            }
407            // Google Vertex API
408            (
409                GoogleModels::Gemini1_5Pro
410                | GoogleModels::Gemini1_5Flash
411                | GoogleModels::Gemini1_5Flash8B
412                | GoogleModels::Gemini2_0Flash
413                | GoogleModels::Gemini2_0FlashLite
414                | GoogleModels::Gemini2_0ProExp
415                | GoogleModels::Gemini2_0FlashThinkingExp
416                | GoogleModels::Gemini2_5Flash
417                | GoogleModels::Gemini2_5Pro
418                | GoogleModels::Gemini2_5FlashLite,
419                GoogleApiEndpoints::GoogleVertex,
420            ) => {
421                self.call_api_vertex_stream(api_key, version, body, debug)
422                    .await
423            }
424            // Google Vertex API for Gemini 3
425            // Gemini 3 is currently only available via Studio
426            // Docs: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/learn/model-versions
427            (GoogleModels::Gemini3Pro, GoogleApiEndpoints::GoogleVertex) => {
428                self.call_api_studio(api_key, version, body, debug).await
429            }
430            // Google Vertex API for fine-tuned models
431            (GoogleModels::FineTunedEndpoint { .. }, GoogleApiEndpoints::GoogleVertex) => {
432                self.call_api_vertex(api_key, version, body, debug).await
433            }
434            // Legacy approach to Google Vertex API
435            #[allow(deprecated)]
436            (
437                GoogleModels::Gemini1_5ProVertex
438                | GoogleModels::Gemini1_5FlashVertex
439                | GoogleModels::Gemini1_5Flash8BVertex
440                | GoogleModels::Gemini2_0FlashVertex
441                | GoogleModels::Gemini2_0FlashLiteVertex
442                | GoogleModels::Gemini2_0ProExpVertex
443                | GoogleModels::Gemini2_0FlashThinkingExpVertex,
444                _,
445            ) => {
446                self.call_api_vertex_stream(api_key, version, body, debug)
447                    .await
448            }
449        }
450    }
451
452    fn get_version_data(
453        &self,
454        response_text: &str,
455        _function_call: bool,
456        version: Option<String>,
457    ) -> Result<String> {
458        // If no version provided default to Google Studio API
459        let version = version
460            .map(|version| GoogleApiEndpoints::from_str(&version))
461            .unwrap_or_default();
462
463        match (self, version) {
464            // Google Studio API
465            (
466                GoogleModels::Gemini1_5Pro
467                | GoogleModels::Gemini1_5Flash
468                | GoogleModels::Gemini1_5Flash8B
469                | GoogleModels::Gemini2_0Flash
470                | GoogleModels::Gemini2_0FlashLite
471                | GoogleModels::Gemini2_0ProExp
472                | GoogleModels::Gemini2_0FlashThinkingExp
473                | GoogleModels::Gemini2_5Flash
474                | GoogleModels::Gemini2_5Pro
475                | GoogleModels::Gemini2_5FlashLite
476                | GoogleModels::Gemini3Pro,
477                GoogleApiEndpoints::GoogleStudio,
478            ) => self.get_generate_content_data(response_text),
479            // Fine-tuned models are only available in the Vertex API
480            // TODO: Explore fine-tuned models in the Studio API
481            (GoogleModels::FineTunedEndpoint { .. }, GoogleApiEndpoints::GoogleStudio) => {
482                self.get_generate_content_data(response_text)
483            }
484            // Because for Vertex we are using streaming the extraction of data/text is handled in call_api method. Here we only pass the input forward
485            (
486                GoogleModels::Gemini1_5Pro
487                | GoogleModels::Gemini1_5Flash
488                | GoogleModels::Gemini1_5Flash8B
489                | GoogleModels::Gemini2_0Flash
490                | GoogleModels::Gemini2_0FlashLite
491                | GoogleModels::Gemini2_0ProExp
492                | GoogleModels::Gemini2_0FlashThinkingExp
493                | GoogleModels::Gemini2_5Flash
494                | GoogleModels::Gemini2_5Pro
495                | GoogleModels::Gemini2_5FlashLite
496                | GoogleModels::Gemini3Pro,
497                GoogleApiEndpoints::GoogleVertex,
498            ) => Ok(response_text.to_string()),
499            // Google Vertex API for fine-tuned models
500            (GoogleModels::FineTunedEndpoint { .. }, GoogleApiEndpoints::GoogleVertex) => {
501                self.get_generate_content_data(response_text)
502            }
503            // Legacy approach to Vertex API implementation
504            #[allow(deprecated)]
505            (
506                GoogleModels::Gemini1_5ProVertex
507                | GoogleModels::Gemini1_5FlashVertex
508                | GoogleModels::Gemini1_5Flash8BVertex
509                | GoogleModels::Gemini2_0FlashVertex
510                | GoogleModels::Gemini2_0FlashLiteVertex
511                | GoogleModels::Gemini2_0ProExpVertex
512                | GoogleModels::Gemini2_0FlashThinkingExpVertex,
513                _,
514            ) => Ok(response_text.to_string()),
515        }
516    }
517
518    //This function allows to check the rate limits for different models
519    fn get_rate_limit(&self) -> RateLimit {
520        //Docs: https://ai.google.dev/gemini-api/docs/rate-limits#tier-3
521        match self {
522            GoogleModels::Gemini1_5Flash | GoogleModels::Gemini1_5FlashVertex => RateLimit {
523                tpm: 4_000_000,
524                rpm: 2_000,
525            },
526            GoogleModels::Gemini1_5Flash8B | GoogleModels::Gemini1_5Flash8BVertex => RateLimit {
527                tpm: 4_000_000,
528                rpm: 4_000,
529            },
530            GoogleModels::Gemini1_5Pro | GoogleModels::Gemini1_5ProVertex => RateLimit {
531                tpm: 4_000_000,
532                rpm: 1_000,
533            },
534            GoogleModels::Gemini2_0Flash | GoogleModels::Gemini2_0FlashVertex => RateLimit {
535                tpm: 30_000_000,
536                rpm: 30_000,
537            },
538            GoogleModels::Gemini2_0FlashLite | GoogleModels::Gemini2_0FlashLiteVertex => {
539                RateLimit {
540                    tpm: 30_000_000,
541                    rpm: 30_000,
542                }
543            }
544            GoogleModels::Gemini2_5Flash => RateLimit {
545                tpm: 8_000_000,
546                rpm: 10_000,
547            },
548            GoogleModels::Gemini2_5Pro => RateLimit {
549                tpm: 8_000_000,
550                rpm: 2_000,
551            },
552            GoogleModels::Gemini2_5FlashLite => RateLimit {
553                tpm: 30_000_000,
554                rpm: 30_000,
555            },
556            GoogleModels::Gemini3Pro => RateLimit {
557                tpm: 8_000_000,
558                rpm: 2_000,
559            },
560            // Fine-tuned models use 2.0 Flash and Flash Lite rate limits
561            GoogleModels::FineTunedEndpoint { .. } => RateLimit {
562                tpm: 30_000_000,
563                rpm: 30_000,
564            },
565            // TODO: No rate limits published for experimental models
566            _ => RateLimit {
567                tpm: 120_000,
568                rpm: 360,
569            },
570        }
571    }
572
573    fn get_default_temperature(&self) -> f32 {
574        // For Gemini 3, we strongly recommend keeping the temperature parameter at its default value of 1.0.
575        // Docs: https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#temperature
576        if self == &GoogleModels::Gemini3Pro {
577            1.0f32
578        } else {
579            0.0f32
580        }
581    }
582}
583
584impl GoogleModels {
585    /// Constructor of the fine-tuned model endpoint
586    /// Fine-tuned models are available in the Vertex API via the endpoint ID
587    pub fn endpoint(name: &str) -> Self {
588        GoogleModels::FineTunedEndpoint {
589            name: name.to_string(),
590        }
591    }
592
593    // Specialized function for calling AI Studio API
594    async fn call_api_studio(
595        &self,
596        api_key: &str,
597        version: Option<String>,
598        body: &serde_json::Value,
599        debug: bool,
600    ) -> Result<String> {
601        //Get the API url
602        let model_url = self.get_version_endpoint(version);
603
604        //Make the API call
605        let client = Client::new();
606
607        //Send request
608        let url_with_key = format!("{}?key={}", model_url, api_key);
609        let response = client
610            .post(url_with_key)
611            .header(header::CONTENT_TYPE, "application/json")
612            .json(&body)
613            .send()
614            .await?;
615
616        let response_status = response.status();
617        let response_text = response.text().await?;
618
619        if debug {
620            info!(
621                "[allms][Google AI Studio] API response: [{}] {:#?}",
622                &response_status, &response_text
623            );
624        }
625
626        Ok(response_text)
627    }
628
629    // Specialized function for calling Vertex API with streaming
630    async fn call_api_vertex_stream(
631        &self,
632        api_key: &str,
633        version: Option<String>,
634        body: &serde_json::Value,
635        debug: bool,
636    ) -> Result<String> {
637        //Get the API url
638        let model_url = self.get_version_endpoint(version);
639
640        //Make the API call
641        let client = Client::new();
642
643        //Send request
644        let response = client
645            .post(model_url)
646            .header(header::CONTENT_TYPE, "application/json")
647            .bearer_auth(api_key)
648            .json(&body)
649            .send()
650            .await?;
651
652        //For Vertex we are streaming that data so we need to deserialize each chunk separately
653        // Check if the API uses streaming
654        if response.status().is_success() {
655            let mut stream = response.bytes_stream();
656            let mut streamed_response = String::new();
657
658            while let Some(chunk) = stream.next().await {
659                let chunk = chunk?;
660
661                // Convert the chunk (Bytes) to a String
662                let mut chunk_str = String::from_utf8(chunk.to_vec()).map_err(|e| anyhow!(e))?;
663
664                // The chunk response starts with "data: " that needs to be remove
665                if chunk_str.starts_with("data: ") {
666                    // Remove the first 6 characters ("data: ")
667                    chunk_str = chunk_str[6..].to_string();
668                }
669
670                //Convert response chunk to struct representing expected response format
671                let gemini_response: GoogleGeminiProApiResp = serde_json::from_str(&chunk_str)?;
672
673                //Extract the data part from the response
674                let part_text = gemini_response
675                    .candidates
676                    .iter()
677                    .filter(|candidate| candidate.content.role.as_deref() == Some("model"))
678                    .flat_map(|candidate| &candidate.content.parts)
679                    .filter_map(|part| part.text.as_ref())
680                    .fold(String::new(), |mut acc, text| {
681                        acc.push_str(text);
682                        acc
683                    });
684
685                //Add the chunk response to output string
686                streamed_response.push_str(&part_text);
687
688                // Debug log each chunk if needed
689                if debug {
690                    info!(
691                        "[allms][Google Vertex AI] Received response chunk: {:?}",
692                        chunk
693                    );
694                }
695            }
696            Ok(self.sanitize_json_response(&streamed_response))
697        } else {
698            let response_status = response.status();
699            let response_txt = response.text().await?;
700            Err(anyhow!(
701                "[allms][Google][{}] Response body: {:#?}",
702                response_status,
703                response_txt
704            ))
705        }
706    }
707
708    // Specialized function for calling Vertex API without streaming (used for fine-tuned models)
709    async fn call_api_vertex(
710        &self,
711        api_key: &str,
712        version: Option<String>,
713        body: &serde_json::Value,
714        debug: bool,
715    ) -> Result<String> {
716        //Get the API url
717        let model_url = self.get_version_endpoint(version);
718
719        //Make the API call
720        let client = Client::new();
721
722        //Send request
723        let response = client
724            .post(model_url)
725            .header(header::CONTENT_TYPE, "application/json")
726            .bearer_auth(api_key)
727            .json(&body)
728            .send()
729            .await?;
730
731        let response_status = response.status();
732        let response_text = response.text().await?;
733
734        if debug {
735            info!(
736                "[allms][Google AI Vertex][Fine-tuned] API response: [{}] {:#?}",
737                &response_status, &response_text
738            );
739        }
740
741        Ok(response_text)
742    }
743
744    // Specialized function for parsing response of the generateContent API (non-streaming)
745    fn get_generate_content_data(&self, response_text: &str) -> Result<String> {
746        //Convert response to struct representing expected response format
747        let gemini_response: GoogleGeminiProApiResp = serde_json::from_str(response_text)?;
748
749        //Extract the data part from the response
750        let data = gemini_response
751            .candidates
752            .iter()
753            .filter(|candidate| candidate.content.role.as_deref() == Some("model"))
754            .flat_map(|candidate| &candidate.content.parts)
755            .filter_map(|part| part.text.as_ref())
756            .fold(String::new(), |mut acc, text| {
757                acc.push_str(text);
758                acc
759            });
760
761        Ok(self.sanitize_json_response(&data))
762    }
763
764    fn get_supported_tools(&self) -> Vec<LLMTools> {
765        match self {
766            GoogleModels::Gemini2_5Pro
767            | GoogleModels::Gemini2_5Flash
768            | GoogleModels::Gemini2_5FlashLite
769            | GoogleModels::Gemini2_0Flash
770            | GoogleModels::Gemini3Pro => vec![
771                LLMTools::GeminiCodeInterpreter(GeminiCodeInterpreterConfig::new()),
772                LLMTools::GeminiWebSearch(GeminiWebSearchConfig::new()),
773            ],
774            _ => vec![],
775        }
776    }
777
778    fn thinking_level_supported(&self) -> bool {
779        match self {
780            GoogleModels::Gemini3Pro => true,
781            GoogleModels::Gemini2_5Pro
782            | GoogleModels::Gemini2_5Flash
783            | GoogleModels::Gemini2_5FlashLite
784            | GoogleModels::Gemini2_0Flash
785            | GoogleModels::Gemini2_0FlashLite
786            | GoogleModels::Gemini2_0ProExp
787            | GoogleModels::Gemini2_0FlashThinkingExp
788            | GoogleModels::Gemini1_5Flash
789            | GoogleModels::Gemini1_5Flash8B
790            | GoogleModels::Gemini1_5Pro
791            | GoogleModels::FineTunedEndpoint { .. }
792            | GoogleModels::Gemini1_5FlashVertex
793            | GoogleModels::Gemini1_5Flash8BVertex
794            | GoogleModels::Gemini1_5ProVertex
795            | GoogleModels::Gemini2_0FlashVertex
796            | GoogleModels::Gemini2_0FlashLiteVertex
797            | GoogleModels::Gemini2_0ProExpVertex
798            | GoogleModels::Gemini2_0FlashThinkingExpVertex => false,
799        }
800    }
801}
802
803#[cfg(test)]
804mod tests {
805    use super::*;
806    use serde_json::json;
807
808    fn create_test_model() -> GoogleModels {
809        GoogleModels::Gemini1_5Pro
810    }
811
812    fn create_test_schema() -> Value {
813        json!({
814            "type": "object",
815            "properties": {
816                "answer": {
817                    "type": "string"
818                }
819            }
820        })
821    }
822
823    ///
824    /// get_body
825    ///
826    #[test]
827    fn test_get_body_basic_functionality() {
828        let model = create_test_model();
829        let instructions = "Test instructions";
830        let json_schema = create_test_schema();
831        let function_call = false;
832        let max_tokens = 1000;
833        let temperature = 0.7;
834        let tools = None;
835        let thinking_level = None;
836
837        let body = model.get_body(
838            instructions,
839            &json_schema,
840            function_call,
841            &max_tokens,
842            &temperature,
843            tools,
844            thinking_level,
845        );
846
847        // Verify the structure of the returned JSON
848        assert!(body.is_object());
849
850        // Check that contents field exists and has the right structure
851        assert!(body["contents"].is_object());
852        assert_eq!(body["contents"]["role"], "user");
853        assert!(body["contents"]["parts"].is_array());
854
855        // Check that generationConfig exists
856        assert!(body["generationConfig"].is_object());
857        assert!((body["generationConfig"]["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
858
859        // Verify that tools field is not present when no tools are provided
860        assert!(body["tools"].is_null());
861    }
862
863    #[test]
864    fn test_get_body_with_instructions_content() {
865        let model = create_test_model();
866        let instructions = "Please analyze this data and provide insights";
867        let json_schema = create_test_schema();
868        let function_call = false;
869        let max_tokens = 1000;
870        let temperature = 0.5;
871        let tools = None;
872        let thinking_level = None;
873
874        let body = model.get_body(
875            instructions,
876            &json_schema,
877            function_call,
878            &max_tokens,
879            &temperature,
880            tools,
881            thinking_level,
882        );
883
884        let parts = &body["contents"]["parts"];
885        assert!(parts.is_array());
886
887        // Find the user instructions part
888        let user_instructions = parts.as_array().unwrap().iter().find(|part| {
889            part["text"]
890                .as_str()
891                .unwrap_or("")
892                .contains("<instructions>")
893        });
894
895        assert!(user_instructions.is_some());
896        let text = user_instructions.unwrap()["text"].as_str().unwrap();
897        assert!(text.contains("Please analyze this data and provide insights"));
898        assert!(text.contains("<instructions>"));
899        assert!(text.contains("</instructions>"));
900    }
901
902    #[test]
903    fn test_get_body_with_json_schema() {
904        let model = create_test_model();
905        let instructions = "Test";
906        let json_schema = json!({
907            "type": "object",
908            "properties": {
909                "result": {
910                    "type": "string",
911                    "description": "The result"
912                }
913            }
914        });
915        let function_call = false;
916        let max_tokens = 1000;
917        let temperature = 0.3;
918        let tools = None;
919        let thinking_level = None;
920
921        let body = model.get_body(
922            instructions,
923            &json_schema,
924            function_call,
925            &max_tokens,
926            &temperature,
927            tools,
928            thinking_level,
929        );
930
931        let parts = &body["contents"]["parts"];
932        let output_schema_part = parts.as_array().unwrap().iter().find(|part| {
933            part["text"]
934                .as_str()
935                .unwrap_or("")
936                .contains("<output json schema>")
937        });
938
939        assert!(output_schema_part.is_some());
940        let text = output_schema_part.unwrap()["text"].as_str().unwrap();
941        assert!(text.contains("<output json schema>"));
942        assert!(text.contains("</output json schema>"));
943        // The JSON schema is serialized, so we need to check for the actual serialized content
944        assert!(text.contains("type"));
945        assert!(text.contains("object"));
946    }
947
948    #[test]
949    fn test_get_body_with_tools() {
950        let model = GoogleModels::Gemini2_5Pro; // Model that supports tools
951        let instructions = "Search for information";
952        let json_schema = create_test_schema();
953        let function_call = false;
954        let max_tokens = 1000;
955        let temperature = 0.7;
956        let tools_array = [
957            LLMTools::GeminiWebSearch(GeminiWebSearchConfig::new()),
958            LLMTools::GeminiCodeInterpreter(GeminiCodeInterpreterConfig::new()),
959        ];
960        let tools = Some(&tools_array[..]);
961        let thinking_level = None;
962
963        let body = model.get_body(
964            instructions,
965            &json_schema,
966            function_call,
967            &max_tokens,
968            &temperature,
969            tools,
970            thinking_level,
971        );
972
973        // Check that tools field exists and contains the tools
974        assert!(body["tools"].is_array());
975        let tools_array = body["tools"].as_array().unwrap();
976        assert!(!tools_array.is_empty());
977    }
978
979    #[test]
980    fn test_get_body_with_unsupported_tools() {
981        let model = GoogleModels::Gemini1_5Flash; // Model that doesn't support tools
982        let instructions = "Test";
983        let json_schema = create_test_schema();
984        let function_call = false;
985        let max_tokens = 1000;
986        let temperature = 0.7;
987        let tools_array = [LLMTools::GeminiWebSearch(GeminiWebSearchConfig::new())];
988        let tools = Some(&tools_array[..]);
989        let thinking_level = None;
990
991        let body = model.get_body(
992            instructions,
993            &json_schema,
994            function_call,
995            &max_tokens,
996            &temperature,
997            tools,
998            thinking_level,
999        );
1000
1001        // Tools should not be included for unsupported models
1002        assert!(body["tools"].is_null());
1003    }
1004
1005    #[test]
1006    fn test_get_body_with_web_search_context() {
1007        let model = GoogleModels::Gemini2_5Pro;
1008        let instructions = "Search for information";
1009        let json_schema = create_test_schema();
1010        let function_call = false;
1011        let max_tokens = 1000;
1012        let temperature = 0.7;
1013        let thinking_level = None;
1014
1015        // Create a web search config with URLs
1016        let web_search_config = GeminiWebSearchConfig::new();
1017        // Note: We need to check if GeminiWebSearchConfig has methods to set URLs
1018        // For now, we'll test the basic structure
1019        let tools_array = [LLMTools::GeminiWebSearch(web_search_config)];
1020        let tools = Some(&tools_array[..]);
1021
1022        let body = model.get_body(
1023            instructions,
1024            &json_schema,
1025            function_call,
1026            &max_tokens,
1027            &temperature,
1028            tools,
1029            thinking_level,
1030        );
1031
1032        // Verify the structure is correct
1033        assert!(body["contents"].is_object());
1034        assert!(body["generationConfig"].is_object());
1035    }
1036
1037    #[test]
1038    fn test_get_body_temperature_values() {
1039        let model = create_test_model();
1040        let instructions = "Test";
1041        let json_schema = create_test_schema();
1042        let function_call = false;
1043        let max_tokens = 1000;
1044        let tools = None;
1045        let thinking_level = None;
1046
1047        // Test different temperature values
1048        let temperatures = vec![0.0, 0.5, 1.0, 1.5];
1049
1050        for temp in temperatures {
1051            let body = model.get_body(
1052                instructions,
1053                &json_schema,
1054                function_call,
1055                &max_tokens,
1056                &temp,
1057                tools,
1058                thinking_level,
1059            );
1060
1061            assert_eq!(body["generationConfig"]["temperature"], temp);
1062        }
1063    }
1064
1065    #[test]
1066    fn test_get_body_function_call_true() {
1067        let model = create_test_model();
1068        let instructions = "Test with function calling";
1069        let json_schema = create_test_schema();
1070        let function_call = true;
1071        let max_tokens = 1000;
1072        let temperature = 0.7;
1073        let tools = None;
1074        let thinking_level = None;
1075
1076        let body = model.get_body(
1077            instructions,
1078            &json_schema,
1079            function_call,
1080            &max_tokens,
1081            &temperature,
1082            tools,
1083            thinking_level,
1084        );
1085
1086        // The function_call parameter affects the base instructions
1087        // We should verify that the base instructions are included
1088        let parts = &body["contents"]["parts"];
1089        assert!(parts.is_array());
1090
1091        // Check that base instructions are included
1092        let base_instructions = parts
1093            .as_array()
1094            .unwrap()
1095            .iter()
1096            .find(|part| part["text"].as_str().unwrap_or("").contains("text"));
1097
1098        assert!(base_instructions.is_some());
1099    }
1100
1101    #[test]
1102    fn test_get_body_empty_instructions() {
1103        let model = create_test_model();
1104        let instructions = "";
1105        let json_schema = create_test_schema();
1106        let function_call = false;
1107        let max_tokens = 1000;
1108        let temperature = 0.7;
1109        let tools = None;
1110        let thinking_level = None;
1111
1112        let body = model.get_body(
1113            instructions,
1114            &json_schema,
1115            function_call,
1116            &max_tokens,
1117            &temperature,
1118            tools,
1119            thinking_level,
1120        );
1121
1122        // Should still create a valid body even with empty instructions
1123        assert!(body.is_object());
1124        assert!(body["contents"].is_object());
1125        assert!(body["generationConfig"].is_object());
1126    }
1127
1128    #[test]
1129    fn test_get_body_complex_json_schema() {
1130        let model = create_test_model();
1131        let instructions = "Test";
1132        let json_schema = json!({
1133            "type": "object",
1134            "properties": {
1135                "name": {
1136                    "type": "string",
1137                    "description": "The name"
1138                },
1139                "age": {
1140                    "type": "integer",
1141                    "minimum": 0
1142                },
1143                "items": {
1144                    "type": "array",
1145                    "items": {
1146                        "type": "string"
1147                    }
1148                }
1149            },
1150            "required": ["name"]
1151        });
1152        let function_call = false;
1153        let max_tokens = 1000;
1154        let temperature = 0.7;
1155        let tools = None;
1156        let thinking_level = None;
1157
1158        let body = model.get_body(
1159            instructions,
1160            &json_schema,
1161            function_call,
1162            &max_tokens,
1163            &temperature,
1164            tools,
1165            thinking_level,
1166        );
1167
1168        // Verify that the complex schema is properly included
1169        let parts = &body["contents"]["parts"];
1170        let output_schema_part = parts.as_array().unwrap().iter().find(|part| {
1171            part["text"]
1172                .as_str()
1173                .unwrap_or("")
1174                .contains("<output json schema>")
1175        });
1176
1177        assert!(output_schema_part.is_some());
1178        let text = output_schema_part.unwrap()["text"].as_str().unwrap();
1179        // The JSON schema is serialized, so we need to check for the actual serialized content
1180        assert!(text.contains("type"));
1181        assert!(text.contains("object"));
1182        assert!(text.contains("required"));
1183        assert!(text.contains("name"));
1184    }
1185}