Skip to main content

rig_core/providers/gemini/
streaming.rs

1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use tracing::{Level, enabled, info_span};
5use tracing_futures::Instrument;
6
7use super::completion::gemini_api_types::{
8    ContentCandidate, FinishReason, ModalityTokenCount, Part, PartKind, TrafficType,
9};
10use super::completion::{
11    CompletionModel, create_request_body, function_call_finish_reason_error, resolve_request_model,
12    streaming_endpoint,
13};
14use crate::completion::message::ReasoningContent;
15use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
16use crate::http_client::HttpClientExt;
17use crate::http_client::sse::{Event, GenericEventSource};
18use crate::streaming;
19use crate::telemetry::SpanCombinator;
20
21#[derive(Debug, Deserialize, Serialize, Default, Clone)]
22#[serde(rename_all = "camelCase")]
23pub struct PartialUsage {
24    pub total_token_count: i32,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub cached_content_token_count: Option<i32>,
27    #[serde(skip_serializing_if = "Option::is_none")]
28    pub candidates_token_count: Option<i32>,
29    #[serde(skip_serializing_if = "Option::is_none")]
30    pub thoughts_token_count: Option<i32>,
31    #[serde(default)]
32    pub prompt_token_count: i32,
33    #[serde(default, skip_serializing_if = "Option::is_none")]
34    pub prompt_tokens_details: Option<Vec<ModalityTokenCount>>,
35    #[serde(default, skip_serializing_if = "Option::is_none")]
36    pub cache_tokens_details: Option<Vec<ModalityTokenCount>>,
37    #[serde(default, skip_serializing_if = "Option::is_none")]
38    pub candidates_tokens_details: Option<Vec<ModalityTokenCount>>,
39    #[serde(default, skip_serializing_if = "Option::is_none")]
40    pub tool_use_prompt_token_count: Option<i32>,
41    #[serde(default, skip_serializing_if = "Option::is_none")]
42    pub tool_use_prompt_tokens_details: Option<Vec<ModalityTokenCount>>,
43    #[serde(default, skip_serializing_if = "Option::is_none")]
44    pub traffic_type: Option<TrafficType>,
45}
46
47impl GetTokenUsage for PartialUsage {
48    fn token_usage(&self) -> Option<crate::completion::Usage> {
49        let mut usage = crate::completion::Usage::new();
50
51        usage.input_tokens = self.prompt_token_count as u64;
52        usage.output_tokens = self.candidates_token_count.unwrap_or_default() as u64;
53        usage.cached_input_tokens = self.cached_content_token_count.unwrap_or_default() as u64;
54        usage.reasoning_tokens = self.thoughts_token_count.unwrap_or_default() as u64;
55        usage.tool_use_prompt_tokens = self.tool_use_prompt_token_count.unwrap_or_default() as u64;
56        usage.total_tokens = self.total_token_count as u64;
57
58        Some(usage)
59    }
60}
61
62#[derive(Debug, Deserialize)]
63#[serde(rename_all = "camelCase")]
64pub struct StreamGenerateContentResponse {
65    pub response_id: Option<String>,
66    /// Candidate responses from the model.
67    #[serde(default)]
68    pub candidates: Vec<ContentCandidate>,
69    pub model_version: Option<String>,
70    pub usage_metadata: Option<PartialUsage>,
71}
72
73#[derive(Clone, Debug, Serialize, Deserialize)]
74pub struct StreamingCompletionResponse {
75    pub usage_metadata: PartialUsage,
76    #[serde(skip_serializing_if = "Option::is_none")]
77    pub finish_reason: Option<FinishReason>,
78    #[serde(skip_serializing_if = "Option::is_none")]
79    pub finish_message: Option<String>,
80    #[serde(skip_serializing_if = "Option::is_none")]
81    pub model_version: Option<String>,
82}
83
84impl GetTokenUsage for StreamingCompletionResponse {
85    fn token_usage(&self) -> Option<crate::completion::Usage> {
86        self.usage_metadata.token_usage()
87    }
88}
89
90fn tool_protocol_finish_reason_error(choice: &ContentCandidate) -> Option<CompletionError> {
91    let reason = choice.finish_reason.as_ref()?;
92    function_call_finish_reason_error(reason, choice.finish_message.as_deref())
93}
94
95impl<T> CompletionModel<T>
96where
97    T: HttpClientExt + Clone + 'static,
98{
99    pub(crate) async fn stream(
100        &self,
101        completion_request: CompletionRequest,
102    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
103    {
104        let request_model = resolve_request_model(&self.model, &completion_request);
105        let span = if tracing::Span::current().is_disabled() {
106            info_span!(
107                target: "rig::completions",
108                "chat_streaming",
109                gen_ai.operation.name = "chat_streaming",
110                gen_ai.provider.name = "gcp.gemini",
111                gen_ai.request.model = &request_model,
112                gen_ai.system_instructions = &completion_request.preamble,
113                gen_ai.response.id = tracing::field::Empty,
114                gen_ai.response.model = &request_model,
115                gen_ai.usage.output_tokens = tracing::field::Empty,
116                gen_ai.usage.input_tokens = tracing::field::Empty,
117                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
118                gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
119                gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
120                gen_ai.usage.reasoning_tokens = tracing::field::Empty,
121            )
122        } else {
123            tracing::Span::current()
124        };
125        let request = create_request_body(completion_request)?;
126
127        if enabled!(Level::TRACE) {
128            tracing::trace!(
129                target: "rig::streaming",
130                "Gemini streaming completion request: {}",
131                serde_json::to_string_pretty(&request)?
132            );
133        }
134
135        let body = serde_json::to_vec(&request)?;
136
137        let req = self
138            .client
139            .post_sse(streaming_endpoint(&request_model))?
140            .header("Content-Type", "application/json")
141            .body(body)
142            .map_err(|e| CompletionError::HttpError(e.into()))?;
143
144        let mut event_source = GenericEventSource::new(self.client.clone(), req);
145
146        let stream = stream! {
147            let mut final_usage = None;
148            let mut final_finish_reason: Option<FinishReason> = None;
149            let mut final_finish_message: Option<String> = None;
150            let mut final_model_version: Option<String> = None;
151            let mut stream_failed = false;
152            while let Some(event_result) = event_source.next().await {
153                match event_result {
154                    Ok(Event::Open) => {
155                        tracing::debug!("SSE connection opened");
156                        continue;
157                    }
158                    Ok(Event::Message(message)) => {
159                        // Skip heartbeat messages or empty data
160                        if message.data.trim().is_empty() {
161                            continue;
162                        }
163
164                        let data = match serde_json::from_str::<StreamGenerateContentResponse>(&message.data) {
165                            Ok(d) => d,
166                            Err(error) => {
167                                tracing::error!(?error, message = message.data, "Failed to parse SSE message");
168                                stream_failed = true;
169                                yield Err(CompletionError::JsonError(error));
170                                break;
171                            }
172                        };
173
174                        let span = tracing::Span::current();
175                        if let Some(response_id) = data.response_id.as_deref() {
176                            span.record("gen_ai.response.id", response_id);
177                        }
178                        if let Some(model_version) = &data.model_version {
179                            span.record("gen_ai.response.model", model_version.as_str());
180                            final_model_version = Some(model_version.clone());
181                        }
182                        if let Some(usage) = data.usage_metadata.as_ref() {
183                            span.record_token_usage(usage);
184                            final_usage = Some(usage.clone());
185                        }
186
187                        // Process the response data
188                        let Some(choice) = data.candidates.into_iter().next() else {
189                            tracing::debug!("There is no content candidate");
190                            continue;
191                        };
192
193                        // Capture before partial moves of choice fields
194                        let should_stop = choice.finish_reason.is_some();
195                        if let Some(fr) = &choice.finish_reason {
196                            final_finish_reason = Some(fr.clone());
197                        }
198                        if let Some(message) = &choice.finish_message {
199                            final_finish_message = Some(message.clone());
200                        }
201
202                        if let Some(err) = tool_protocol_finish_reason_error(&choice) {
203                            stream_failed = true;
204                            yield Err(err);
205                            break;
206                        }
207
208                        let Some(content) = choice.content else {
209                            tracing::debug!(finish_reason = ?final_finish_reason, "Streaming candidate missing content");
210                            // Gemini's final chunk may carry finishReason with no content — break instead of skip
211                            if should_stop {
212                                break;
213                            }
214                            continue;
215                        };
216
217                        if content.parts.is_empty() {
218                            tracing::trace!(reason = ?choice.finish_reason, "There is no part in the streaming content");
219                        }
220
221                        for part in content.parts {
222                            match part {
223                                Part {
224                                    part: PartKind::Text(text),
225                                    thought: Some(true),
226                                    thought_signature,
227                                    ..
228                                } => {
229                                    if !text.is_empty() {
230                                        if thought_signature.is_some() {
231                                            // Signature arrives on the final chunk of a
232                                            // thinking block; emit a full Reasoning so the
233                                            // core accumulator captures the signature for
234                                            // Gemini 3+ roundtrip.
235                                            yield Ok(streaming::RawStreamingChoice::Reasoning {
236                                                id: None,
237                                                content: ReasoningContent::Text {
238                                                    text,
239                                                    signature: thought_signature,
240                                                },
241                                            });
242                                        } else {
243                                            yield Ok(streaming::RawStreamingChoice::ReasoningDelta {
244                                                id: None,
245                                                reasoning: text,
246                                            });
247                                        }
248                                    }
249                                },
250                                Part {
251                                    part: PartKind::Text(text),
252                                    ..
253                                } => {
254                                    if !text.is_empty() {
255                                        yield Ok(streaming::RawStreamingChoice::Message(text));
256                                    }
257                                },
258                                Part {
259                                    part: PartKind::FunctionCall(function_call),
260                                    thought_signature,
261                                    ..
262                                } => {
263                                    yield Ok(streaming::RawStreamingChoice::ToolCall(
264                                        streaming::RawStreamingToolCall::new(function_call.name.clone(), function_call.name.clone(), function_call.args.clone())
265                                            .with_signature(thought_signature)
266                                    ));
267                                },
268                                part => {
269                                    tracing::warn!(?part, "Unsupported response type with streaming");
270                                }
271                            }
272                        }
273
274                        // Check if this is the final response
275                        if should_stop {
276                            break;
277                        }
278                    }
279                    Err(crate::http_client::Error::StreamEnded) => {
280                        break;
281                    }
282                    Err(error) => {
283                        tracing::error!(?error, "SSE error");
284                        stream_failed = true;
285                        yield Err(CompletionError::ProviderError(error.to_string()));
286                        break;
287                    }
288                }
289            }
290
291            // Ensure event source is closed when stream ends
292            event_source.close();
293
294            if !stream_failed {
295                yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
296                    usage_metadata: final_usage.unwrap_or_default(),
297                    finish_reason: final_finish_reason,
298                    finish_message: final_finish_message,
299                    model_version: final_model_version,
300                }));
301            }
302        }.instrument(span);
303
304        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
305            stream,
306        )))
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313    use serde_json::json;
314
315    #[test]
316    fn test_deserialize_stream_response_with_single_text_part() {
317        let json_data = json!({
318            "candidates": [{
319                "content": {
320                    "parts": [
321                        {"text": "Hello, world!"}
322                    ],
323                    "role": "model"
324                },
325                "finishReason": "STOP",
326                "index": 0
327            }],
328            "usageMetadata": {
329                "promptTokenCount": 10,
330                "candidatesTokenCount": 5,
331                "totalTokenCount": 15
332            }
333        });
334
335        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
336        assert_eq!(response.candidates.len(), 1);
337        assert!(matches!(
338            response.candidates[0].finish_reason,
339            Some(FinishReason::Stop)
340        ));
341        let content = response.candidates[0]
342            .content
343            .as_ref()
344            .expect("candidate should contain content");
345        assert_eq!(content.parts.len(), 1);
346
347        if let Part {
348            part: PartKind::Text(text),
349            ..
350        } = &content.parts[0]
351        {
352            assert_eq!(text, "Hello, world!");
353        } else {
354            panic!("Expected text part");
355        }
356    }
357
358    #[test]
359    fn test_streaming_tool_protocol_finish_reason_returns_response_error() {
360        for (finish_reason, reason_name, finish_message) in [
361            (
362                "MALFORMED_FUNCTION_CALL",
363                "MalformedFunctionCall",
364                "malformed function call: default_api",
365            ),
366            (
367                "UNEXPECTED_TOOL_CALL",
368                "UnexpectedToolCall",
369                "unexpected tool call: default_api",
370            ),
371            (
372                "MISSING_THOUGHT_SIGNATURE",
373                "MissingThoughtSignature",
374                "missing thought signature for tool call",
375            ),
376            (
377                "TOO_MANY_TOOL_CALLS",
378                "TooManyToolCalls",
379                "too many tool calls in response",
380            ),
381            (
382                "MALFORMED_RESPONSE",
383                "MalformedResponse",
384                "malformed response from provider",
385            ),
386        ] {
387            let json_data = json!({
388                "candidates": [{
389                    "finishReason": finish_reason,
390                    "finishMessage": finish_message,
391                    "index": 0
392                }]
393            });
394
395            let response: StreamGenerateContentResponse =
396                serde_json::from_value(json_data).unwrap();
397            let candidate = response
398                .candidates
399                .first()
400                .expect("expected terminal candidate");
401            let err = tool_protocol_finish_reason_error(candidate)
402                .expect("tool protocol finish reason should be an error");
403
404            assert!(matches!(
405                err,
406                CompletionError::ResponseError(message)
407                    if message.contains(reason_name)
408                        && message.contains(finish_message)
409            ));
410        }
411    }
412
413    #[test]
414    fn test_deserialize_stream_response_with_usage_only_chunk() {
415        let json_data = json!({
416            "responseId": "response-123",
417            "modelVersion": "gemini-2.0-flash-001",
418            "usageMetadata": {
419                "promptTokenCount": 10,
420                "candidatesTokenCount": 5,
421                "totalTokenCount": 15
422            }
423        });
424
425        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
426        assert_eq!(response.response_id.as_deref(), Some("response-123"));
427        assert_eq!(
428            response.model_version.as_deref(),
429            Some("gemini-2.0-flash-001")
430        );
431        assert!(response.candidates.is_empty());
432
433        let usage = response
434            .usage_metadata
435            .as_ref()
436            .and_then(GetTokenUsage::token_usage)
437            .unwrap();
438        assert_eq!(usage.input_tokens, 10);
439        assert_eq!(usage.output_tokens, 5);
440        assert_eq!(usage.total_tokens, 15);
441    }
442
443    #[test]
444    fn test_deserialize_stream_response_with_multiple_text_parts() {
445        let json_data = json!({
446            "candidates": [{
447                "content": {
448                    "parts": [
449                        {"text": "Hello, "},
450                        {"text": "world!"},
451                        {"text": " How are you?"}
452                    ],
453                    "role": "model"
454                },
455                "finishReason": "STOP",
456                "index": 0
457            }],
458            "usageMetadata": {
459                "promptTokenCount": 10,
460                "candidatesTokenCount": 8,
461                "totalTokenCount": 18
462            }
463        });
464
465        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
466        assert_eq!(response.candidates.len(), 1);
467        let content = response.candidates[0]
468            .content
469            .as_ref()
470            .expect("candidate should contain content");
471        assert_eq!(content.parts.len(), 3);
472
473        // Verify all three text parts are present
474        for (i, expected_text) in ["Hello, ", "world!", " How are you?"].iter().enumerate() {
475            if let Part {
476                part: PartKind::Text(text),
477                ..
478            } = &content.parts[i]
479            {
480                assert_eq!(text, expected_text);
481            } else {
482                panic!("Expected text part at index {}", i);
483            }
484        }
485    }
486
487    #[test]
488    fn test_deserialize_stream_response_with_multiple_tool_calls() {
489        let json_data = json!({
490            "candidates": [{
491                "content": {
492                    "parts": [
493                        {
494                            "functionCall": {
495                                "name": "get_weather",
496                                "args": {"city": "San Francisco"}
497                            }
498                        },
499                        {
500                            "functionCall": {
501                                "name": "get_temperature",
502                                "args": {"location": "New York"}
503                            }
504                        }
505                    ],
506                    "role": "model"
507                },
508                "finishReason": "STOP",
509                "index": 0
510            }],
511            "usageMetadata": {
512                "promptTokenCount": 50,
513                "candidatesTokenCount": 20,
514                "totalTokenCount": 70
515            }
516        });
517
518        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
519        let content = response.candidates[0]
520            .content
521            .as_ref()
522            .expect("candidate should contain content");
523        assert_eq!(content.parts.len(), 2);
524
525        // Verify first tool call
526        if let Part {
527            part: PartKind::FunctionCall(call),
528            ..
529        } = &content.parts[0]
530        {
531            assert_eq!(call.name, "get_weather");
532        } else {
533            panic!("Expected function call at index 0");
534        }
535
536        // Verify second tool call
537        if let Part {
538            part: PartKind::FunctionCall(call),
539            ..
540        } = &content.parts[1]
541        {
542            assert_eq!(call.name, "get_temperature");
543        } else {
544            panic!("Expected function call at index 1");
545        }
546    }
547
548    #[test]
549    fn test_deserialize_stream_response_with_mixed_parts() {
550        let json_data = json!({
551            "candidates": [{
552                "content": {
553                    "parts": [
554                        {
555                            "text": "Let me think about this...",
556                            "thought": true
557                        },
558                        {
559                            "text": "Here's my response: "
560                        },
561                        {
562                            "functionCall": {
563                                "name": "search",
564                                "args": {"query": "rust async"}
565                            }
566                        },
567                        {
568                            "text": "I found the answer!"
569                        }
570                    ],
571                    "role": "model"
572                },
573                "finishReason": "STOP",
574                "index": 0
575            }],
576            "usageMetadata": {
577                "promptTokenCount": 100,
578                "candidatesTokenCount": 50,
579                "thoughtsTokenCount": 15,
580                "totalTokenCount": 165
581            }
582        });
583
584        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
585        let content = response.candidates[0]
586            .content
587            .as_ref()
588            .expect("candidate should contain content");
589        let parts = &content.parts;
590        assert_eq!(parts.len(), 4);
591
592        // Verify reasoning (thought) part
593        if let Part {
594            part: PartKind::Text(text),
595            thought: Some(true),
596            ..
597        } = &parts[0]
598        {
599            assert_eq!(text, "Let me think about this...");
600        } else {
601            panic!("Expected thought part at index 0");
602        }
603
604        // Verify regular text
605        if let Part {
606            part: PartKind::Text(text),
607            thought,
608            ..
609        } = &parts[1]
610        {
611            assert_eq!(text, "Here's my response: ");
612            assert!(thought.is_none() || thought == &Some(false));
613        } else {
614            panic!("Expected text part at index 1");
615        }
616
617        // Verify tool call
618        if let Part {
619            part: PartKind::FunctionCall(call),
620            ..
621        } = &parts[2]
622        {
623            assert_eq!(call.name, "search");
624        } else {
625            panic!("Expected function call at index 2");
626        }
627
628        // Verify final text
629        if let Part {
630            part: PartKind::Text(text),
631            ..
632        } = &parts[3]
633        {
634            assert_eq!(text, "I found the answer!");
635        } else {
636            panic!("Expected text part at index 3");
637        }
638    }
639
640    #[test]
641    fn test_deserialize_stream_response_with_empty_parts() {
642        let json_data = json!({
643            "candidates": [{
644                "content": {
645                    "parts": [],
646                    "role": "model"
647                },
648                "finishReason": "STOP",
649                "index": 0
650            }],
651            "usageMetadata": {
652                "promptTokenCount": 10,
653                "candidatesTokenCount": 0,
654                "totalTokenCount": 10
655            }
656        });
657
658        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
659        let content = response.candidates[0]
660            .content
661            .as_ref()
662            .expect("candidate should contain content");
663        assert_eq!(content.parts.len(), 0);
664    }
665
666    #[test]
667    fn test_partial_usage_token_calculation() {
668        let usage = PartialUsage {
669            total_token_count: 100,
670            cached_content_token_count: Some(20),
671            candidates_token_count: Some(30),
672            thoughts_token_count: Some(10),
673            prompt_token_count: 40,
674            prompt_tokens_details: None,
675            cache_tokens_details: None,
676            candidates_tokens_details: None,
677            tool_use_prompt_token_count: Some(12),
678            tool_use_prompt_tokens_details: None,
679            traffic_type: None,
680        };
681
682        let token_usage = usage.token_usage().unwrap();
683        assert_eq!(token_usage.input_tokens, 40);
684        assert_eq!(token_usage.cached_input_tokens, 20);
685        assert_eq!(token_usage.output_tokens, 30);
686        assert_eq!(token_usage.reasoning_tokens, 10);
687        assert_eq!(token_usage.tool_use_prompt_tokens, 12);
688        assert_eq!(token_usage.total_tokens, 100);
689    }
690
691    #[test]
692    fn test_partial_usage_with_missing_counts() {
693        let usage = PartialUsage {
694            total_token_count: 50,
695            cached_content_token_count: None,
696            candidates_token_count: Some(30),
697            thoughts_token_count: None,
698            prompt_token_count: 20,
699            prompt_tokens_details: None,
700            cache_tokens_details: None,
701            candidates_tokens_details: None,
702            tool_use_prompt_token_count: None,
703            tool_use_prompt_tokens_details: None,
704            traffic_type: None,
705        };
706
707        let token_usage = usage.token_usage().unwrap();
708        assert_eq!(token_usage.input_tokens, 20);
709        assert_eq!(token_usage.cached_input_tokens, 0);
710        assert_eq!(token_usage.output_tokens, 30);
711        assert_eq!(token_usage.reasoning_tokens, 0);
712        assert_eq!(token_usage.total_tokens, 50);
713    }
714
715    #[test]
716    fn test_streaming_completion_response_has_finish_reason_and_model_version() {
717        use super::super::completion::gemini_api_types::FinishReason;
718
719        let response = StreamingCompletionResponse {
720            usage_metadata: PartialUsage::default(),
721            finish_reason: Some(FinishReason::Stop),
722            finish_message: None,
723            model_version: Some("gemini-2.5-pro-preview-05-06".to_string()),
724        };
725
726        assert!(matches!(response.finish_reason, Some(FinishReason::Stop)));
727        assert_eq!(
728            response.model_version.as_deref(),
729            Some("gemini-2.5-pro-preview-05-06")
730        );
731
732        let json = serde_json::to_string(&response).unwrap();
733        let deserialized: StreamingCompletionResponse = serde_json::from_str(&json).unwrap();
734        assert!(matches!(
735            deserialized.finish_reason,
736            Some(FinishReason::Stop)
737        ));
738        assert_eq!(
739            deserialized.model_version.as_deref(),
740            Some("gemini-2.5-pro-preview-05-06")
741        );
742    }
743
744    #[test]
745    fn test_streaming_completion_response_token_usage() {
746        let response = StreamingCompletionResponse {
747            usage_metadata: PartialUsage {
748                total_token_count: 150,
749                cached_content_token_count: None,
750                candidates_token_count: Some(75),
751                thoughts_token_count: None,
752                prompt_token_count: 75,
753                prompt_tokens_details: None,
754                cache_tokens_details: None,
755                candidates_tokens_details: None,
756                tool_use_prompt_token_count: None,
757                tool_use_prompt_tokens_details: None,
758                traffic_type: None,
759            },
760            finish_reason: Some(FinishReason::Stop),
761            finish_message: None,
762            model_version: Some("gemini-2.0-flash-001".to_string()),
763        };
764
765        let token_usage = response.token_usage().unwrap();
766        assert_eq!(token_usage.input_tokens, 75);
767        assert_eq!(token_usage.output_tokens, 75);
768        assert_eq!(token_usage.reasoning_tokens, 0);
769        assert_eq!(token_usage.cached_input_tokens, 0);
770        assert_eq!(token_usage.total_tokens, 150);
771        assert!(matches!(response.finish_reason, Some(FinishReason::Stop)));
772        assert_eq!(
773            response.model_version.as_deref(),
774            Some("gemini-2.0-flash-001")
775        );
776    }
777
778    #[test]
779    fn test_partial_usage_serde_roundtrip_with_all_optional_fields() {
780        let json_data = serde_json::json!({
781            "promptTokenCount": 100,
782            "cachedContentTokenCount": 25,
783            "candidatesTokenCount": 50,
784            "thoughtsTokenCount": 15,
785            "totalTokenCount": 190,
786            "promptTokensDetails": [
787                { "modality": "TEXT", "tokenCount": 80 },
788                { "modality": "IMAGE", "tokenCount": 20 }
789            ],
790            "cacheTokensDetails": [
791                { "modality": "TEXT", "tokenCount": 25 }
792            ],
793            "candidatesTokensDetails": [
794                { "modality": "TEXT", "tokenCount": 50 }
795            ],
796            "toolUsePromptTokenCount": 12,
797            "toolUsePromptTokensDetails": [
798                { "modality": "TEXT", "tokenCount": 12 }
799            ],
800            "trafficType": "PROVISIONED_THROUGHPUT"
801        });
802
803        let usage: PartialUsage = serde_json::from_value(json_data).unwrap();
804        assert_eq!(usage.prompt_token_count, 100);
805        assert_eq!(usage.cached_content_token_count, Some(25));
806        assert_eq!(usage.candidates_token_count, Some(50));
807        assert_eq!(usage.thoughts_token_count, Some(15));
808        assert_eq!(usage.total_token_count, 190);
809        assert!(usage.prompt_tokens_details.is_some());
810        assert_eq!(usage.prompt_tokens_details.as_ref().unwrap().len(), 2);
811        assert!(usage.cache_tokens_details.is_some());
812        assert!(usage.candidates_tokens_details.is_some());
813        assert_eq!(usage.tool_use_prompt_token_count, Some(12));
814        assert!(usage.tool_use_prompt_tokens_details.is_some());
815        assert!(matches!(
816            usage.traffic_type,
817            Some(TrafficType::ProvisionedThroughput)
818        ));
819
820        let token_usage = usage.token_usage().unwrap();
821        assert_eq!(token_usage.input_tokens, 100);
822        assert_eq!(token_usage.cached_input_tokens, 25);
823        assert_eq!(token_usage.output_tokens, 50);
824        assert_eq!(token_usage.reasoning_tokens, 15);
825        assert_eq!(token_usage.tool_use_prompt_tokens, 12);
826        assert_eq!(token_usage.total_tokens, 190);
827    }
828}