llm_sdk/google/
model.rs

1use super::api::{
2    Content, FunctionCall, FunctionCallingConfig, FunctionCallingConfigMode, FunctionDeclaration,
3    FunctionResponse, GenerateContentConfig, GenerateContentParameters, GenerateContentResponse,
4    MediaModality, ModalityTokenCount, Part as GooglePart, PrebuiltVoiceConfig, SpeechConfig,
5    ThinkingConfig, Tool, ToolConfig, VoiceConfig,
6};
7use crate::{
8    audio_part_utils, client_utils, id_utils, source_part_utils, stream_utils, AudioPart,
9    ContentDelta, ImagePart, LanguageModel, LanguageModelError, LanguageModelInput,
10    LanguageModelMetadata, LanguageModelResult, LanguageModelStream, Message, ModelResponse,
11    ModelTokensDetails, ModelUsage, Part, PartialModelResponse, ReasoningPart,
12    ResponseFormatOption, ToolChoiceOption,
13};
14use async_stream::try_stream;
15use futures::{future::BoxFuture, StreamExt};
16use reqwest::{
17    header::{HeaderMap, HeaderName, HeaderValue},
18    Client,
19};
20use serde_json::json;
21use std::{collections::HashMap, sync::Arc};
22
23const PROVIDER: &str = "google";
24
25pub struct GoogleModel {
26    model_id: String,
27    api_key: String,
28    base_url: String,
29    client: Client,
30    metadata: Option<Arc<LanguageModelMetadata>>,
31    headers: HashMap<String, String>,
32}
33
34#[derive(Clone, Default)]
35pub struct GoogleModelOptions {
36    pub api_key: String,
37    pub base_url: Option<String>,
38    pub headers: Option<HashMap<String, String>>,
39    pub client: Option<Client>,
40}
41
42impl GoogleModel {
43    #[must_use]
44    pub fn new(model_id: impl Into<String>, options: GoogleModelOptions) -> Self {
45        let GoogleModelOptions {
46            api_key,
47            base_url,
48            headers,
49            client,
50        } = options;
51
52        let base_url = base_url
53            .unwrap_or_else(|| "https://generativelanguage.googleapis.com/v1beta".to_string())
54            .trim_end_matches('/')
55            .to_string();
56        let client = client.unwrap_or_else(Client::new);
57        let headers = headers.unwrap_or_default();
58
59        Self {
60            model_id: model_id.into(),
61            api_key,
62            base_url,
63            client,
64            metadata: None,
65            headers,
66        }
67    }
68
69    #[must_use]
70    pub fn with_metadata(mut self, metadata: LanguageModelMetadata) -> Self {
71        self.metadata = Some(Arc::new(metadata));
72        self
73    }
74
75    fn request_headers(&self) -> LanguageModelResult<HeaderMap> {
76        let mut headers = HeaderMap::new();
77
78        for (key, value) in &self.headers {
79            let header_name = HeaderName::from_bytes(key.as_bytes()).map_err(|error| {
80                LanguageModelError::InvalidInput(format!(
81                    "Invalid Google header name '{key}': {error}"
82                ))
83            })?;
84            let header_value = HeaderValue::from_str(value).map_err(|error| {
85                LanguageModelError::InvalidInput(format!(
86                    "Invalid Google header value for '{key}': {error}"
87                ))
88            })?;
89            headers.insert(header_name, header_value);
90        }
91
92        Ok(headers)
93    }
94}
95
96impl LanguageModel for GoogleModel {
97    fn provider(&self) -> &'static str {
98        PROVIDER
99    }
100
101    fn model_id(&self) -> String {
102        self.model_id.clone()
103    }
104
105    fn metadata(&self) -> Option<&LanguageModelMetadata> {
106        self.metadata.as_deref()
107    }
108
109    fn generate(
110        &self,
111        input: LanguageModelInput,
112    ) -> BoxFuture<'_, LanguageModelResult<ModelResponse>> {
113        Box::pin(async move {
114            crate::opentelemetry::trace_generate(
115                self.provider(),
116                &self.model_id(),
117                input,
118                |input| async move {
119                    let params = convert_to_generate_content_parameters(input, &self.model_id)?;
120
121                    let url = format!(
122                        "{}/models/{}:generateContent?key={}",
123                        self.base_url, self.model_id, self.api_key
124                    );
125
126                    let headers = self.request_headers()?;
127                    let response: GenerateContentResponse =
128                        client_utils::send_json(&self.client, &url, &params, headers).await?;
129
130                    let candidate = response
131                        .candidates
132                        .and_then(|c| c.into_iter().next())
133                        .ok_or_else(|| {
134                            LanguageModelError::Invariant(
135                                PROVIDER,
136                                "No candidate in response".to_string(),
137                            )
138                        })?;
139
140                    let content = map_google_content(
141                        candidate.content.and_then(|c| c.parts).unwrap_or_default(),
142                    )?;
143
144                    let usage = response
145                        .usage_metadata
146                        .map(|u| map_google_usage_metadata(&u));
147
148                    let cost = if let (Some(usage), Some(pricing)) = (
149                        usage.as_ref(),
150                        self.metadata().and_then(|m| m.pricing.as_ref()),
151                    ) {
152                        Some(usage.calculate_cost(pricing))
153                    } else {
154                        None
155                    };
156
157                    Ok(ModelResponse {
158                        content,
159                        usage,
160                        cost,
161                    })
162                },
163            )
164            .await
165        })
166    }
167
168    fn stream(
169        &self,
170        input: LanguageModelInput,
171    ) -> BoxFuture<'_, LanguageModelResult<LanguageModelStream>> {
172        Box::pin(async move {
173            crate::opentelemetry::trace_stream(
174                self.provider(),
175                &self.model_id(),
176                input,
177                |input| async move {
178                    let params = convert_to_generate_content_parameters(input, &self.model_id)?;
179                    let metadata = self.metadata.clone();
180
181                    let url = format!(
182                        "{}/models/{}:streamGenerateContent?key={}&alt=sse",
183                        self.base_url, self.model_id, self.api_key
184                    );
185
186                    let headers = self.request_headers()?;
187                    let mut chunk_stream = client_utils::send_sse_stream::<
188                        _,
189                        GenerateContentResponse,
190                    >(
191                        &self.client, &url, &params, headers, self.provider()
192                    )
193                    .await?;
194
195                    let stream = try_stream! {
196                        let mut all_content_deltas: Vec<ContentDelta> = Vec::new();
197
198                        while let Some(chunk) = chunk_stream.next().await {
199                            let response = chunk?;
200
201                            let candidate = response
202                                .candidates
203                                .and_then(|c| c.into_iter().next());
204
205                            if let Some(candidate) = candidate {
206                                if let Some(content) = candidate.content {
207                                    if let Some(parts) = content.parts {
208                                        let incoming_deltas = map_google_content_to_delta(
209                                            parts,
210                                            &all_content_deltas,
211                                        )?;
212
213                                        all_content_deltas.extend(incoming_deltas.clone());
214
215                                        for delta in incoming_deltas {
216                                            yield PartialModelResponse {
217                                                delta: Some(delta),
218                                                usage: None,
219                                                cost: None,
220                                            };
221                                        }
222                                    }
223                                }
224                            }
225
226                            if let Some(usage_metadata) = response.usage_metadata {
227                                let usage = map_google_usage_metadata(&usage_metadata);
228                                yield PartialModelResponse {
229                                    delta: None,
230                                    cost: metadata
231                                        .as_ref()
232                                        .and_then(|m| m.pricing.as_ref())
233                                        .map(|pricing| usage.calculate_cost(pricing)),
234                                    usage: Some(usage),
235                                };
236                            }
237                        }
238                    };
239
240                    Ok(LanguageModelStream::from_stream(stream))
241                },
242            )
243            .await
244        })
245    }
246}
247
248fn convert_to_generate_content_parameters(
249    input: LanguageModelInput,
250    model_id: &str,
251) -> LanguageModelResult<GenerateContentParameters> {
252    let messages = convert_to_google_contents(input.messages)?;
253
254    let mut params = GenerateContentParameters {
255        contents: messages,
256        model: model_id.to_string(),
257        ..Default::default()
258    };
259    let mut config = GenerateContentConfig::default();
260
261    if let Some(system_prompt) = input.system_prompt {
262        params.system_instruction = Some(Content {
263            role: Some("system".to_string()),
264            parts: Some(vec![GooglePart {
265                text: Some(system_prompt),
266                ..Default::default()
267            }]),
268        });
269    }
270
271    if let Some(temp) = input.temperature {
272        config.temperature = Some(temp);
273    }
274    if let Some(top_p) = input.top_p {
275        config.top_p = Some(top_p);
276    }
277    if let Some(top_k) = input.top_k {
278        config.top_k = Some(top_k);
279    }
280    if let Some(presence_penalty) = input.presence_penalty {
281        config.presence_penalty = Some(presence_penalty);
282    }
283    if let Some(frequency_penalty) = input.frequency_penalty {
284        config.frequency_penalty = Some(frequency_penalty);
285    }
286    if let Some(seed) = input.seed {
287        config.seed = Some(seed);
288    }
289    if let Some(max_tokens) = input.max_tokens {
290        config.max_output_tokens = Some(max_tokens);
291    }
292
293    if let Some(tools) = input.tools {
294        let function_declarations = tools
295            .into_iter()
296            .map(|tool| FunctionDeclaration {
297                name: Some(tool.name),
298                description: Some(tool.description),
299                parameters_json_schema: Some(tool.parameters),
300                ..Default::default()
301            })
302            .collect();
303
304        params.tools = Some(vec![Tool {
305            function_declarations: Some(function_declarations),
306        }]);
307    }
308
309    if let Some(tool_choice) = input.tool_choice {
310        params.tool_config = Some(ToolConfig {
311            function_calling_config: Some(convert_to_google_function_calling_config(tool_choice)),
312        });
313    }
314
315    if let Some(response_format) = input.response_format {
316        let (response_mime_type, response_json_schema) =
317            convert_to_google_response_schema(response_format);
318        config.response_mime_type = Some(response_mime_type);
319        config.response_json_schema = response_json_schema;
320    }
321
322    if let Some(modalities) = input.modalities {
323        config.response_modalities = Some(
324            modalities
325                .into_iter()
326                .map(|m| match m {
327                    crate::Modality::Text => "TEXT".to_string(),
328                    crate::Modality::Image => "IMAGE".to_string(),
329                    crate::Modality::Audio => "AUDIO".to_string(),
330                })
331                .collect(),
332        );
333    }
334
335    if let Some(audio) = input.audio {
336        if let Some(voice) = audio.voice {
337            config.speech_config = Some(SpeechConfig {
338                voice_config: Some(VoiceConfig {
339                    prebuilt_voice_config: Some(PrebuiltVoiceConfig {
340                        voice_name: Some(voice),
341                    }),
342                }),
343                language_code: audio.language,
344                multi_speaker_voice_config: None,
345            });
346        }
347    }
348
349    if let Some(reasoning) = input.reasoning {
350        config.thinking_config = Some(ThinkingConfig {
351            include_thoughts: Some(reasoning.enabled),
352            thinking_budget: reasoning
353                .budget_tokens
354                .map(|t| i32::try_from(t).unwrap_or(0)),
355        });
356    }
357
358    params.generation_config = Some(config);
359
360    params.extra = input.extra;
361
362    Ok(params)
363}
364
365fn convert_to_google_contents(messages: Vec<Message>) -> LanguageModelResult<Vec<Content>> {
366    messages
367        .into_iter()
368        .map(|message| match message {
369            Message::User(user_message) => Ok(Content {
370                role: Some("user".to_string()),
371                parts: Some(
372                    user_message
373                        .content
374                        .into_iter()
375                        .flat_map(convert_to_google_parts)
376                        .collect(),
377                ),
378            }),
379            Message::Assistant(assistant_message) => Ok(Content {
380                role: Some("model".to_string()),
381                parts: Some(
382                    assistant_message
383                        .content
384                        .into_iter()
385                        .flat_map(convert_to_google_parts)
386                        .collect(),
387                ),
388            }),
389            Message::Tool(tool_message) => Ok(Content {
390                role: Some("user".to_string()),
391                parts: Some(
392                    tool_message
393                        .content
394                        .into_iter()
395                        .flat_map(convert_to_google_parts)
396                        .collect(),
397                ),
398            }),
399        })
400        .collect()
401}
402
403fn convert_to_google_parts(part: Part) -> Vec<GooglePart> {
404    match part {
405        Part::Text(text_part) => vec![GooglePart {
406            text: Some(text_part.text),
407            ..Default::default()
408        }],
409        Part::Image(image_part) => vec![GooglePart {
410            inline_data: Some(super::api::Blob2 {
411                data: Some(image_part.data),
412                mime_type: Some(image_part.mime_type),
413                display_name: None,
414            }),
415            ..Default::default()
416        }],
417        Part::Audio(audio_part) => vec![GooglePart {
418            inline_data: Some(super::api::Blob2 {
419                data: Some(audio_part.data),
420                mime_type: Some(audio_part_utils::map_audio_format_to_mime_type(
421                    &audio_part.format,
422                )),
423                display_name: None,
424            }),
425            ..Default::default()
426        }],
427        Part::Reasoning(reasoning_part) => vec![GooglePart {
428            text: Some(reasoning_part.text),
429            thought: Some(true),
430            thought_signature: reasoning_part.signature,
431            ..Default::default()
432        }],
433        Part::Source(source_part) => source_part
434            .content
435            .into_iter()
436            .flat_map(convert_to_google_parts)
437            .collect(),
438        Part::ToolCall(tool_call_part) => vec![GooglePart {
439            function_call: Some(FunctionCall {
440                name: Some(tool_call_part.tool_name),
441                args: Some(tool_call_part.args),
442                id: Some(tool_call_part.tool_call_id),
443            }),
444            ..Default::default()
445        }],
446        Part::ToolResult(tool_result_part) => vec![GooglePart {
447            function_response: Some(FunctionResponse {
448                id: Some(tool_result_part.tool_call_id),
449                name: Some(tool_result_part.tool_name),
450                response: Some(convert_to_google_function_response(
451                    tool_result_part.content,
452                    tool_result_part.is_error.unwrap_or(false),
453                )),
454            }),
455            ..Default::default()
456        }],
457    }
458}
459
460fn convert_to_google_function_response(
461    parts: Vec<Part>,
462    is_error: bool,
463) -> HashMap<String, serde_json::Value> {
464    let compatible_parts = source_part_utils::get_compatible_parts_without_source_parts(parts);
465    let text_parts: Vec<String> = compatible_parts
466        .into_iter()
467        .filter_map(|part| {
468            if let Part::Text(text_part) = part {
469                Some(text_part.text)
470            } else {
471                None
472            }
473        })
474        .collect();
475
476    let responses: Vec<serde_json::Value> = text_parts
477        .into_iter()
478        .map(|text| serde_json::from_str(&text).unwrap_or_else(|_| json!({ "data": text })))
479        .collect();
480
481    // Use "output" key to specify function output and "error" key to specify error
482    // details, as per Google API specification
483    let mut result = HashMap::new();
484    let key = if is_error { "error" } else { "output" };
485    let value = if responses.len() == 1 {
486        responses.into_iter().next().unwrap_or(json!({}))
487    } else {
488        json!(responses)
489    };
490    result.insert(key.to_string(), value);
491    result
492}
493
494fn convert_to_google_function_calling_config(
495    tool_choice: ToolChoiceOption,
496) -> FunctionCallingConfig {
497    match tool_choice {
498        ToolChoiceOption::Auto => FunctionCallingConfig {
499            mode: Some(FunctionCallingConfigMode::Auto),
500            allowed_function_names: None,
501        },
502        ToolChoiceOption::None => FunctionCallingConfig {
503            mode: Some(FunctionCallingConfigMode::None),
504            allowed_function_names: None,
505        },
506        ToolChoiceOption::Required => FunctionCallingConfig {
507            mode: Some(FunctionCallingConfigMode::Any),
508            allowed_function_names: None,
509        },
510        ToolChoiceOption::Tool(tool) => FunctionCallingConfig {
511            mode: Some(FunctionCallingConfigMode::Any),
512            allowed_function_names: Some(vec![tool.tool_name]),
513        },
514    }
515}
516
517fn convert_to_google_response_schema(
518    response_format: ResponseFormatOption,
519) -> (String, Option<serde_json::Value>) {
520    match response_format {
521        ResponseFormatOption::Text => ("text/plain".to_string(), None),
522        ResponseFormatOption::Json(json_format) => {
523            ("application/json".to_string(), json_format.schema)
524        }
525    }
526}
527
528fn map_google_content(parts: Vec<GooglePart>) -> LanguageModelResult<Vec<Part>> {
529    parts
530        .into_iter()
531        .filter_map(|part| {
532            if let Some(text) = part.text {
533                if part.thought.unwrap_or(false) {
534                    let mut reasoning_part = ReasoningPart::new(text);
535                    if let Some(signature) = part.thought_signature {
536                        reasoning_part = reasoning_part.with_signature(signature);
537                    }
538                    Some(Ok(reasoning_part.into()))
539                } else {
540                    Some(Ok(Part::text(text)))
541                }
542            } else if let Some(inline_data) = part.inline_data {
543                if let (Some(data), Some(mime_type)) = (inline_data.data, inline_data.mime_type) {
544                    if mime_type.starts_with("image/") {
545                        Some(Ok(Part::Image(ImagePart {
546                            data,
547                            mime_type,
548                            width: None,
549                            height: None,
550                            id: None,
551                        })))
552                    } else if mime_type.starts_with("audio/") {
553                        if let Ok(format) =
554                            audio_part_utils::map_mime_type_to_audio_format(&mime_type)
555                        {
556                            Some(Ok(Part::Audio(AudioPart {
557                                data,
558                                format,
559                                sample_rate: None,
560                                channels: None,
561                                id: None,
562                                transcript: None,
563                            })))
564                        } else {
565                            Some(Err(LanguageModelError::Invariant(
566                                PROVIDER,
567                                format!("Unsupported audio mime type: {mime_type}"),
568                            )))
569                        }
570                    } else {
571                        None
572                    }
573                } else {
574                    Some(Err(LanguageModelError::Invariant(
575                        PROVIDER,
576                        "Inline data missing data or mime type".to_string(),
577                    )))
578                }
579            } else if let Some(function_call) = part.function_call {
580                if let Some(name) = function_call.name {
581                    Some(Ok(Part::ToolCall(crate::ToolCallPart {
582                        tool_call_id: function_call
583                            .id
584                            // Google does not always return id, generate one if missing
585                            .unwrap_or_else(|| id_utils::generate_string(10)),
586                        tool_name: name,
587                        args: json!(function_call.args.unwrap_or_default()),
588                        id: None,
589                    })))
590                } else {
591                    Some(Err(LanguageModelError::Invariant(
592                        PROVIDER,
593                        "Function call missing name".to_string(),
594                    )))
595                }
596            } else {
597                None
598            }
599        })
600        .collect()
601}
602
603fn map_google_content_to_delta(
604    parts: Vec<GooglePart>,
605    existing_deltas: &[ContentDelta],
606) -> LanguageModelResult<Vec<ContentDelta>> {
607    let mut deltas = Vec::new();
608
609    let parts = map_google_content(parts)?;
610
611    for part in parts {
612        let all_content_deltas = existing_deltas
613            .iter()
614            .chain(deltas.iter())
615            .collect::<Vec<_>>();
616        let part_delta = stream_utils::loosely_convert_part_to_part_delta(part)?;
617        let guessed_index = stream_utils::guess_delta_index(&part_delta, &all_content_deltas, None);
618        deltas.push(ContentDelta {
619            index: guessed_index,
620            part: part_delta,
621        });
622    }
623
624    Ok(deltas)
625}
626
627fn map_google_usage_metadata(
628    usage: &super::api::GenerateContentResponseUsageMetadata,
629) -> ModelUsage {
630    let input_tokens = usage.prompt_token_count.unwrap_or(0);
631    let output_tokens = usage.candidates_token_count.unwrap_or(0);
632
633    let input_tokens_details = map_modality_token_counts(
634        usage.prompt_tokens_details.as_ref(),
635        usage.cache_tokens_details.as_ref(),
636    );
637
638    let output_tokens_details =
639        map_modality_token_counts(usage.candidates_tokens_details.as_ref(), None);
640
641    ModelUsage {
642        input_tokens,
643        output_tokens,
644        input_tokens_details,
645        output_tokens_details,
646    }
647}
648
649fn map_modality_token_counts(
650    details: Option<&Vec<ModalityTokenCount>>,
651    cached_details: Option<&Vec<ModalityTokenCount>>,
652) -> Option<ModelTokensDetails> {
653    if details.is_none() && cached_details.is_none() {
654        return None;
655    }
656
657    let mut tokens_details = ModelTokensDetails {
658        text_tokens: None,
659        cached_text_tokens: None,
660        audio_tokens: None,
661        cached_audio_tokens: None,
662        image_tokens: None,
663        cached_image_tokens: None,
664    };
665
666    if let Some(details) = details {
667        for detail in details {
668            if let (Some(modality), Some(count)) = (&detail.modality, detail.token_count) {
669                match modality {
670                    MediaModality::Text => {
671                        *tokens_details.text_tokens.get_or_insert_default() += count;
672                    }
673                    MediaModality::Audio => {
674                        *tokens_details.audio_tokens.get_or_insert_default() += count;
675                    }
676                    MediaModality::Image => {
677                        *tokens_details.image_tokens.get_or_insert_default() += count;
678                    }
679                    _ => {}
680                }
681            }
682        }
683    }
684
685    if let Some(cached) = cached_details {
686        for detail in cached {
687            if let (Some(modality), Some(count)) = (&detail.modality, detail.token_count) {
688                match modality {
689                    MediaModality::Text => {
690                        *tokens_details.cached_text_tokens.get_or_insert_default() += count;
691                    }
692                    MediaModality::Audio => {
693                        *tokens_details.cached_audio_tokens.get_or_insert_default() += count;
694                    }
695                    MediaModality::Image => {
696                        *tokens_details.cached_image_tokens.get_or_insert_default() += count;
697                    }
698                    _ => {}
699                }
700            }
701        }
702    }
703
704    Some(tokens_details)
705}