1use super::api::{
2 Content, FunctionCall, FunctionCallingConfig, FunctionCallingConfigMode, FunctionDeclaration,
3 FunctionResponse, GenerateContentConfig, GenerateContentParameters, GenerateContentResponse,
4 MediaModality, ModalityTokenCount, Part as GooglePart, PrebuiltVoiceConfig, SpeechConfig,
5 ThinkingConfig, Tool, ToolConfig, VoiceConfig,
6};
7use crate::{
8 audio_part_utils, client_utils, id_utils, source_part_utils, stream_utils, AudioPart,
9 ContentDelta, ImagePart, LanguageModel, LanguageModelError, LanguageModelInput,
10 LanguageModelMetadata, LanguageModelResult, LanguageModelStream, Message, ModelResponse,
11 ModelTokensDetails, ModelUsage, Part, PartialModelResponse, ReasoningPart,
12 ResponseFormatOption, ToolChoiceOption,
13};
14use async_stream::try_stream;
15use futures::{future::BoxFuture, StreamExt};
16use reqwest::{
17 header::{HeaderMap, HeaderName, HeaderValue},
18 Client,
19};
20use serde_json::json;
21use std::{collections::HashMap, sync::Arc};
22
23const PROVIDER: &str = "google";
24
25pub struct GoogleModel {
26 model_id: String,
27 api_key: String,
28 base_url: String,
29 client: Client,
30 metadata: Option<Arc<LanguageModelMetadata>>,
31 headers: HashMap<String, String>,
32}
33
34#[derive(Clone, Default)]
35pub struct GoogleModelOptions {
36 pub api_key: String,
37 pub base_url: Option<String>,
38 pub headers: Option<HashMap<String, String>>,
39 pub client: Option<Client>,
40}
41
42impl GoogleModel {
43 #[must_use]
44 pub fn new(model_id: impl Into<String>, options: GoogleModelOptions) -> Self {
45 let GoogleModelOptions {
46 api_key,
47 base_url,
48 headers,
49 client,
50 } = options;
51
52 let base_url = base_url
53 .unwrap_or_else(|| "https://generativelanguage.googleapis.com/v1beta".to_string())
54 .trim_end_matches('/')
55 .to_string();
56 let client = client.unwrap_or_else(Client::new);
57 let headers = headers.unwrap_or_default();
58
59 Self {
60 model_id: model_id.into(),
61 api_key,
62 base_url,
63 client,
64 metadata: None,
65 headers,
66 }
67 }
68
69 #[must_use]
70 pub fn with_metadata(mut self, metadata: LanguageModelMetadata) -> Self {
71 self.metadata = Some(Arc::new(metadata));
72 self
73 }
74
75 fn request_headers(&self) -> LanguageModelResult<HeaderMap> {
76 let mut headers = HeaderMap::new();
77
78 for (key, value) in &self.headers {
79 let header_name = HeaderName::from_bytes(key.as_bytes()).map_err(|error| {
80 LanguageModelError::InvalidInput(format!(
81 "Invalid Google header name '{key}': {error}"
82 ))
83 })?;
84 let header_value = HeaderValue::from_str(value).map_err(|error| {
85 LanguageModelError::InvalidInput(format!(
86 "Invalid Google header value for '{key}': {error}"
87 ))
88 })?;
89 headers.insert(header_name, header_value);
90 }
91
92 Ok(headers)
93 }
94}
95
96impl LanguageModel for GoogleModel {
97 fn provider(&self) -> &'static str {
98 PROVIDER
99 }
100
101 fn model_id(&self) -> String {
102 self.model_id.clone()
103 }
104
105 fn metadata(&self) -> Option<&LanguageModelMetadata> {
106 self.metadata.as_deref()
107 }
108
109 fn generate(
110 &self,
111 input: LanguageModelInput,
112 ) -> BoxFuture<'_, LanguageModelResult<ModelResponse>> {
113 Box::pin(async move {
114 crate::opentelemetry::trace_generate(
115 self.provider(),
116 &self.model_id(),
117 input,
118 |input| async move {
119 let params = convert_to_generate_content_parameters(input, &self.model_id)?;
120
121 let url = format!(
122 "{}/models/{}:generateContent?key={}",
123 self.base_url, self.model_id, self.api_key
124 );
125
126 let headers = self.request_headers()?;
127 let response: GenerateContentResponse =
128 client_utils::send_json(&self.client, &url, ¶ms, headers).await?;
129
130 let candidate = response
131 .candidates
132 .and_then(|c| c.into_iter().next())
133 .ok_or_else(|| {
134 LanguageModelError::Invariant(
135 PROVIDER,
136 "No candidate in response".to_string(),
137 )
138 })?;
139
140 let content = map_google_content(
141 candidate.content.and_then(|c| c.parts).unwrap_or_default(),
142 )?;
143
144 let usage = response
145 .usage_metadata
146 .map(|u| map_google_usage_metadata(&u));
147
148 let cost = if let (Some(usage), Some(pricing)) = (
149 usage.as_ref(),
150 self.metadata().and_then(|m| m.pricing.as_ref()),
151 ) {
152 Some(usage.calculate_cost(pricing))
153 } else {
154 None
155 };
156
157 Ok(ModelResponse {
158 content,
159 usage,
160 cost,
161 })
162 },
163 )
164 .await
165 })
166 }
167
168 fn stream(
169 &self,
170 input: LanguageModelInput,
171 ) -> BoxFuture<'_, LanguageModelResult<LanguageModelStream>> {
172 Box::pin(async move {
173 crate::opentelemetry::trace_stream(
174 self.provider(),
175 &self.model_id(),
176 input,
177 |input| async move {
178 let params = convert_to_generate_content_parameters(input, &self.model_id)?;
179 let metadata = self.metadata.clone();
180
181 let url = format!(
182 "{}/models/{}:streamGenerateContent?key={}&alt=sse",
183 self.base_url, self.model_id, self.api_key
184 );
185
186 let headers = self.request_headers()?;
187 let mut chunk_stream = client_utils::send_sse_stream::<
188 _,
189 GenerateContentResponse,
190 >(
191 &self.client, &url, ¶ms, headers, self.provider()
192 )
193 .await?;
194
195 let stream = try_stream! {
196 let mut all_content_deltas: Vec<ContentDelta> = Vec::new();
197
198 while let Some(chunk) = chunk_stream.next().await {
199 let response = chunk?;
200
201 let candidate = response
202 .candidates
203 .and_then(|c| c.into_iter().next());
204
205 if let Some(candidate) = candidate {
206 if let Some(content) = candidate.content {
207 if let Some(parts) = content.parts {
208 let incoming_deltas = map_google_content_to_delta(
209 parts,
210 &all_content_deltas,
211 )?;
212
213 all_content_deltas.extend(incoming_deltas.clone());
214
215 for delta in incoming_deltas {
216 yield PartialModelResponse {
217 delta: Some(delta),
218 usage: None,
219 cost: None,
220 };
221 }
222 }
223 }
224 }
225
226 if let Some(usage_metadata) = response.usage_metadata {
227 let usage = map_google_usage_metadata(&usage_metadata);
228 yield PartialModelResponse {
229 delta: None,
230 cost: metadata
231 .as_ref()
232 .and_then(|m| m.pricing.as_ref())
233 .map(|pricing| usage.calculate_cost(pricing)),
234 usage: Some(usage),
235 };
236 }
237 }
238 };
239
240 Ok(LanguageModelStream::from_stream(stream))
241 },
242 )
243 .await
244 })
245 }
246}
247
248fn convert_to_generate_content_parameters(
249 input: LanguageModelInput,
250 model_id: &str,
251) -> LanguageModelResult<GenerateContentParameters> {
252 let messages = convert_to_google_contents(input.messages)?;
253
254 let mut params = GenerateContentParameters {
255 contents: messages,
256 model: model_id.to_string(),
257 ..Default::default()
258 };
259 let mut config = GenerateContentConfig::default();
260
261 if let Some(system_prompt) = input.system_prompt {
262 params.system_instruction = Some(Content {
263 role: Some("system".to_string()),
264 parts: Some(vec![GooglePart {
265 text: Some(system_prompt),
266 ..Default::default()
267 }]),
268 });
269 }
270
271 if let Some(temp) = input.temperature {
272 config.temperature = Some(temp);
273 }
274 if let Some(top_p) = input.top_p {
275 config.top_p = Some(top_p);
276 }
277 if let Some(top_k) = input.top_k {
278 config.top_k = Some(top_k);
279 }
280 if let Some(presence_penalty) = input.presence_penalty {
281 config.presence_penalty = Some(presence_penalty);
282 }
283 if let Some(frequency_penalty) = input.frequency_penalty {
284 config.frequency_penalty = Some(frequency_penalty);
285 }
286 if let Some(seed) = input.seed {
287 config.seed = Some(seed);
288 }
289 if let Some(max_tokens) = input.max_tokens {
290 config.max_output_tokens = Some(max_tokens);
291 }
292
293 if let Some(tools) = input.tools {
294 let function_declarations = tools
295 .into_iter()
296 .map(|tool| FunctionDeclaration {
297 name: Some(tool.name),
298 description: Some(tool.description),
299 parameters_json_schema: Some(tool.parameters),
300 ..Default::default()
301 })
302 .collect();
303
304 params.tools = Some(vec![Tool {
305 function_declarations: Some(function_declarations),
306 }]);
307 }
308
309 if let Some(tool_choice) = input.tool_choice {
310 params.tool_config = Some(ToolConfig {
311 function_calling_config: Some(convert_to_google_function_calling_config(tool_choice)),
312 });
313 }
314
315 if let Some(response_format) = input.response_format {
316 let (response_mime_type, response_json_schema) =
317 convert_to_google_response_schema(response_format);
318 config.response_mime_type = Some(response_mime_type);
319 config.response_json_schema = response_json_schema;
320 }
321
322 if let Some(modalities) = input.modalities {
323 config.response_modalities = Some(
324 modalities
325 .into_iter()
326 .map(|m| match m {
327 crate::Modality::Text => "TEXT".to_string(),
328 crate::Modality::Image => "IMAGE".to_string(),
329 crate::Modality::Audio => "AUDIO".to_string(),
330 })
331 .collect(),
332 );
333 }
334
335 if let Some(audio) = input.audio {
336 if let Some(voice) = audio.voice {
337 config.speech_config = Some(SpeechConfig {
338 voice_config: Some(VoiceConfig {
339 prebuilt_voice_config: Some(PrebuiltVoiceConfig {
340 voice_name: Some(voice),
341 }),
342 }),
343 language_code: audio.language,
344 multi_speaker_voice_config: None,
345 });
346 }
347 }
348
349 if let Some(reasoning) = input.reasoning {
350 config.thinking_config = Some(ThinkingConfig {
351 include_thoughts: Some(reasoning.enabled),
352 thinking_budget: reasoning
353 .budget_tokens
354 .map(|t| i32::try_from(t).unwrap_or(0)),
355 });
356 }
357
358 params.generation_config = Some(config);
359
360 params.extra = input.extra;
361
362 Ok(params)
363}
364
365fn convert_to_google_contents(messages: Vec<Message>) -> LanguageModelResult<Vec<Content>> {
366 messages
367 .into_iter()
368 .map(|message| match message {
369 Message::User(user_message) => Ok(Content {
370 role: Some("user".to_string()),
371 parts: Some(
372 user_message
373 .content
374 .into_iter()
375 .flat_map(convert_to_google_parts)
376 .collect(),
377 ),
378 }),
379 Message::Assistant(assistant_message) => Ok(Content {
380 role: Some("model".to_string()),
381 parts: Some(
382 assistant_message
383 .content
384 .into_iter()
385 .flat_map(convert_to_google_parts)
386 .collect(),
387 ),
388 }),
389 Message::Tool(tool_message) => Ok(Content {
390 role: Some("user".to_string()),
391 parts: Some(
392 tool_message
393 .content
394 .into_iter()
395 .flat_map(convert_to_google_parts)
396 .collect(),
397 ),
398 }),
399 })
400 .collect()
401}
402
403fn convert_to_google_parts(part: Part) -> Vec<GooglePart> {
404 match part {
405 Part::Text(text_part) => vec![GooglePart {
406 text: Some(text_part.text),
407 ..Default::default()
408 }],
409 Part::Image(image_part) => vec![GooglePart {
410 inline_data: Some(super::api::Blob2 {
411 data: Some(image_part.data),
412 mime_type: Some(image_part.mime_type),
413 display_name: None,
414 }),
415 ..Default::default()
416 }],
417 Part::Audio(audio_part) => vec![GooglePart {
418 inline_data: Some(super::api::Blob2 {
419 data: Some(audio_part.data),
420 mime_type: Some(audio_part_utils::map_audio_format_to_mime_type(
421 &audio_part.format,
422 )),
423 display_name: None,
424 }),
425 ..Default::default()
426 }],
427 Part::Reasoning(reasoning_part) => vec![GooglePart {
428 text: Some(reasoning_part.text),
429 thought: Some(true),
430 thought_signature: reasoning_part.signature,
431 ..Default::default()
432 }],
433 Part::Source(source_part) => source_part
434 .content
435 .into_iter()
436 .flat_map(convert_to_google_parts)
437 .collect(),
438 Part::ToolCall(tool_call_part) => vec![GooglePart {
439 function_call: Some(FunctionCall {
440 name: Some(tool_call_part.tool_name),
441 args: Some(tool_call_part.args),
442 id: Some(tool_call_part.tool_call_id),
443 }),
444 ..Default::default()
445 }],
446 Part::ToolResult(tool_result_part) => vec![GooglePart {
447 function_response: Some(FunctionResponse {
448 id: Some(tool_result_part.tool_call_id),
449 name: Some(tool_result_part.tool_name),
450 response: Some(convert_to_google_function_response(
451 tool_result_part.content,
452 tool_result_part.is_error.unwrap_or(false),
453 )),
454 }),
455 ..Default::default()
456 }],
457 }
458}
459
460fn convert_to_google_function_response(
461 parts: Vec<Part>,
462 is_error: bool,
463) -> HashMap<String, serde_json::Value> {
464 let compatible_parts = source_part_utils::get_compatible_parts_without_source_parts(parts);
465 let text_parts: Vec<String> = compatible_parts
466 .into_iter()
467 .filter_map(|part| {
468 if let Part::Text(text_part) = part {
469 Some(text_part.text)
470 } else {
471 None
472 }
473 })
474 .collect();
475
476 let responses: Vec<serde_json::Value> = text_parts
477 .into_iter()
478 .map(|text| serde_json::from_str(&text).unwrap_or_else(|_| json!({ "data": text })))
479 .collect();
480
481 let mut result = HashMap::new();
484 let key = if is_error { "error" } else { "output" };
485 let value = if responses.len() == 1 {
486 responses.into_iter().next().unwrap_or(json!({}))
487 } else {
488 json!(responses)
489 };
490 result.insert(key.to_string(), value);
491 result
492}
493
494fn convert_to_google_function_calling_config(
495 tool_choice: ToolChoiceOption,
496) -> FunctionCallingConfig {
497 match tool_choice {
498 ToolChoiceOption::Auto => FunctionCallingConfig {
499 mode: Some(FunctionCallingConfigMode::Auto),
500 allowed_function_names: None,
501 },
502 ToolChoiceOption::None => FunctionCallingConfig {
503 mode: Some(FunctionCallingConfigMode::None),
504 allowed_function_names: None,
505 },
506 ToolChoiceOption::Required => FunctionCallingConfig {
507 mode: Some(FunctionCallingConfigMode::Any),
508 allowed_function_names: None,
509 },
510 ToolChoiceOption::Tool(tool) => FunctionCallingConfig {
511 mode: Some(FunctionCallingConfigMode::Any),
512 allowed_function_names: Some(vec![tool.tool_name]),
513 },
514 }
515}
516
517fn convert_to_google_response_schema(
518 response_format: ResponseFormatOption,
519) -> (String, Option<serde_json::Value>) {
520 match response_format {
521 ResponseFormatOption::Text => ("text/plain".to_string(), None),
522 ResponseFormatOption::Json(json_format) => {
523 ("application/json".to_string(), json_format.schema)
524 }
525 }
526}
527
528fn map_google_content(parts: Vec<GooglePart>) -> LanguageModelResult<Vec<Part>> {
529 parts
530 .into_iter()
531 .filter_map(|part| {
532 if let Some(text) = part.text {
533 if part.thought.unwrap_or(false) {
534 let mut reasoning_part = ReasoningPart::new(text);
535 if let Some(signature) = part.thought_signature {
536 reasoning_part = reasoning_part.with_signature(signature);
537 }
538 Some(Ok(reasoning_part.into()))
539 } else {
540 Some(Ok(Part::text(text)))
541 }
542 } else if let Some(inline_data) = part.inline_data {
543 if let (Some(data), Some(mime_type)) = (inline_data.data, inline_data.mime_type) {
544 if mime_type.starts_with("image/") {
545 Some(Ok(Part::Image(ImagePart {
546 data,
547 mime_type,
548 width: None,
549 height: None,
550 id: None,
551 })))
552 } else if mime_type.starts_with("audio/") {
553 if let Ok(format) =
554 audio_part_utils::map_mime_type_to_audio_format(&mime_type)
555 {
556 Some(Ok(Part::Audio(AudioPart {
557 data,
558 format,
559 sample_rate: None,
560 channels: None,
561 id: None,
562 transcript: None,
563 })))
564 } else {
565 Some(Err(LanguageModelError::Invariant(
566 PROVIDER,
567 format!("Unsupported audio mime type: {mime_type}"),
568 )))
569 }
570 } else {
571 None
572 }
573 } else {
574 Some(Err(LanguageModelError::Invariant(
575 PROVIDER,
576 "Inline data missing data or mime type".to_string(),
577 )))
578 }
579 } else if let Some(function_call) = part.function_call {
580 if let Some(name) = function_call.name {
581 Some(Ok(Part::ToolCall(crate::ToolCallPart {
582 tool_call_id: function_call
583 .id
584 .unwrap_or_else(|| id_utils::generate_string(10)),
586 tool_name: name,
587 args: json!(function_call.args.unwrap_or_default()),
588 id: None,
589 })))
590 } else {
591 Some(Err(LanguageModelError::Invariant(
592 PROVIDER,
593 "Function call missing name".to_string(),
594 )))
595 }
596 } else {
597 None
598 }
599 })
600 .collect()
601}
602
603fn map_google_content_to_delta(
604 parts: Vec<GooglePart>,
605 existing_deltas: &[ContentDelta],
606) -> LanguageModelResult<Vec<ContentDelta>> {
607 let mut deltas = Vec::new();
608
609 let parts = map_google_content(parts)?;
610
611 for part in parts {
612 let all_content_deltas = existing_deltas
613 .iter()
614 .chain(deltas.iter())
615 .collect::<Vec<_>>();
616 let part_delta = stream_utils::loosely_convert_part_to_part_delta(part)?;
617 let guessed_index = stream_utils::guess_delta_index(&part_delta, &all_content_deltas, None);
618 deltas.push(ContentDelta {
619 index: guessed_index,
620 part: part_delta,
621 });
622 }
623
624 Ok(deltas)
625}
626
627fn map_google_usage_metadata(
628 usage: &super::api::GenerateContentResponseUsageMetadata,
629) -> ModelUsage {
630 let input_tokens = usage.prompt_token_count.unwrap_or(0);
631 let output_tokens = usage.candidates_token_count.unwrap_or(0);
632
633 let input_tokens_details = map_modality_token_counts(
634 usage.prompt_tokens_details.as_ref(),
635 usage.cache_tokens_details.as_ref(),
636 );
637
638 let output_tokens_details =
639 map_modality_token_counts(usage.candidates_tokens_details.as_ref(), None);
640
641 ModelUsage {
642 input_tokens,
643 output_tokens,
644 input_tokens_details,
645 output_tokens_details,
646 }
647}
648
649fn map_modality_token_counts(
650 details: Option<&Vec<ModalityTokenCount>>,
651 cached_details: Option<&Vec<ModalityTokenCount>>,
652) -> Option<ModelTokensDetails> {
653 if details.is_none() && cached_details.is_none() {
654 return None;
655 }
656
657 let mut tokens_details = ModelTokensDetails {
658 text_tokens: None,
659 cached_text_tokens: None,
660 audio_tokens: None,
661 cached_audio_tokens: None,
662 image_tokens: None,
663 cached_image_tokens: None,
664 };
665
666 if let Some(details) = details {
667 for detail in details {
668 if let (Some(modality), Some(count)) = (&detail.modality, detail.token_count) {
669 match modality {
670 MediaModality::Text => {
671 *tokens_details.text_tokens.get_or_insert_default() += count;
672 }
673 MediaModality::Audio => {
674 *tokens_details.audio_tokens.get_or_insert_default() += count;
675 }
676 MediaModality::Image => {
677 *tokens_details.image_tokens.get_or_insert_default() += count;
678 }
679 _ => {}
680 }
681 }
682 }
683 }
684
685 if let Some(cached) = cached_details {
686 for detail in cached {
687 if let (Some(modality), Some(count)) = (&detail.modality, detail.token_count) {
688 match modality {
689 MediaModality::Text => {
690 *tokens_details.cached_text_tokens.get_or_insert_default() += count;
691 }
692 MediaModality::Audio => {
693 *tokens_details.cached_audio_tokens.get_or_insert_default() += count;
694 }
695 MediaModality::Image => {
696 *tokens_details.cached_image_tokens.get_or_insert_default() += count;
697 }
698 _ => {}
699 }
700 }
701 }
702 }
703
704 Some(tokens_details)
705}