1#![allow(deprecated)]
2
3use anyhow::{anyhow, Result};
4use async_trait::async_trait;
5use futures::stream::StreamExt;
6use log::{error, info};
7use reqwest::{header, Client};
8use serde::{Deserialize, Serialize};
9use serde_json::{json, Value};
10
11use crate::apis::GoogleApiEndpoints;
12use crate::completions::ThinkingLevel;
13use crate::constants::{
14 GOOGLE_GEMINI_API_URL, GOOGLE_GEMINI_BETA_API_URL, GOOGLE_VERTEX_API_URL,
15 GOOGLE_VERTEX_ENDPOINT_API_URL,
16};
17use crate::domain::{GoogleGeminiProApiResp, RateLimit};
18use crate::llm_models::tools::{GeminiCodeInterpreterConfig, GeminiWebSearchConfig};
19use crate::llm_models::{LLMModel, LLMTools};
20
21#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)]
22pub enum GoogleModels {
25 Gemini3Pro,
27 Gemini2_5Pro,
29 Gemini2_5Flash,
30 Gemini2_5FlashLite,
31 Gemini2_0Flash,
33 Gemini2_0FlashLite,
34 Gemini2_0ProExp,
36 Gemini2_0FlashThinkingExp,
37 Gemini1_5Flash,
39 Gemini1_5Flash8B,
40 Gemini1_5Pro,
41 FineTunedEndpoint {
43 name: String,
44 },
45 #[deprecated(
47 since = "0.19.0",
48 note = "Starting 0.19.0 `allms` allows to set the API version to `google-vertex` or `google-studio` instead of using the model name to call the right API."
49 )]
50 Gemini1_5FlashVertex,
51 #[deprecated(
52 since = "0.19.0",
53 note = "Starting 0.19.0 `allms` allows to set the API version to `google-vertex` or `google-studio` instead of using the model name to call the right API."
54 )]
55 Gemini1_5Flash8BVertex,
56 #[deprecated(
57 since = "0.19.0",
58 note = "Starting 0.19.0 `allms` allows to set the API version to `google-vertex` or `google-studio` instead of using the model name to call the right API."
59 )]
60 Gemini1_5ProVertex,
61 #[deprecated(
62 since = "0.19.0",
63 note = "Starting 0.19.0 `allms` allows to set the API version to `google-vertex` or `google-studio` instead of using the model name to call the right API."
64 )]
65 Gemini2_0FlashVertex,
66 #[deprecated(
67 since = "0.19.0",
68 note = "Starting 0.19.0 `allms` allows to set the API version to `google-vertex` or `google-studio` instead of using the model name to call the right API."
69 )]
70 Gemini2_0FlashLiteVertex,
71 #[deprecated(
72 since = "0.19.0",
73 note = "Starting 0.19.0 `allms` allows to set the API version to `google-vertex` or `google-studio` instead of using the model name to call the right API."
74 )]
75 Gemini2_0ProExpVertex,
76 #[deprecated(
77 since = "0.19.0",
78 note = "Starting 0.19.0 `allms` allows to set the API version to `google-vertex` or `google-studio` instead of using the model name to call the right API."
79 )]
80 Gemini2_0FlashThinkingExpVertex,
81}
82
83#[async_trait(?Send)]
84impl LLMModel for GoogleModels {
85 fn as_str(&self) -> &str {
86 match self {
87 GoogleModels::Gemini1_5Pro | GoogleModels::Gemini1_5ProVertex => "gemini-1.5-pro",
88 GoogleModels::Gemini1_5Flash | GoogleModels::Gemini1_5FlashVertex => "gemini-1.5-flash",
89 GoogleModels::Gemini1_5Flash8B | GoogleModels::Gemini1_5Flash8BVertex => {
90 "gemini-1.5-flash-8b"
91 }
92 GoogleModels::Gemini2_0Flash | GoogleModels::Gemini2_0FlashVertex => "gemini-2.0-flash",
93 GoogleModels::Gemini2_0FlashLite | GoogleModels::Gemini2_0FlashLiteVertex => {
94 "gemini-2.0-flash-lite"
95 }
96 GoogleModels::Gemini2_0ProExp | GoogleModels::Gemini2_0ProExpVertex => {
97 "gemini-2.0-pro-exp-02-05"
98 }
99 GoogleModels::Gemini2_0FlashThinkingExp
100 | GoogleModels::Gemini2_0FlashThinkingExpVertex => {
101 "gemini-2.0-flash-thinking-exp-01-21"
102 }
103 GoogleModels::Gemini2_5Flash => "gemini-2.5-flash",
104 GoogleModels::Gemini2_5Pro => "gemini-2.5-pro",
105 GoogleModels::Gemini2_5FlashLite => "gemini-2.5-flash-lite",
106 GoogleModels::Gemini3Pro => "gemini-3-pro-preview",
107 GoogleModels::FineTunedEndpoint { name } => name,
108 }
109 }
110
111 fn try_from_str(name: &str) -> Option<Self> {
112 match name.to_lowercase().as_str() {
113 "gemini-1.5-pro" => Some(GoogleModels::Gemini1_5Pro),
114 "gemini-1.5-pro-vertex" => Some(GoogleModels::Gemini1_5ProVertex),
115 "gemini-1.5-flash" => Some(GoogleModels::Gemini1_5Flash),
116 "gemini-1.5-flash-vertex" => Some(GoogleModels::Gemini1_5FlashVertex),
117 "gemini-1.5-flash-8b" => Some(GoogleModels::Gemini1_5Flash8B),
118 "gemini-1.5-flash-8b-vertex" => Some(GoogleModels::Gemini1_5Flash8BVertex),
119 "gemini-2.0-flash" => Some(GoogleModels::Gemini2_0Flash),
120 "gemini-2.0-flash-vertex" => Some(GoogleModels::Gemini2_0FlashVertex),
121 "gemini-2.0-flash-lite" => Some(GoogleModels::Gemini2_0FlashLite),
122 "gemini-2.0-flash-lite-vertex" => Some(GoogleModels::Gemini2_0FlashLiteVertex),
123 "gemini-2.0-pro" => Some(GoogleModels::Gemini2_0ProExp),
124 "gemini-2.0-pro-exp" => Some(GoogleModels::Gemini2_0ProExp),
125 "gemini-2.0-pro-vertex" => Some(GoogleModels::Gemini2_0ProExpVertex),
126 "gemini-2.0-flash-thinking" => Some(GoogleModels::Gemini2_0FlashThinkingExp),
127 "gemini-2.0-flash-thinking-exp" => Some(GoogleModels::Gemini2_0FlashThinkingExp),
128 "gemini-2.0-flash-thinking-vertex" => {
129 Some(GoogleModels::Gemini2_0FlashThinkingExpVertex)
130 }
131 "gemini-2.5-flash" => Some(GoogleModels::Gemini2_5Flash),
132 "gemini-2.5-pro" => Some(GoogleModels::Gemini2_5Pro),
133 "gemini-2.5-flash-lite" => Some(GoogleModels::Gemini2_5FlashLite),
134 "gemini-3-pro-preview" => Some(GoogleModels::Gemini3Pro),
135 "gemini-3-pro" => Some(GoogleModels::Gemini3Pro),
136 "gemini-pro" => Some(GoogleModels::Gemini1_5Pro),
138 "gemini-1.0-pro" => Some(GoogleModels::Gemini1_5Pro),
139 "gemini-pro-vertex" => Some(GoogleModels::Gemini1_5ProVertex),
140 "gemini-1.0-pro-vertex" => Some(GoogleModels::Gemini1_5ProVertex),
141 _ => None,
143 }
144 }
145
146 fn default_max_tokens(&self) -> usize {
147 match self {
149 GoogleModels::Gemini1_5Pro | GoogleModels::Gemini1_5ProVertex => 2_097_152,
150 GoogleModels::Gemini1_5Flash | GoogleModels::Gemini1_5FlashVertex => 1_048_576,
151 GoogleModels::Gemini1_5Flash8B | GoogleModels::Gemini1_5Flash8BVertex => 1_048_576,
152 GoogleModels::Gemini2_0Flash | GoogleModels::Gemini2_0FlashVertex => 1_048_576,
153 GoogleModels::Gemini2_0FlashLite | GoogleModels::Gemini2_0FlashLiteVertex => 1_048_576,
154 GoogleModels::Gemini2_0ProExp | GoogleModels::Gemini2_0ProExpVertex => 2_097_152,
155 GoogleModels::Gemini2_0FlashThinkingExp
156 | GoogleModels::Gemini2_0FlashThinkingExpVertex => 1_048_576,
157 GoogleModels::Gemini2_5Flash => 1_048_576,
158 GoogleModels::Gemini2_5Pro => 1_048_576,
159 GoogleModels::Gemini2_5FlashLite => 1_048_576,
160 GoogleModels::Gemini3Pro => 1_048_576,
161 GoogleModels::FineTunedEndpoint { .. } => 1_048_576,
163 }
164 }
165
166 fn get_version_endpoint(&self, version: Option<String>) -> String {
167 let version = version
169 .map(|version| GoogleApiEndpoints::from_str(&version))
170 .unwrap_or_default();
171
172 match (self, version) {
173 (
175 GoogleModels::Gemini1_5Pro
176 | GoogleModels::Gemini1_5Flash
177 | GoogleModels::Gemini1_5Flash8B
178 | GoogleModels::Gemini2_0Flash
179 | GoogleModels::Gemini2_0FlashLite
180 | GoogleModels::Gemini2_0ProExp
181 | GoogleModels::Gemini2_0FlashThinkingExp,
182 GoogleApiEndpoints::GoogleStudio,
183 ) => format!(
184 "{}/{}:generateContent",
185 &*GOOGLE_GEMINI_API_URL,
186 self.as_str()
187 ),
188 (
190 GoogleModels::Gemini2_5Flash
191 | GoogleModels::Gemini2_5Pro
192 | GoogleModels::Gemini2_5FlashLite
193 | GoogleModels::Gemini3Pro,
194 GoogleApiEndpoints::GoogleStudio,
195 ) => format!(
196 "{}/{}:generateContent",
197 &*GOOGLE_GEMINI_BETA_API_URL,
198 self.as_str()
199 ),
200 (GoogleModels::FineTunedEndpoint { .. }, GoogleApiEndpoints::GoogleStudio) => {
203 format!(
205 "{}/{}:generateContent",
206 &*GOOGLE_VERTEX_ENDPOINT_API_URL,
207 self.as_str()
208 )
209 }
210 (
212 GoogleModels::Gemini1_5Pro
213 | GoogleModels::Gemini1_5Flash
214 | GoogleModels::Gemini1_5Flash8B
215 | GoogleModels::Gemini2_0Flash
216 | GoogleModels::Gemini2_0FlashLite
217 | GoogleModels::Gemini2_0ProExp
218 | GoogleModels::Gemini2_0FlashThinkingExp
219 | GoogleModels::Gemini2_5Flash
220 | GoogleModels::Gemini2_5Pro
221 | GoogleModels::Gemini2_5FlashLite,
222 GoogleApiEndpoints::GoogleVertex,
223 ) => {
224 format!(
226 "{}/{}:streamGenerateContent?alt=sse",
227 &*GOOGLE_VERTEX_API_URL,
228 self.as_str()
229 )
230 }
231 (GoogleModels::Gemini3Pro, GoogleApiEndpoints::GoogleVertex) => {
235 error!("[allms][Google Vertex] Gemini 3 Pro is not supported in the Google Vertex API. Rerouting to Studio API.");
236 format!(
237 "{}/{}:generateContent",
238 &*GOOGLE_GEMINI_BETA_API_URL,
239 self.as_str()
240 )
241 }
242 (GoogleModels::FineTunedEndpoint { .. }, GoogleApiEndpoints::GoogleVertex) => {
244 format!(
246 "{}/{}:generateContent",
247 &*GOOGLE_VERTEX_ENDPOINT_API_URL,
248 self.as_str()
249 )
250 }
251 #[allow(deprecated)]
253 (
254 GoogleModels::Gemini1_5ProVertex
255 | GoogleModels::Gemini1_5FlashVertex
256 | GoogleModels::Gemini1_5Flash8BVertex
257 | GoogleModels::Gemini2_0FlashVertex
258 | GoogleModels::Gemini2_0FlashLiteVertex
259 | GoogleModels::Gemini2_0ProExpVertex
260 | GoogleModels::Gemini2_0FlashThinkingExpVertex,
261 _,
262 ) => {
263 format!(
265 "{}/{}:streamGenerateContent?alt=sse",
266 &*GOOGLE_VERTEX_API_URL,
267 self.as_str()
268 )
269 }
270 }
271 }
272
273 fn get_body(
275 &self,
276 instructions: &str,
277 json_schema: &Value,
278 function_call: bool,
279 _max_tokens: &usize,
280 temperature: &f32,
281 tools: Option<&[LLMTools]>,
282 thinking_level: Option<&ThinkingLevel>,
283 ) -> serde_json::Value {
284 let base_instructions_json = json!({
286 "text": self.get_base_instructions(Some(function_call))
287 });
288
289 let output_instructions_json = json!({ "text": format!("<output json schema>
290 {json_schema}
291 </output json schema>") });
292
293 let user_instructions_json = json!({
294 "text": format!("<instructions>
295 {instructions}
296 </instructions>"),
297 });
298
299 let mut message_parts = vec![
300 base_instructions_json,
301 output_instructions_json,
302 user_instructions_json,
303 ];
304
305 if let Some(tools_inner) = tools {
307 if let Some(LLMTools::GeminiWebSearch(config)) = tools_inner
308 .iter()
309 .find(|tool| matches!(tool, LLMTools::GeminiWebSearch(_)))
310 {
311 let urls = config.get_context_urls();
312 if !urls.is_empty() {
313 message_parts.push(json!({
314 "text": format!("<url_context>
315 {:?}
316 </url_context>",
317 urls),
318 }));
319 }
320 }
321 }
322
323 let contents = json!({
324 "role": "user",
325 "parts": message_parts,
326 });
327
328 let generation_config = json!({
329 "temperature": temperature,
330 });
331
332 let mut body = json!({
333 "contents": contents,
334 "generationConfig": generation_config,
335 });
336
337 if let Some(tools_inner) = tools {
339 let processed_tools: Vec<Value> = tools_inner
340 .iter()
341 .filter_map(|tool| {
342 self.get_supported_tools()
343 .iter()
344 .find(|supported| {
345 std::mem::discriminant(tool) == std::mem::discriminant(supported)
346 })
347 .and_then(|_| tool.get_config_json())
348 })
349 .collect();
350
351 if !processed_tools.is_empty() {
352 body["tools"] = json!(processed_tools);
353 }
354 }
355
356 if self.thinking_level_supported() {
358 if let Some(thinking_level) = thinking_level {
359 body["generationConfig"]["thinkingConfig"]["thinkingLevel"] =
360 json!(thinking_level.as_str());
361 }
362 }
363
364 body
365 }
366
367 async fn call_api(
373 &self,
374 api_key: &str,
375 version: Option<String>,
376 body: &serde_json::Value,
377 debug: bool,
378 _tools: Option<&[LLMTools]>,
379 ) -> Result<String> {
380 let api_version = version
382 .as_ref()
383 .map(|version| GoogleApiEndpoints::from_str(version))
384 .unwrap_or_default();
385
386 match (self, api_version) {
387 (
389 GoogleModels::Gemini1_5Pro
390 | GoogleModels::Gemini1_5Flash
391 | GoogleModels::Gemini1_5Flash8B
392 | GoogleModels::Gemini2_0Flash
393 | GoogleModels::Gemini2_0FlashLite
394 | GoogleModels::Gemini2_0ProExp
395 | GoogleModels::Gemini2_0FlashThinkingExp
396 | GoogleModels::Gemini2_5Flash
397 | GoogleModels::Gemini2_5Pro
398 | GoogleModels::Gemini2_5FlashLite
399 | GoogleModels::Gemini3Pro,
400 GoogleApiEndpoints::GoogleStudio,
401 ) => self.call_api_studio(api_key, version, body, debug).await,
402 (GoogleModels::FineTunedEndpoint { .. }, GoogleApiEndpoints::GoogleStudio) => {
405 self.call_api_vertex(api_key, version, body, debug).await
406 }
407 (
409 GoogleModels::Gemini1_5Pro
410 | GoogleModels::Gemini1_5Flash
411 | GoogleModels::Gemini1_5Flash8B
412 | GoogleModels::Gemini2_0Flash
413 | GoogleModels::Gemini2_0FlashLite
414 | GoogleModels::Gemini2_0ProExp
415 | GoogleModels::Gemini2_0FlashThinkingExp
416 | GoogleModels::Gemini2_5Flash
417 | GoogleModels::Gemini2_5Pro
418 | GoogleModels::Gemini2_5FlashLite,
419 GoogleApiEndpoints::GoogleVertex,
420 ) => {
421 self.call_api_vertex_stream(api_key, version, body, debug)
422 .await
423 }
424 (GoogleModels::Gemini3Pro, GoogleApiEndpoints::GoogleVertex) => {
428 self.call_api_studio(api_key, version, body, debug).await
429 }
430 (GoogleModels::FineTunedEndpoint { .. }, GoogleApiEndpoints::GoogleVertex) => {
432 self.call_api_vertex(api_key, version, body, debug).await
433 }
434 #[allow(deprecated)]
436 (
437 GoogleModels::Gemini1_5ProVertex
438 | GoogleModels::Gemini1_5FlashVertex
439 | GoogleModels::Gemini1_5Flash8BVertex
440 | GoogleModels::Gemini2_0FlashVertex
441 | GoogleModels::Gemini2_0FlashLiteVertex
442 | GoogleModels::Gemini2_0ProExpVertex
443 | GoogleModels::Gemini2_0FlashThinkingExpVertex,
444 _,
445 ) => {
446 self.call_api_vertex_stream(api_key, version, body, debug)
447 .await
448 }
449 }
450 }
451
452 fn get_version_data(
453 &self,
454 response_text: &str,
455 _function_call: bool,
456 version: Option<String>,
457 ) -> Result<String> {
458 let version = version
460 .map(|version| GoogleApiEndpoints::from_str(&version))
461 .unwrap_or_default();
462
463 match (self, version) {
464 (
466 GoogleModels::Gemini1_5Pro
467 | GoogleModels::Gemini1_5Flash
468 | GoogleModels::Gemini1_5Flash8B
469 | GoogleModels::Gemini2_0Flash
470 | GoogleModels::Gemini2_0FlashLite
471 | GoogleModels::Gemini2_0ProExp
472 | GoogleModels::Gemini2_0FlashThinkingExp
473 | GoogleModels::Gemini2_5Flash
474 | GoogleModels::Gemini2_5Pro
475 | GoogleModels::Gemini2_5FlashLite
476 | GoogleModels::Gemini3Pro,
477 GoogleApiEndpoints::GoogleStudio,
478 ) => self.get_generate_content_data(response_text),
479 (GoogleModels::FineTunedEndpoint { .. }, GoogleApiEndpoints::GoogleStudio) => {
482 self.get_generate_content_data(response_text)
483 }
484 (
486 GoogleModels::Gemini1_5Pro
487 | GoogleModels::Gemini1_5Flash
488 | GoogleModels::Gemini1_5Flash8B
489 | GoogleModels::Gemini2_0Flash
490 | GoogleModels::Gemini2_0FlashLite
491 | GoogleModels::Gemini2_0ProExp
492 | GoogleModels::Gemini2_0FlashThinkingExp
493 | GoogleModels::Gemini2_5Flash
494 | GoogleModels::Gemini2_5Pro
495 | GoogleModels::Gemini2_5FlashLite
496 | GoogleModels::Gemini3Pro,
497 GoogleApiEndpoints::GoogleVertex,
498 ) => Ok(response_text.to_string()),
499 (GoogleModels::FineTunedEndpoint { .. }, GoogleApiEndpoints::GoogleVertex) => {
501 self.get_generate_content_data(response_text)
502 }
503 #[allow(deprecated)]
505 (
506 GoogleModels::Gemini1_5ProVertex
507 | GoogleModels::Gemini1_5FlashVertex
508 | GoogleModels::Gemini1_5Flash8BVertex
509 | GoogleModels::Gemini2_0FlashVertex
510 | GoogleModels::Gemini2_0FlashLiteVertex
511 | GoogleModels::Gemini2_0ProExpVertex
512 | GoogleModels::Gemini2_0FlashThinkingExpVertex,
513 _,
514 ) => Ok(response_text.to_string()),
515 }
516 }
517
518 fn get_rate_limit(&self) -> RateLimit {
520 match self {
522 GoogleModels::Gemini1_5Flash | GoogleModels::Gemini1_5FlashVertex => RateLimit {
523 tpm: 4_000_000,
524 rpm: 2_000,
525 },
526 GoogleModels::Gemini1_5Flash8B | GoogleModels::Gemini1_5Flash8BVertex => RateLimit {
527 tpm: 4_000_000,
528 rpm: 4_000,
529 },
530 GoogleModels::Gemini1_5Pro | GoogleModels::Gemini1_5ProVertex => RateLimit {
531 tpm: 4_000_000,
532 rpm: 1_000,
533 },
534 GoogleModels::Gemini2_0Flash | GoogleModels::Gemini2_0FlashVertex => RateLimit {
535 tpm: 30_000_000,
536 rpm: 30_000,
537 },
538 GoogleModels::Gemini2_0FlashLite | GoogleModels::Gemini2_0FlashLiteVertex => {
539 RateLimit {
540 tpm: 30_000_000,
541 rpm: 30_000,
542 }
543 }
544 GoogleModels::Gemini2_5Flash => RateLimit {
545 tpm: 8_000_000,
546 rpm: 10_000,
547 },
548 GoogleModels::Gemini2_5Pro => RateLimit {
549 tpm: 8_000_000,
550 rpm: 2_000,
551 },
552 GoogleModels::Gemini2_5FlashLite => RateLimit {
553 tpm: 30_000_000,
554 rpm: 30_000,
555 },
556 GoogleModels::Gemini3Pro => RateLimit {
557 tpm: 8_000_000,
558 rpm: 2_000,
559 },
560 GoogleModels::FineTunedEndpoint { .. } => RateLimit {
562 tpm: 30_000_000,
563 rpm: 30_000,
564 },
565 _ => RateLimit {
567 tpm: 120_000,
568 rpm: 360,
569 },
570 }
571 }
572
573 fn get_default_temperature(&self) -> f32 {
574 if self == &GoogleModels::Gemini3Pro {
577 1.0f32
578 } else {
579 0.0f32
580 }
581 }
582}
583
584impl GoogleModels {
585 pub fn endpoint(name: &str) -> Self {
588 GoogleModels::FineTunedEndpoint {
589 name: name.to_string(),
590 }
591 }
592
593 async fn call_api_studio(
595 &self,
596 api_key: &str,
597 version: Option<String>,
598 body: &serde_json::Value,
599 debug: bool,
600 ) -> Result<String> {
601 let model_url = self.get_version_endpoint(version);
603
604 let client = Client::new();
606
607 let url_with_key = format!("{}?key={}", model_url, api_key);
609 let response = client
610 .post(url_with_key)
611 .header(header::CONTENT_TYPE, "application/json")
612 .json(&body)
613 .send()
614 .await?;
615
616 let response_status = response.status();
617 let response_text = response.text().await?;
618
619 if debug {
620 info!(
621 "[allms][Google AI Studio] API response: [{}] {:#?}",
622 &response_status, &response_text
623 );
624 }
625
626 Ok(response_text)
627 }
628
629 async fn call_api_vertex_stream(
631 &self,
632 api_key: &str,
633 version: Option<String>,
634 body: &serde_json::Value,
635 debug: bool,
636 ) -> Result<String> {
637 let model_url = self.get_version_endpoint(version);
639
640 let client = Client::new();
642
643 let response = client
645 .post(model_url)
646 .header(header::CONTENT_TYPE, "application/json")
647 .bearer_auth(api_key)
648 .json(&body)
649 .send()
650 .await?;
651
652 if response.status().is_success() {
655 let mut stream = response.bytes_stream();
656 let mut streamed_response = String::new();
657
658 while let Some(chunk) = stream.next().await {
659 let chunk = chunk?;
660
661 let mut chunk_str = String::from_utf8(chunk.to_vec()).map_err(|e| anyhow!(e))?;
663
664 if chunk_str.starts_with("data: ") {
666 chunk_str = chunk_str[6..].to_string();
668 }
669
670 let gemini_response: GoogleGeminiProApiResp = serde_json::from_str(&chunk_str)?;
672
673 let part_text = gemini_response
675 .candidates
676 .iter()
677 .filter(|candidate| candidate.content.role.as_deref() == Some("model"))
678 .flat_map(|candidate| &candidate.content.parts)
679 .filter_map(|part| part.text.as_ref())
680 .fold(String::new(), |mut acc, text| {
681 acc.push_str(text);
682 acc
683 });
684
685 streamed_response.push_str(&part_text);
687
688 if debug {
690 info!(
691 "[allms][Google Vertex AI] Received response chunk: {:?}",
692 chunk
693 );
694 }
695 }
696 Ok(self.sanitize_json_response(&streamed_response))
697 } else {
698 let response_status = response.status();
699 let response_txt = response.text().await?;
700 Err(anyhow!(
701 "[allms][Google][{}] Response body: {:#?}",
702 response_status,
703 response_txt
704 ))
705 }
706 }
707
708 async fn call_api_vertex(
710 &self,
711 api_key: &str,
712 version: Option<String>,
713 body: &serde_json::Value,
714 debug: bool,
715 ) -> Result<String> {
716 let model_url = self.get_version_endpoint(version);
718
719 let client = Client::new();
721
722 let response = client
724 .post(model_url)
725 .header(header::CONTENT_TYPE, "application/json")
726 .bearer_auth(api_key)
727 .json(&body)
728 .send()
729 .await?;
730
731 let response_status = response.status();
732 let response_text = response.text().await?;
733
734 if debug {
735 info!(
736 "[allms][Google AI Vertex][Fine-tuned] API response: [{}] {:#?}",
737 &response_status, &response_text
738 );
739 }
740
741 Ok(response_text)
742 }
743
744 fn get_generate_content_data(&self, response_text: &str) -> Result<String> {
746 let gemini_response: GoogleGeminiProApiResp = serde_json::from_str(response_text)?;
748
749 let data = gemini_response
751 .candidates
752 .iter()
753 .filter(|candidate| candidate.content.role.as_deref() == Some("model"))
754 .flat_map(|candidate| &candidate.content.parts)
755 .filter_map(|part| part.text.as_ref())
756 .fold(String::new(), |mut acc, text| {
757 acc.push_str(text);
758 acc
759 });
760
761 Ok(self.sanitize_json_response(&data))
762 }
763
764 fn get_supported_tools(&self) -> Vec<LLMTools> {
765 match self {
766 GoogleModels::Gemini2_5Pro
767 | GoogleModels::Gemini2_5Flash
768 | GoogleModels::Gemini2_5FlashLite
769 | GoogleModels::Gemini2_0Flash
770 | GoogleModels::Gemini3Pro => vec![
771 LLMTools::GeminiCodeInterpreter(GeminiCodeInterpreterConfig::new()),
772 LLMTools::GeminiWebSearch(GeminiWebSearchConfig::new()),
773 ],
774 _ => vec![],
775 }
776 }
777
778 fn thinking_level_supported(&self) -> bool {
779 match self {
780 GoogleModels::Gemini3Pro => true,
781 GoogleModels::Gemini2_5Pro
782 | GoogleModels::Gemini2_5Flash
783 | GoogleModels::Gemini2_5FlashLite
784 | GoogleModels::Gemini2_0Flash
785 | GoogleModels::Gemini2_0FlashLite
786 | GoogleModels::Gemini2_0ProExp
787 | GoogleModels::Gemini2_0FlashThinkingExp
788 | GoogleModels::Gemini1_5Flash
789 | GoogleModels::Gemini1_5Flash8B
790 | GoogleModels::Gemini1_5Pro
791 | GoogleModels::FineTunedEndpoint { .. }
792 | GoogleModels::Gemini1_5FlashVertex
793 | GoogleModels::Gemini1_5Flash8BVertex
794 | GoogleModels::Gemini1_5ProVertex
795 | GoogleModels::Gemini2_0FlashVertex
796 | GoogleModels::Gemini2_0FlashLiteVertex
797 | GoogleModels::Gemini2_0ProExpVertex
798 | GoogleModels::Gemini2_0FlashThinkingExpVertex => false,
799 }
800 }
801}
802
803#[cfg(test)]
804mod tests {
805 use super::*;
806 use serde_json::json;
807
808 fn create_test_model() -> GoogleModels {
809 GoogleModels::Gemini1_5Pro
810 }
811
812 fn create_test_schema() -> Value {
813 json!({
814 "type": "object",
815 "properties": {
816 "answer": {
817 "type": "string"
818 }
819 }
820 })
821 }
822
823 #[test]
827 fn test_get_body_basic_functionality() {
828 let model = create_test_model();
829 let instructions = "Test instructions";
830 let json_schema = create_test_schema();
831 let function_call = false;
832 let max_tokens = 1000;
833 let temperature = 0.7;
834 let tools = None;
835 let thinking_level = None;
836
837 let body = model.get_body(
838 instructions,
839 &json_schema,
840 function_call,
841 &max_tokens,
842 &temperature,
843 tools,
844 thinking_level,
845 );
846
847 assert!(body.is_object());
849
850 assert!(body["contents"].is_object());
852 assert_eq!(body["contents"]["role"], "user");
853 assert!(body["contents"]["parts"].is_array());
854
855 assert!(body["generationConfig"].is_object());
857 assert!((body["generationConfig"]["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
858
859 assert!(body["tools"].is_null());
861 }
862
863 #[test]
864 fn test_get_body_with_instructions_content() {
865 let model = create_test_model();
866 let instructions = "Please analyze this data and provide insights";
867 let json_schema = create_test_schema();
868 let function_call = false;
869 let max_tokens = 1000;
870 let temperature = 0.5;
871 let tools = None;
872 let thinking_level = None;
873
874 let body = model.get_body(
875 instructions,
876 &json_schema,
877 function_call,
878 &max_tokens,
879 &temperature,
880 tools,
881 thinking_level,
882 );
883
884 let parts = &body["contents"]["parts"];
885 assert!(parts.is_array());
886
887 let user_instructions = parts.as_array().unwrap().iter().find(|part| {
889 part["text"]
890 .as_str()
891 .unwrap_or("")
892 .contains("<instructions>")
893 });
894
895 assert!(user_instructions.is_some());
896 let text = user_instructions.unwrap()["text"].as_str().unwrap();
897 assert!(text.contains("Please analyze this data and provide insights"));
898 assert!(text.contains("<instructions>"));
899 assert!(text.contains("</instructions>"));
900 }
901
902 #[test]
903 fn test_get_body_with_json_schema() {
904 let model = create_test_model();
905 let instructions = "Test";
906 let json_schema = json!({
907 "type": "object",
908 "properties": {
909 "result": {
910 "type": "string",
911 "description": "The result"
912 }
913 }
914 });
915 let function_call = false;
916 let max_tokens = 1000;
917 let temperature = 0.3;
918 let tools = None;
919 let thinking_level = None;
920
921 let body = model.get_body(
922 instructions,
923 &json_schema,
924 function_call,
925 &max_tokens,
926 &temperature,
927 tools,
928 thinking_level,
929 );
930
931 let parts = &body["contents"]["parts"];
932 let output_schema_part = parts.as_array().unwrap().iter().find(|part| {
933 part["text"]
934 .as_str()
935 .unwrap_or("")
936 .contains("<output json schema>")
937 });
938
939 assert!(output_schema_part.is_some());
940 let text = output_schema_part.unwrap()["text"].as_str().unwrap();
941 assert!(text.contains("<output json schema>"));
942 assert!(text.contains("</output json schema>"));
943 assert!(text.contains("type"));
945 assert!(text.contains("object"));
946 }
947
948 #[test]
949 fn test_get_body_with_tools() {
950 let model = GoogleModels::Gemini2_5Pro; let instructions = "Search for information";
952 let json_schema = create_test_schema();
953 let function_call = false;
954 let max_tokens = 1000;
955 let temperature = 0.7;
956 let tools_array = [
957 LLMTools::GeminiWebSearch(GeminiWebSearchConfig::new()),
958 LLMTools::GeminiCodeInterpreter(GeminiCodeInterpreterConfig::new()),
959 ];
960 let tools = Some(&tools_array[..]);
961 let thinking_level = None;
962
963 let body = model.get_body(
964 instructions,
965 &json_schema,
966 function_call,
967 &max_tokens,
968 &temperature,
969 tools,
970 thinking_level,
971 );
972
973 assert!(body["tools"].is_array());
975 let tools_array = body["tools"].as_array().unwrap();
976 assert!(!tools_array.is_empty());
977 }
978
979 #[test]
980 fn test_get_body_with_unsupported_tools() {
981 let model = GoogleModels::Gemini1_5Flash; let instructions = "Test";
983 let json_schema = create_test_schema();
984 let function_call = false;
985 let max_tokens = 1000;
986 let temperature = 0.7;
987 let tools_array = [LLMTools::GeminiWebSearch(GeminiWebSearchConfig::new())];
988 let tools = Some(&tools_array[..]);
989 let thinking_level = None;
990
991 let body = model.get_body(
992 instructions,
993 &json_schema,
994 function_call,
995 &max_tokens,
996 &temperature,
997 tools,
998 thinking_level,
999 );
1000
1001 assert!(body["tools"].is_null());
1003 }
1004
1005 #[test]
1006 fn test_get_body_with_web_search_context() {
1007 let model = GoogleModels::Gemini2_5Pro;
1008 let instructions = "Search for information";
1009 let json_schema = create_test_schema();
1010 let function_call = false;
1011 let max_tokens = 1000;
1012 let temperature = 0.7;
1013 let thinking_level = None;
1014
1015 let web_search_config = GeminiWebSearchConfig::new();
1017 let tools_array = [LLMTools::GeminiWebSearch(web_search_config)];
1020 let tools = Some(&tools_array[..]);
1021
1022 let body = model.get_body(
1023 instructions,
1024 &json_schema,
1025 function_call,
1026 &max_tokens,
1027 &temperature,
1028 tools,
1029 thinking_level,
1030 );
1031
1032 assert!(body["contents"].is_object());
1034 assert!(body["generationConfig"].is_object());
1035 }
1036
1037 #[test]
1038 fn test_get_body_temperature_values() {
1039 let model = create_test_model();
1040 let instructions = "Test";
1041 let json_schema = create_test_schema();
1042 let function_call = false;
1043 let max_tokens = 1000;
1044 let tools = None;
1045 let thinking_level = None;
1046
1047 let temperatures = vec![0.0, 0.5, 1.0, 1.5];
1049
1050 for temp in temperatures {
1051 let body = model.get_body(
1052 instructions,
1053 &json_schema,
1054 function_call,
1055 &max_tokens,
1056 &temp,
1057 tools,
1058 thinking_level,
1059 );
1060
1061 assert_eq!(body["generationConfig"]["temperature"], temp);
1062 }
1063 }
1064
1065 #[test]
1066 fn test_get_body_function_call_true() {
1067 let model = create_test_model();
1068 let instructions = "Test with function calling";
1069 let json_schema = create_test_schema();
1070 let function_call = true;
1071 let max_tokens = 1000;
1072 let temperature = 0.7;
1073 let tools = None;
1074 let thinking_level = None;
1075
1076 let body = model.get_body(
1077 instructions,
1078 &json_schema,
1079 function_call,
1080 &max_tokens,
1081 &temperature,
1082 tools,
1083 thinking_level,
1084 );
1085
1086 let parts = &body["contents"]["parts"];
1089 assert!(parts.is_array());
1090
1091 let base_instructions = parts
1093 .as_array()
1094 .unwrap()
1095 .iter()
1096 .find(|part| part["text"].as_str().unwrap_or("").contains("text"));
1097
1098 assert!(base_instructions.is_some());
1099 }
1100
1101 #[test]
1102 fn test_get_body_empty_instructions() {
1103 let model = create_test_model();
1104 let instructions = "";
1105 let json_schema = create_test_schema();
1106 let function_call = false;
1107 let max_tokens = 1000;
1108 let temperature = 0.7;
1109 let tools = None;
1110 let thinking_level = None;
1111
1112 let body = model.get_body(
1113 instructions,
1114 &json_schema,
1115 function_call,
1116 &max_tokens,
1117 &temperature,
1118 tools,
1119 thinking_level,
1120 );
1121
1122 assert!(body.is_object());
1124 assert!(body["contents"].is_object());
1125 assert!(body["generationConfig"].is_object());
1126 }
1127
1128 #[test]
1129 fn test_get_body_complex_json_schema() {
1130 let model = create_test_model();
1131 let instructions = "Test";
1132 let json_schema = json!({
1133 "type": "object",
1134 "properties": {
1135 "name": {
1136 "type": "string",
1137 "description": "The name"
1138 },
1139 "age": {
1140 "type": "integer",
1141 "minimum": 0
1142 },
1143 "items": {
1144 "type": "array",
1145 "items": {
1146 "type": "string"
1147 }
1148 }
1149 },
1150 "required": ["name"]
1151 });
1152 let function_call = false;
1153 let max_tokens = 1000;
1154 let temperature = 0.7;
1155 let tools = None;
1156 let thinking_level = None;
1157
1158 let body = model.get_body(
1159 instructions,
1160 &json_schema,
1161 function_call,
1162 &max_tokens,
1163 &temperature,
1164 tools,
1165 thinking_level,
1166 );
1167
1168 let parts = &body["contents"]["parts"];
1170 let output_schema_part = parts.as_array().unwrap().iter().find(|part| {
1171 part["text"]
1172 .as_str()
1173 .unwrap_or("")
1174 .contains("<output json schema>")
1175 });
1176
1177 assert!(output_schema_part.is_some());
1178 let text = output_schema_part.unwrap()["text"].as_str().unwrap();
1179 assert!(text.contains("type"));
1181 assert!(text.contains("object"));
1182 assert!(text.contains("required"));
1183 assert!(text.contains("name"));
1184 }
1185}