1use std::sync::Arc;
44
45use crate::{
46 builder::LLMBackend,
47 chat::{
48 ChatMessage, ChatProvider, ChatResponse, ChatRole, MessageType, StructuredOutputFormat,
49 Tool, Usage,
50 },
51 completion::{CompletionProvider, CompletionRequest, CompletionResponse},
52 embedding::EmbeddingProvider,
53 error::LLMError,
54 models::{ModelListRawEntry, ModelListRequest, ModelListResponse, ModelsProvider},
55 stt::SpeechToTextProvider,
56 tts::TextToSpeechProvider,
57 FunctionCall, LLMProvider, ToolCall,
58};
59use async_trait::async_trait;
60use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
61use chrono::{DateTime, Utc};
62use futures::{stream::Stream, StreamExt};
63use reqwest::Client;
64use serde::{Deserialize, Serialize};
65use serde_json::Value;
66
67#[derive(Debug)]
69pub struct GoogleConfig {
70 pub api_key: String,
72 pub model: String,
74 pub max_tokens: Option<u32>,
76 pub temperature: Option<f32>,
78 pub system: Option<String>,
80 pub timeout_seconds: Option<u64>,
82 pub top_p: Option<f32>,
84 pub top_k: Option<u32>,
86 pub json_schema: Option<StructuredOutputFormat>,
88 pub tools: Option<Vec<Tool>>,
90}
91
92#[derive(Debug, Clone)]
100pub struct Google {
101 pub config: Arc<GoogleConfig>,
103 pub client: Client,
105}
106
107#[derive(Serialize)]
109struct GoogleChatRequest<'a> {
110 contents: Vec<GoogleChatContent<'a>>,
112 #[serde(skip_serializing_if = "Option::is_none", rename = "generationConfig")]
114 generation_config: Option<GoogleGenerationConfig>,
115 #[serde(skip_serializing_if = "Option::is_none")]
117 tools: Option<Vec<GoogleTool>>,
118}
119
120#[derive(Serialize)]
122struct GoogleChatContent<'a> {
123 role: &'a str,
125 parts: Vec<GoogleContentPart<'a>>,
127}
128
129#[derive(Serialize)]
131#[serde(rename_all = "camelCase")]
132enum GoogleContentPart<'a> {
133 #[serde(rename = "text")]
135 Text(&'a str),
136 InlineData(GoogleInlineData),
137 FunctionCall(GoogleFunctionCall),
138 #[serde(rename = "functionResponse")]
139 FunctionResponse(GoogleFunctionResponse),
140}
141
142#[derive(Serialize)]
143struct GoogleInlineData {
144 mime_type: String,
145 data: String,
146}
147
148#[derive(Serialize)]
150struct GoogleGenerationConfig {
151 #[serde(skip_serializing_if = "Option::is_none", rename = "maxOutputTokens")]
153 max_output_tokens: Option<u32>,
154 #[serde(skip_serializing_if = "Option::is_none")]
156 temperature: Option<f32>,
157 #[serde(skip_serializing_if = "Option::is_none", rename = "topP")]
159 top_p: Option<f32>,
160 #[serde(skip_serializing_if = "Option::is_none", rename = "topK")]
162 top_k: Option<u32>,
163 #[serde(skip_serializing_if = "Option::is_none")]
165 response_mime_type: Option<GoogleResponseMimeType>,
166 #[serde(skip_serializing_if = "Option::is_none")]
168 response_schema: Option<Value>,
169}
170
171#[derive(Deserialize, Debug)]
173struct GoogleChatResponse {
174 candidates: Vec<GoogleCandidate>,
176 #[serde(rename = "usageMetadata")]
178 usage_metadata: Option<GoogleUsageMetadata>,
179}
180
181#[derive(Deserialize, Debug)]
183struct GoogleUsageMetadata {
184 #[serde(rename = "promptTokenCount")]
186 prompt_token_count: Option<u32>,
187 #[serde(rename = "candidatesTokenCount")]
189 candidates_token_count: Option<u32>,
190 #[serde(rename = "totalTokenCount")]
192 total_token_count: Option<u32>,
193}
194
195#[derive(Deserialize, Debug)]
197struct GoogleStreamResponse {
198 candidates: Option<Vec<GoogleCandidate>>,
200 #[serde(rename = "usageMetadata")]
202 usage_metadata: Option<GoogleUsageMetadata>,
203}
204
205impl std::fmt::Display for GoogleChatResponse {
206 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207 match (self.text(), self.tool_calls()) {
208 (Some(text), Some(tool_calls)) => {
209 for call in tool_calls {
210 write!(f, "{call}")?;
211 }
212 write!(f, "{text}")
213 }
214 (Some(text), None) => write!(f, "{text}"),
215 (None, Some(tool_calls)) => {
216 for call in tool_calls {
217 write!(f, "{call}")?;
218 }
219 Ok(())
220 }
221 (None, None) => write!(f, ""),
222 }
223 }
224}
225
226#[derive(Deserialize, Debug)]
228struct GoogleCandidate {
229 content: GoogleResponseContent,
231}
232
233#[derive(Deserialize, Debug)]
235struct GoogleResponseContent {
236 #[serde(default)]
238 parts: Vec<GoogleResponsePart>,
239 #[serde(skip_serializing_if = "Option::is_none")]
241 function_call: Option<GoogleFunctionCall>,
242 #[serde(skip_serializing_if = "Option::is_none")]
244 function_calls: Option<Vec<GoogleFunctionCall>>,
245}
246
247impl ChatResponse for GoogleChatResponse {
248 fn text(&self) -> Option<String> {
249 self.candidates
250 .first()
251 .map(|c| c.content.parts.iter().map(|p| p.text.clone()).collect())
252 }
253
254 fn tool_calls(&self) -> Option<Vec<ToolCall>> {
255 self.candidates.first().and_then(|c| {
256 let part_function_calls: Vec<ToolCall> = c
258 .content
259 .parts
260 .iter()
261 .filter_map(|part| {
262 part.function_call.as_ref().map(|f| ToolCall {
263 id: format!("call_{}", f.name),
264 call_type: "function".to_string(),
265 function: FunctionCall {
266 name: f.name.clone(),
267 arguments: serde_json::to_string(&f.args).unwrap_or_default(),
268 },
269 })
270 })
271 .collect();
272
273 if !part_function_calls.is_empty() {
274 return Some(part_function_calls);
275 }
276
277 if let Some(fc) = &c.content.function_calls {
279 Some(
281 fc.iter()
282 .map(|f| ToolCall {
283 id: format!("call_{}", f.name),
284 call_type: "function".to_string(),
285 function: FunctionCall {
286 name: f.name.clone(),
287 arguments: serde_json::to_string(&f.args).unwrap_or_default(),
288 },
289 })
290 .collect(),
291 )
292 } else {
293 c.content.function_call.as_ref().map(|f| {
294 vec![ToolCall {
295 id: format!("call_{}", f.name),
296 call_type: "function".to_string(),
297 function: FunctionCall {
298 name: f.name.clone(),
299 arguments: serde_json::to_string(&f.args).unwrap_or_default(),
300 },
301 }]
302 })
303 }
304 })
305 }
306
307 fn usage(&self) -> Option<Usage> {
308 self.usage_metadata.as_ref().and_then(|metadata| {
309 match (metadata.prompt_token_count, metadata.candidates_token_count) {
310 (Some(prompt_tokens), Some(completion_tokens)) => Some(Usage {
311 prompt_tokens,
312 completion_tokens,
313 total_tokens: metadata
314 .total_token_count
315 .unwrap_or(prompt_tokens + completion_tokens),
316 completion_tokens_details: None,
317 prompt_tokens_details: None,
318 }),
319 _ => None,
320 }
321 })
322 }
323}
324
325#[derive(Deserialize, Debug)]
327struct GoogleResponsePart {
328 #[serde(default)]
330 text: String,
331 #[serde(rename = "functionCall")]
333 function_call: Option<GoogleFunctionCall>,
334}
335
336#[derive(Deserialize, Debug, Serialize)]
338enum GoogleResponseMimeType {
339 #[serde(rename = "text/plain")]
341 PlainText,
342 #[serde(rename = "application/json")]
344 Json,
345 #[serde(rename = "text/x.enum")]
347 Enum,
348}
349
350#[derive(Serialize, Debug)]
352struct GoogleTool {
353 #[serde(rename = "functionDeclarations")]
355 function_declarations: Vec<GoogleFunctionDeclaration>,
356}
357
358#[derive(Serialize, Debug)]
360struct GoogleFunctionDeclaration {
361 name: String,
363 description: String,
365 parameters: GoogleFunctionParameters,
367}
368
369impl From<&crate::chat::Tool> for GoogleFunctionDeclaration {
370 fn from(tool: &crate::chat::Tool) -> Self {
371 let properties_value = tool
372 .function
373 .parameters
374 .get("properties")
375 .cloned()
376 .unwrap_or_else(|| serde_json::Value::Object(serde_json::Map::new()));
377
378 GoogleFunctionDeclaration {
379 name: tool.function.name.clone(),
380 description: tool.function.description.clone(),
381 parameters: GoogleFunctionParameters {
382 schema_type: "object".to_string(),
383 properties: properties_value,
384 required: tool
385 .function
386 .parameters
387 .get("required")
388 .and_then(|v| v.as_array())
389 .map(|arr| {
390 arr.iter()
391 .filter_map(|v| v.as_str().map(|s| s.to_string()))
392 .collect::<Vec<String>>()
393 })
394 .unwrap_or_default(),
395 },
396 }
397 }
398}
399
400#[derive(Serialize, Debug)]
402struct GoogleFunctionParameters {
403 #[serde(rename = "type")]
405 schema_type: String,
406 properties: Value,
408 required: Vec<String>,
410}
411
412#[derive(Deserialize, Debug, Serialize)]
414struct GoogleFunctionCall {
415 name: String,
417 #[serde(default)]
419 args: Value,
420}
421
422#[derive(Deserialize, Debug, Serialize)]
441struct GoogleFunctionResponse {
442 name: String,
444 response: GoogleFunctionResponseContent,
446}
447
448#[derive(Deserialize, Debug, Serialize)]
449struct GoogleFunctionResponseContent {
450 name: String,
452 content: Value,
454}
455
456#[derive(Serialize)]
458struct GoogleEmbeddingRequest<'a> {
459 model: &'a str,
460 content: GoogleEmbeddingContent<'a>,
461}
462
463#[derive(Serialize)]
464struct GoogleEmbeddingContent<'a> {
465 parts: Vec<GoogleContentPart<'a>>,
466}
467
468#[derive(Deserialize)]
470struct GoogleEmbeddingResponse {
471 embedding: GoogleEmbedding,
472}
473
474#[derive(Deserialize)]
475struct GoogleEmbedding {
476 values: Vec<f32>,
477}
478
479impl Google {
480 #[allow(clippy::too_many_arguments)]
499 pub fn new(
500 api_key: impl Into<String>,
501 model: Option<String>,
502 max_tokens: Option<u32>,
503 temperature: Option<f32>,
504 timeout_seconds: Option<u64>,
505 system: Option<String>,
506 top_p: Option<f32>,
507 top_k: Option<u32>,
508 json_schema: Option<StructuredOutputFormat>,
509 tools: Option<Vec<Tool>>,
510 ) -> Self {
511 let mut builder = Client::builder();
512 if let Some(sec) = timeout_seconds {
513 builder = builder.timeout(std::time::Duration::from_secs(sec));
514 }
515 Self::with_client(
516 builder.build().expect("Failed to build reqwest Client"),
517 api_key,
518 model,
519 max_tokens,
520 temperature,
521 timeout_seconds,
522 system,
523 top_p,
524 top_k,
525 json_schema,
526 tools,
527 )
528 }
529
530 #[allow(clippy::too_many_arguments)]
532 pub fn with_client(
533 client: Client,
534 api_key: impl Into<String>,
535 model: Option<String>,
536 max_tokens: Option<u32>,
537 temperature: Option<f32>,
538 timeout_seconds: Option<u64>,
539 system: Option<String>,
540 top_p: Option<f32>,
541 top_k: Option<u32>,
542 json_schema: Option<StructuredOutputFormat>,
543 tools: Option<Vec<Tool>>,
544 ) -> Self {
545 Self {
546 config: Arc::new(GoogleConfig {
547 api_key: api_key.into(),
548 model: model.unwrap_or_else(|| "gemini-1.5-flash".to_string()),
549 max_tokens,
550 temperature,
551 system,
552 timeout_seconds,
553 top_p,
554 top_k,
555 json_schema,
556 tools,
557 }),
558 client,
559 }
560 }
561
562 pub fn api_key(&self) -> &str {
563 &self.config.api_key
564 }
565
566 pub fn model(&self) -> &str {
567 &self.config.model
568 }
569
570 pub fn max_tokens(&self) -> Option<u32> {
571 self.config.max_tokens
572 }
573
574 pub fn temperature(&self) -> Option<f32> {
575 self.config.temperature
576 }
577
578 pub fn timeout_seconds(&self) -> Option<u64> {
579 self.config.timeout_seconds
580 }
581
582 pub fn system(&self) -> Option<&str> {
583 self.config.system.as_deref()
584 }
585
586 pub fn top_p(&self) -> Option<f32> {
587 self.config.top_p
588 }
589
590 pub fn top_k(&self) -> Option<u32> {
591 self.config.top_k
592 }
593
594 pub fn json_schema(&self) -> Option<&StructuredOutputFormat> {
595 self.config.json_schema.as_ref()
596 }
597
598 pub fn tools(&self) -> Option<&[Tool]> {
599 self.config.tools.as_deref()
600 }
601
602 pub fn client(&self) -> &Client {
603 &self.client
604 }
605}
606
607#[async_trait]
608impl ChatProvider for Google {
609 async fn chat(&self, messages: &[ChatMessage]) -> Result<Box<dyn ChatResponse>, LLMError> {
619 if self.config.api_key.is_empty() {
620 return Err(LLMError::AuthError("Missing Google API key".to_string()));
621 }
622
623 let mut chat_contents = Vec::with_capacity(messages.len());
624
625 if let Some(system) = &self.config.system {
627 chat_contents.push(GoogleChatContent {
628 role: "user",
629 parts: vec![GoogleContentPart::Text(system)],
630 });
631 }
632
633 for msg in messages {
635 let role = match &msg.message_type {
637 MessageType::ToolResult(_) => "function",
638 _ => match msg.role {
639 ChatRole::User => "user",
640 ChatRole::Assistant => "model",
641 },
642 };
643
644 chat_contents.push(GoogleChatContent {
645 role,
646 parts: match &msg.message_type {
647 MessageType::Text => vec![GoogleContentPart::Text(&msg.content)],
648 MessageType::Image((image_mime, raw_bytes)) => {
649 vec![GoogleContentPart::InlineData(GoogleInlineData {
650 mime_type: image_mime.mime_type().to_string(),
651 data: BASE64.encode(raw_bytes),
652 })]
653 }
654 MessageType::ImageURL(_) => unimplemented!(),
655 MessageType::Pdf(raw_bytes) => {
656 vec![GoogleContentPart::InlineData(GoogleInlineData {
657 mime_type: "application/pdf".to_string(),
658 data: BASE64.encode(raw_bytes),
659 })]
660 }
661 MessageType::ToolUse(calls) => calls
662 .iter()
663 .map(|call| {
664 GoogleContentPart::FunctionCall(GoogleFunctionCall {
665 name: call.function.name.clone(),
666 args: serde_json::from_str(&call.function.arguments)
667 .unwrap_or(serde_json::Value::Null),
668 })
669 })
670 .collect(),
671 MessageType::ToolResult(result) => result
672 .iter()
673 .map(|result| {
674 let parsed_args =
675 serde_json::from_str::<Value>(&result.function.arguments)
676 .unwrap_or(serde_json::Value::Null);
677
678 GoogleContentPart::FunctionResponse(GoogleFunctionResponse {
679 name: result.function.name.clone(),
680 response: GoogleFunctionResponseContent {
681 name: result.function.name.clone(),
682 content: parsed_args,
683 },
684 })
685 })
686 .collect(),
687 },
688 });
689 }
690
691 let generation_config = if self.config.max_tokens.is_none()
693 && self.config.temperature.is_none()
694 && self.config.top_p.is_none()
695 && self.config.top_k.is_none()
696 && self.config.json_schema.is_none()
697 {
698 None
699 } else {
700 let (response_mime_type, response_schema) =
703 if let Some(json_schema) = &self.config.json_schema {
704 if let Some(schema) = &json_schema.schema {
705 let mut schema = schema.clone();
707 if let Some(obj) = schema.as_object_mut() {
708 obj.remove("additionalProperties");
709 }
710 (Some(GoogleResponseMimeType::Json), Some(schema))
711 } else {
712 (None, None)
713 }
714 } else {
715 (None, None)
716 };
717 Some(GoogleGenerationConfig {
718 max_output_tokens: self.config.max_tokens,
719 temperature: self.config.temperature,
720 top_p: self.config.top_p,
721 top_k: self.config.top_k,
722 response_mime_type,
723 response_schema,
724 })
725 };
726
727 let req_body = GoogleChatRequest {
728 contents: chat_contents,
729 generation_config,
730 tools: None,
731 };
732 if log::log_enabled!(log::Level::Trace) {
733 if let Ok(json) = serde_json::to_string(&req_body) {
734 log::trace!("Google Gemini request payload: {}", json);
735 }
736 }
737
738 let url = format!(
739 "https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={key}",
740 model = self.config.model,
741 key = self.config.api_key
742 );
743
744 let mut request = self.client.post(&url).json(&req_body);
745 if let Some(timeout) = self.config.timeout_seconds {
746 request = request.timeout(std::time::Duration::from_secs(timeout));
747 }
748
749 let resp = request.send().await?;
750 log::debug!("Google Gemini HTTP status: {}", resp.status());
751 let resp = resp.error_for_status()?;
752
753 let resp_text = resp.text().await?;
755
756 let json_resp: Result<GoogleChatResponse, serde_json::Error> =
758 serde_json::from_str(&resp_text);
759
760 match json_resp {
761 Ok(response) => Ok(Box::new(response)),
762 Err(e) => {
763 Err(LLMError::ResponseFormatError {
765 message: format!("Failed to decode Google API response: {e}"),
766 raw_response: resp_text,
767 })
768 }
769 }
770 }
771
772 async fn chat_with_tools(
783 &self,
784 messages: &[ChatMessage],
785 tools: Option<&[Tool]>,
786 ) -> Result<Box<dyn ChatResponse>, LLMError> {
787 if self.config.api_key.is_empty() {
788 return Err(LLMError::AuthError("Missing Google API key".to_string()));
789 }
790
791 let mut chat_contents = Vec::with_capacity(messages.len());
792
793 if let Some(system) = &self.config.system {
795 chat_contents.push(GoogleChatContent {
796 role: "user",
797 parts: vec![GoogleContentPart::Text(system)],
798 });
799 }
800
801 for msg in messages {
803 let role = match &msg.message_type {
805 MessageType::ToolResult(_) => "function",
806 _ => match msg.role {
807 ChatRole::User => "user",
808 ChatRole::Assistant => "model",
809 },
810 };
811
812 chat_contents.push(GoogleChatContent {
813 role,
814 parts: match &msg.message_type {
815 MessageType::Text => vec![GoogleContentPart::Text(&msg.content)],
816 MessageType::Image((image_mime, raw_bytes)) => {
817 vec![GoogleContentPart::InlineData(GoogleInlineData {
818 mime_type: image_mime.mime_type().to_string(),
819 data: BASE64.encode(raw_bytes),
820 })]
821 }
822 MessageType::ImageURL(_) => unimplemented!(),
823 MessageType::Pdf(raw_bytes) => {
824 vec![GoogleContentPart::InlineData(GoogleInlineData {
825 mime_type: "application/pdf".to_string(),
826 data: BASE64.encode(raw_bytes),
827 })]
828 }
829 MessageType::ToolUse(calls) => calls
830 .iter()
831 .map(|call| {
832 GoogleContentPart::FunctionCall(GoogleFunctionCall {
833 name: call.function.name.clone(),
834 args: serde_json::from_str(&call.function.arguments)
835 .unwrap_or(serde_json::Value::Null),
836 })
837 })
838 .collect(),
839 MessageType::ToolResult(result) => result
840 .iter()
841 .map(|result| {
842 let parsed_args =
843 serde_json::from_str::<Value>(&result.function.arguments)
844 .unwrap_or(serde_json::Value::Null);
845
846 GoogleContentPart::FunctionResponse(GoogleFunctionResponse {
847 name: result.function.name.clone(),
848 response: GoogleFunctionResponseContent {
849 name: result.function.name.clone(),
850 content: parsed_args,
851 },
852 })
853 })
854 .collect(),
855 },
856 });
857 }
858
859 let google_tools = tools.map(|t| {
861 vec![GoogleTool {
862 function_declarations: t.iter().map(GoogleFunctionDeclaration::from).collect(),
863 }]
864 });
865
866 let generation_config = {
868 let (response_mime_type, response_schema) =
871 if let Some(json_schema) = &self.config.json_schema {
872 if let Some(schema) = &json_schema.schema {
873 let mut schema = schema.clone();
875
876 if let Some(obj) = schema.as_object_mut() {
877 obj.remove("additionalProperties");
878 }
879
880 (Some(GoogleResponseMimeType::Json), Some(schema))
881 } else {
882 (None, None)
883 }
884 } else {
885 (None, None)
886 };
887
888 Some(GoogleGenerationConfig {
889 max_output_tokens: self.config.max_tokens,
890 temperature: self.config.temperature,
891 top_p: self.config.top_p,
892 top_k: self.config.top_k,
893 response_mime_type,
894 response_schema,
895 })
896 };
897
898 let req_body = GoogleChatRequest {
899 contents: chat_contents,
900 generation_config,
901 tools: google_tools,
902 };
903
904 if log::log_enabled!(log::Level::Trace) {
905 if let Ok(json) = serde_json::to_string(&req_body) {
906 log::trace!("Google Gemini request payload (tool): {}", json);
907 }
908 }
909
910 let url = format!(
911 "https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={key}",
912 model = self.config.model,
913 key = self.config.api_key
914
915 );
916
917 let mut request = self.client.post(&url).json(&req_body);
918
919 if let Some(timeout) = self.config.timeout_seconds {
920 request = request.timeout(std::time::Duration::from_secs(timeout));
921 }
922
923 let resp = request.send().await?;
924
925 log::debug!("Google Gemini HTTP status (tool): {}", resp.status());
926
927 let resp = resp.error_for_status()?;
928
929 let resp_text = resp.text().await?;
931
932 let json_resp: Result<GoogleChatResponse, serde_json::Error> =
934 serde_json::from_str(&resp_text);
935
936 match json_resp {
937 Ok(response) => Ok(Box::new(response)),
938 Err(e) => {
939 Err(LLMError::ResponseFormatError {
941 message: format!("Failed to decode Google API response: {e}"),
942 raw_response: resp_text,
943 })
944 }
945 }
946 }
947
948 async fn chat_stream(
958 &self,
959 messages: &[ChatMessage],
960 ) -> Result<std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>, LLMError>
961 {
962 let struct_stream = self.chat_stream_struct(messages).await?;
963 let content_stream = struct_stream.filter_map(|result| async move {
964 match result {
965 Ok(stream_response) => {
966 if let Some(choice) = stream_response.choices.first() {
967 if let Some(content) = &choice.delta.content {
968 if !content.is_empty() {
969 return Some(Ok(content.clone()));
970 }
971 }
972 }
973 None
974 }
975 Err(e) => Some(Err(e)),
976 }
977 });
978 Ok(Box::pin(content_stream))
979 }
980
981 async fn chat_stream_struct(
991 &self,
992 messages: &[ChatMessage],
993 ) -> Result<
994 std::pin::Pin<Box<dyn Stream<Item = Result<crate::chat::StreamResponse, LLMError>> + Send>>,
995 LLMError,
996 > {
997 if self.config.api_key.is_empty() {
998 return Err(LLMError::AuthError("Missing Google API key".to_string()));
999 }
1000 let mut chat_contents = Vec::with_capacity(messages.len());
1001 if let Some(system) = &self.config.system {
1002 chat_contents.push(GoogleChatContent {
1003 role: "user",
1004 parts: vec![GoogleContentPart::Text(system)],
1005 });
1006 }
1007 for msg in messages {
1008 let role = match msg.role {
1009 ChatRole::User => "user",
1010 ChatRole::Assistant => "model",
1011 };
1012 chat_contents.push(GoogleChatContent {
1013 role,
1014 parts: match &msg.message_type {
1015 MessageType::Text => vec![GoogleContentPart::Text(&msg.content)],
1016 MessageType::Image((image_mime, raw_bytes)) => {
1017 vec![GoogleContentPart::InlineData(GoogleInlineData {
1018 mime_type: image_mime.mime_type().to_string(),
1019 data: BASE64.encode(raw_bytes),
1020 })]
1021 }
1022 MessageType::Pdf(raw_bytes) => {
1023 vec![GoogleContentPart::InlineData(GoogleInlineData {
1024 mime_type: "application/pdf".to_string(),
1025 data: BASE64.encode(raw_bytes),
1026 })]
1027 }
1028 _ => vec![GoogleContentPart::Text(&msg.content)],
1029 },
1030 });
1031 }
1032 let generation_config = if self.config.max_tokens.is_none()
1033 && self.config.temperature.is_none()
1034 && self.config.top_p.is_none()
1035 && self.config.top_k.is_none()
1036 {
1037 None
1038 } else {
1039 Some(GoogleGenerationConfig {
1040 max_output_tokens: self.config.max_tokens,
1041 temperature: self.config.temperature,
1042 top_p: self.config.top_p,
1043 top_k: self.config.top_k,
1044 response_mime_type: None,
1045 response_schema: None,
1046 })
1047 };
1048
1049 let req_body = GoogleChatRequest {
1050 contents: chat_contents,
1051 generation_config,
1052 tools: None,
1053 };
1054 let url = format!(
1055 "https://generativelanguage.googleapis.com/v1beta/models/{model}:streamGenerateContent?alt=sse&key={key}",
1056 model = self.config.model,
1057 key = self.config.api_key
1058 );
1059
1060 let mut request = self.client.post(&url).json(&req_body);
1061 if let Some(timeout) = self.config.timeout_seconds {
1062 request = request.timeout(std::time::Duration::from_secs(timeout));
1063 }
1064 let response = request.send().await?;
1065 if !response.status().is_success() {
1066 let status = response.status();
1067 let error_text = response.text().await?;
1068 return Err(LLMError::ResponseFormatError {
1069 message: format!("Google API returned error status: {status}"),
1070 raw_response: error_text,
1071 });
1072 }
1073 Ok(create_google_sse_stream(response))
1074 }
1075}
1076
1077#[async_trait]
1078impl CompletionProvider for Google {
1079 async fn complete(&self, req: &CompletionRequest) -> Result<CompletionResponse, LLMError> {
1089 let chat_message = ChatMessage::user().content(req.prompt.clone()).build();
1090 if let Some(text) = self.chat(&[chat_message]).await?.text() {
1091 Ok(CompletionResponse { text })
1092 } else {
1093 Err(LLMError::ProviderError(
1094 "No answer returned by Google".to_string(),
1095 ))
1096 }
1097 }
1098}
1099
1100#[async_trait]
1101impl EmbeddingProvider for Google {
1102 async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
1103 if self.config.api_key.is_empty() {
1104 return Err(LLMError::AuthError("Missing Google API key".to_string()));
1105 }
1106
1107 let mut embeddings = Vec::new();
1108
1109 for text in texts {
1111 let req_body = GoogleEmbeddingRequest {
1112 model: "models/text-embedding-004",
1113 content: GoogleEmbeddingContent {
1114 parts: vec![GoogleContentPart::Text(&text)],
1115 },
1116 };
1117
1118 let url = format!(
1119 "https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent?key={}",
1120 self.config.api_key
1121 );
1122
1123 let resp = self
1124 .client
1125 .post(&url)
1126 .json(&req_body)
1127 .send()
1128 .await?
1129 .error_for_status()?;
1130
1131 let embedding_resp: GoogleEmbeddingResponse = resp.json().await?;
1132 embeddings.push(embedding_resp.embedding.values);
1133 }
1134 Ok(embeddings)
1135 }
1136}
1137
1138#[async_trait]
1139impl SpeechToTextProvider for Google {
1140 async fn transcribe(&self, _audio: Vec<u8>) -> Result<String, LLMError> {
1141 Err(LLMError::ProviderError(
1142 "Google does not implement speech to text endpoint yet.".into(),
1143 ))
1144 }
1145}
1146
1147impl LLMProvider for Google {
1148 fn tools(&self) -> Option<&[Tool]> {
1149 self.config.tools.as_deref()
1150 }
1151}
1152
1153fn create_google_sse_stream(
1163 response: reqwest::Response,
1164) -> std::pin::Pin<Box<dyn Stream<Item = Result<crate::chat::StreamResponse, LLMError>> + Send>> {
1165 let stream = response
1166 .bytes_stream()
1167 .map(move |chunk| match chunk {
1168 Ok(bytes) => {
1169 let text = String::from_utf8_lossy(&bytes);
1170 parse_google_sse_chunk(&text)
1171 }
1172 Err(e) => Err(LLMError::HttpError(e.to_string())),
1173 })
1174 .filter_map(|result| async move {
1175 match result {
1176 Ok(Some(response)) => Some(Ok(response)),
1177 Ok(None) => None,
1178 Err(e) => Some(Err(e)),
1179 }
1180 });
1181 Box::pin(stream)
1182}
1183
1184fn parse_google_sse_chunk(chunk: &str) -> Result<Option<crate::chat::StreamResponse>, LLMError> {
1196 for line in chunk.lines() {
1197 let line = line.trim();
1198 if let Some(data) = line.strip_prefix("data: ") {
1199 match serde_json::from_str::<GoogleStreamResponse>(data) {
1200 Ok(response) => {
1201 let mut content = None;
1202 let mut usage = None;
1203 if let Some(candidates) = &response.candidates {
1205 if let Some(candidate) = candidates.first() {
1206 if let Some(part) = candidate.content.parts.first() {
1207 if !part.text.is_empty() {
1208 content = Some(part.text.clone());
1209 }
1210 }
1211 }
1212 }
1213 if let Some(usage_metadata) = &response.usage_metadata {
1215 if let (Some(prompt_tokens), Some(completion_tokens)) = (
1216 usage_metadata.prompt_token_count,
1217 usage_metadata.candidates_token_count,
1218 ) {
1219 usage = Some(Usage {
1220 prompt_tokens,
1221 completion_tokens,
1222 total_tokens: usage_metadata
1223 .total_token_count
1224 .unwrap_or(prompt_tokens + completion_tokens),
1225 completion_tokens_details: None,
1226 prompt_tokens_details: None,
1227 });
1228 }
1229 }
1230 if content.is_some() || usage.is_some() {
1232 return Ok(Some(crate::chat::StreamResponse {
1233 choices: vec![crate::chat::StreamChoice {
1234 delta: crate::chat::StreamDelta {
1235 content,
1236 tool_calls: None,
1237 },
1238 }],
1239 usage,
1240 }));
1241 }
1242 return Ok(None);
1243 }
1244 Err(_) => continue,
1245 }
1246 }
1247 }
1248 Ok(None)
1249}
1250
1251#[async_trait]
1252impl TextToSpeechProvider for Google {}
1253
1254#[derive(Clone, Debug, Deserialize)]
1255pub struct GoogleModelEntry {
1256 pub name: String,
1257 pub version: String,
1258 pub display_name: String,
1259 pub description: String,
1260 pub input_token_limit: Option<u32>,
1261 pub output_token_limit: Option<u32>,
1262 pub supported_generation_methods: Vec<String>,
1263 pub temperature: Option<f32>,
1264 pub top_p: Option<f32>,
1265 pub top_k: Option<u32>,
1266 #[serde(flatten)]
1267 pub extra: Value,
1268}
1269
1270impl ModelListRawEntry for GoogleModelEntry {
1271 fn get_id(&self) -> String {
1272 self.name.clone()
1273 }
1274
1275 fn get_created_at(&self) -> DateTime<Utc> {
1276 DateTime::<Utc>::UNIX_EPOCH
1278 }
1279
1280 fn get_raw(&self) -> Value {
1281 self.extra.clone()
1282 }
1283}
1284
1285#[derive(Clone, Debug, Deserialize)]
1286pub struct GoogleModelListResponse {
1287 pub models: Vec<GoogleModelEntry>,
1288}
1289
1290impl ModelListResponse for GoogleModelListResponse {
1291 fn get_models(&self) -> Vec<String> {
1292 self.models.iter().map(|m| m.name.clone()).collect()
1293 }
1294
1295 fn get_models_raw(&self) -> Vec<Box<dyn ModelListRawEntry>> {
1296 self.models
1297 .iter()
1298 .map(|e| Box::new(e.clone()) as Box<dyn ModelListRawEntry>)
1299 .collect()
1300 }
1301
1302 fn get_backend(&self) -> LLMBackend {
1303 LLMBackend::Google
1304 }
1305}
1306
1307#[async_trait]
1308impl ModelsProvider for Google {
1309 async fn list_models(
1310 &self,
1311 _request: Option<&ModelListRequest>,
1312 ) -> Result<Box<dyn ModelListResponse>, LLMError> {
1313 if self.config.api_key.is_empty() {
1314 return Err(LLMError::AuthError("Missing Google API key".to_string()));
1315 }
1316
1317 let url = format!(
1318 "https://generativelanguage.googleapis.com/v1beta/models?key={}",
1319 self.config.api_key
1320 );
1321
1322 let resp = self.client.get(&url).send().await?.error_for_status()?;
1323
1324 let result: GoogleModelListResponse = resp.json().await?;
1325 Ok(Box::new(result))
1326 }
1327}