Skip to main content

rig_core/providers/gemini/
streaming.rs

1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use tracing::{Level, enabled, info_span};
5use tracing_futures::Instrument;
6
7use super::completion::gemini_api_types::{
8    ContentCandidate, ModalityTokenCount, Part, PartKind, TrafficType,
9};
10use super::completion::{
11    CompletionModel, create_request_body, resolve_request_model, streaming_endpoint,
12};
13use crate::completion::message::ReasoningContent;
14use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
15use crate::http_client::HttpClientExt;
16use crate::http_client::sse::{Event, GenericEventSource};
17use crate::streaming;
18use crate::telemetry::SpanCombinator;
19
20#[derive(Debug, Deserialize, Serialize, Default, Clone)]
21#[serde(rename_all = "camelCase")]
22pub struct PartialUsage {
23    pub total_token_count: i32,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub cached_content_token_count: Option<i32>,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub candidates_token_count: Option<i32>,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub thoughts_token_count: Option<i32>,
30    #[serde(default)]
31    pub prompt_token_count: i32,
32    #[serde(default, skip_serializing_if = "Option::is_none")]
33    pub prompt_tokens_details: Option<Vec<ModalityTokenCount>>,
34    #[serde(default, skip_serializing_if = "Option::is_none")]
35    pub cache_tokens_details: Option<Vec<ModalityTokenCount>>,
36    #[serde(default, skip_serializing_if = "Option::is_none")]
37    pub candidates_tokens_details: Option<Vec<ModalityTokenCount>>,
38    #[serde(default, skip_serializing_if = "Option::is_none")]
39    pub tool_use_prompt_token_count: Option<i32>,
40    #[serde(default, skip_serializing_if = "Option::is_none")]
41    pub tool_use_prompt_tokens_details: Option<Vec<ModalityTokenCount>>,
42    #[serde(default, skip_serializing_if = "Option::is_none")]
43    pub traffic_type: Option<TrafficType>,
44}
45
46impl GetTokenUsage for PartialUsage {
47    fn token_usage(&self) -> Option<crate::completion::Usage> {
48        let mut usage = crate::completion::Usage::new();
49
50        usage.input_tokens = self.prompt_token_count as u64;
51        usage.output_tokens = self.candidates_token_count.unwrap_or_default() as u64;
52        usage.cached_input_tokens = self.cached_content_token_count.unwrap_or_default() as u64;
53        usage.reasoning_tokens = self.thoughts_token_count.unwrap_or_default() as u64;
54        usage.total_tokens = self.total_token_count as u64;
55
56        Some(usage)
57    }
58}
59
60#[derive(Debug, Deserialize)]
61#[serde(rename_all = "camelCase")]
62pub struct StreamGenerateContentResponse {
63    pub response_id: Option<String>,
64    /// Candidate responses from the model.
65    #[serde(default)]
66    pub candidates: Vec<ContentCandidate>,
67    pub model_version: Option<String>,
68    pub usage_metadata: Option<PartialUsage>,
69}
70
71#[derive(Clone, Debug, Serialize, Deserialize)]
72pub struct StreamingCompletionResponse {
73    pub usage_metadata: PartialUsage,
74}
75
76impl GetTokenUsage for StreamingCompletionResponse {
77    fn token_usage(&self) -> Option<crate::completion::Usage> {
78        self.usage_metadata.token_usage()
79    }
80}
81
82impl<T> CompletionModel<T>
83where
84    T: HttpClientExt + Clone + 'static,
85{
86    pub(crate) async fn stream(
87        &self,
88        completion_request: CompletionRequest,
89    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
90    {
91        let request_model = resolve_request_model(&self.model, &completion_request);
92        let span = if tracing::Span::current().is_disabled() {
93            info_span!(
94                target: "rig::completions",
95                "chat_streaming",
96                gen_ai.operation.name = "chat_streaming",
97                gen_ai.provider.name = "gcp.gemini",
98                gen_ai.request.model = &request_model,
99                gen_ai.system_instructions = &completion_request.preamble,
100                gen_ai.response.id = tracing::field::Empty,
101                gen_ai.response.model = &request_model,
102                gen_ai.usage.output_tokens = tracing::field::Empty,
103                gen_ai.usage.input_tokens = tracing::field::Empty,
104                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
105                gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
106                gen_ai.usage.reasoning_tokens = tracing::field::Empty,
107            )
108        } else {
109            tracing::Span::current()
110        };
111        let request = create_request_body(completion_request)?;
112
113        if enabled!(Level::TRACE) {
114            tracing::trace!(
115                target: "rig::streaming",
116                "Gemini streaming completion request: {}",
117                serde_json::to_string_pretty(&request)?
118            );
119        }
120
121        let body = serde_json::to_vec(&request)?;
122
123        let req = self
124            .client
125            .post_sse(streaming_endpoint(&request_model))?
126            .header("Content-Type", "application/json")
127            .body(body)
128            .map_err(|e| CompletionError::HttpError(e.into()))?;
129
130        let mut event_source = GenericEventSource::new(self.client.clone(), req);
131
132        let stream = stream! {
133            let mut final_usage = None;
134            while let Some(event_result) = event_source.next().await {
135                match event_result {
136                    Ok(Event::Open) => {
137                        tracing::debug!("SSE connection opened");
138                        continue;
139                    }
140                    Ok(Event::Message(message)) => {
141                        // Skip heartbeat messages or empty data
142                        if message.data.trim().is_empty() {
143                            continue;
144                        }
145
146                        let data = match serde_json::from_str::<StreamGenerateContentResponse>(&message.data) {
147                            Ok(d) => d,
148                            Err(error) => {
149                                tracing::error!(?error, message = message.data, "Failed to parse SSE message");
150                                continue;
151                            }
152                        };
153
154                        let span = tracing::Span::current();
155                        if let Some(response_id) = data.response_id.as_deref() {
156                            span.record("gen_ai.response.id", response_id);
157                        }
158                        if let Some(model_version) = data.model_version.as_deref() {
159                            span.record("gen_ai.response.model", model_version);
160                        }
161                        if let Some(usage) = data.usage_metadata.as_ref() {
162                            span.record_token_usage(usage);
163                            final_usage = Some(usage.clone());
164                        }
165
166                        // Process the response data
167                        let Some(choice) = data.candidates.into_iter().next() else {
168                            tracing::debug!("There is no content candidate");
169                            continue;
170                        };
171
172                        let Some(content) = choice.content else {
173                            tracing::debug!(finish_reason = ?choice.finish_reason, "Streaming candidate missing content");
174                            continue;
175                        };
176
177                        if content.parts.is_empty() {
178                            tracing::trace!(reason = ?choice.finish_reason, "There is no part in the streaming content");
179                        }
180
181                        for part in content.parts {
182                            match part {
183                                Part {
184                                    part: PartKind::Text(text),
185                                    thought: Some(true),
186                                    thought_signature,
187                                    ..
188                                } => {
189                                    if !text.is_empty() {
190                                        if thought_signature.is_some() {
191                                            // Signature arrives on the final chunk of a
192                                            // thinking block; emit a full Reasoning so the
193                                            // core accumulator captures the signature for
194                                            // Gemini 3+ roundtrip.
195                                            yield Ok(streaming::RawStreamingChoice::Reasoning {
196                                                id: None,
197                                                content: ReasoningContent::Text {
198                                                    text,
199                                                    signature: thought_signature,
200                                                },
201                                            });
202                                        } else {
203                                            yield Ok(streaming::RawStreamingChoice::ReasoningDelta {
204                                                id: None,
205                                                reasoning: text,
206                                            });
207                                        }
208                                    }
209                                },
210                                Part {
211                                    part: PartKind::Text(text),
212                                    ..
213                                } => {
214                                    if !text.is_empty() {
215                                        yield Ok(streaming::RawStreamingChoice::Message(text));
216                                    }
217                                },
218                                Part {
219                                    part: PartKind::FunctionCall(function_call),
220                                    thought_signature,
221                                    ..
222                                } => {
223                                    yield Ok(streaming::RawStreamingChoice::ToolCall(
224                                        streaming::RawStreamingToolCall::new(function_call.name.clone(), function_call.name.clone(), function_call.args.clone())
225                                            .with_signature(thought_signature)
226                                    ));
227                                },
228                                part => {
229                                    tracing::warn!(?part, "Unsupported response type with streaming");
230                                }
231                            }
232                        }
233
234                        // Check if this is the final response
235                        if choice.finish_reason.is_some() {
236                            break;
237                        }
238                    }
239                    Err(crate::http_client::Error::StreamEnded) => {
240                        break;
241                    }
242                    Err(error) => {
243                        tracing::error!(?error, "SSE error");
244                        yield Err(CompletionError::ProviderError(error.to_string()));
245                        break;
246                    }
247                }
248            }
249
250            // Ensure event source is closed when stream ends
251            event_source.close();
252
253            yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
254                usage_metadata: final_usage.unwrap_or_default()
255            }));
256        }.instrument(span);
257
258        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
259            stream,
260        )))
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    use serde_json::json;
268
269    #[test]
270    fn test_deserialize_stream_response_with_single_text_part() {
271        let json_data = json!({
272            "candidates": [{
273                "content": {
274                    "parts": [
275                        {"text": "Hello, world!"}
276                    ],
277                    "role": "model"
278                },
279                "finishReason": "STOP",
280                "index": 0
281            }],
282            "usageMetadata": {
283                "promptTokenCount": 10,
284                "candidatesTokenCount": 5,
285                "totalTokenCount": 15
286            }
287        });
288
289        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
290        assert_eq!(response.candidates.len(), 1);
291        let content = response.candidates[0]
292            .content
293            .as_ref()
294            .expect("candidate should contain content");
295        assert_eq!(content.parts.len(), 1);
296
297        if let Part {
298            part: PartKind::Text(text),
299            ..
300        } = &content.parts[0]
301        {
302            assert_eq!(text, "Hello, world!");
303        } else {
304            panic!("Expected text part");
305        }
306    }
307
308    #[test]
309    fn test_deserialize_stream_response_with_usage_only_chunk() {
310        let json_data = json!({
311            "responseId": "response-123",
312            "modelVersion": "gemini-2.0-flash-001",
313            "usageMetadata": {
314                "promptTokenCount": 10,
315                "candidatesTokenCount": 5,
316                "totalTokenCount": 15
317            }
318        });
319
320        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
321        assert_eq!(response.response_id.as_deref(), Some("response-123"));
322        assert_eq!(
323            response.model_version.as_deref(),
324            Some("gemini-2.0-flash-001")
325        );
326        assert!(response.candidates.is_empty());
327
328        let usage = response
329            .usage_metadata
330            .as_ref()
331            .and_then(GetTokenUsage::token_usage)
332            .unwrap();
333        assert_eq!(usage.input_tokens, 10);
334        assert_eq!(usage.output_tokens, 5);
335        assert_eq!(usage.total_tokens, 15);
336    }
337
338    #[test]
339    fn test_deserialize_stream_response_with_multiple_text_parts() {
340        let json_data = json!({
341            "candidates": [{
342                "content": {
343                    "parts": [
344                        {"text": "Hello, "},
345                        {"text": "world!"},
346                        {"text": " How are you?"}
347                    ],
348                    "role": "model"
349                },
350                "finishReason": "STOP",
351                "index": 0
352            }],
353            "usageMetadata": {
354                "promptTokenCount": 10,
355                "candidatesTokenCount": 8,
356                "totalTokenCount": 18
357            }
358        });
359
360        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
361        assert_eq!(response.candidates.len(), 1);
362        let content = response.candidates[0]
363            .content
364            .as_ref()
365            .expect("candidate should contain content");
366        assert_eq!(content.parts.len(), 3);
367
368        // Verify all three text parts are present
369        for (i, expected_text) in ["Hello, ", "world!", " How are you?"].iter().enumerate() {
370            if let Part {
371                part: PartKind::Text(text),
372                ..
373            } = &content.parts[i]
374            {
375                assert_eq!(text, expected_text);
376            } else {
377                panic!("Expected text part at index {}", i);
378            }
379        }
380    }
381
382    #[test]
383    fn test_deserialize_stream_response_with_multiple_tool_calls() {
384        let json_data = json!({
385            "candidates": [{
386                "content": {
387                    "parts": [
388                        {
389                            "functionCall": {
390                                "name": "get_weather",
391                                "args": {"city": "San Francisco"}
392                            }
393                        },
394                        {
395                            "functionCall": {
396                                "name": "get_temperature",
397                                "args": {"location": "New York"}
398                            }
399                        }
400                    ],
401                    "role": "model"
402                },
403                "finishReason": "STOP",
404                "index": 0
405            }],
406            "usageMetadata": {
407                "promptTokenCount": 50,
408                "candidatesTokenCount": 20,
409                "totalTokenCount": 70
410            }
411        });
412
413        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
414        let content = response.candidates[0]
415            .content
416            .as_ref()
417            .expect("candidate should contain content");
418        assert_eq!(content.parts.len(), 2);
419
420        // Verify first tool call
421        if let Part {
422            part: PartKind::FunctionCall(call),
423            ..
424        } = &content.parts[0]
425        {
426            assert_eq!(call.name, "get_weather");
427        } else {
428            panic!("Expected function call at index 0");
429        }
430
431        // Verify second tool call
432        if let Part {
433            part: PartKind::FunctionCall(call),
434            ..
435        } = &content.parts[1]
436        {
437            assert_eq!(call.name, "get_temperature");
438        } else {
439            panic!("Expected function call at index 1");
440        }
441    }
442
443    #[test]
444    fn test_deserialize_stream_response_with_mixed_parts() {
445        let json_data = json!({
446            "candidates": [{
447                "content": {
448                    "parts": [
449                        {
450                            "text": "Let me think about this...",
451                            "thought": true
452                        },
453                        {
454                            "text": "Here's my response: "
455                        },
456                        {
457                            "functionCall": {
458                                "name": "search",
459                                "args": {"query": "rust async"}
460                            }
461                        },
462                        {
463                            "text": "I found the answer!"
464                        }
465                    ],
466                    "role": "model"
467                },
468                "finishReason": "STOP",
469                "index": 0
470            }],
471            "usageMetadata": {
472                "promptTokenCount": 100,
473                "candidatesTokenCount": 50,
474                "thoughtsTokenCount": 15,
475                "totalTokenCount": 165
476            }
477        });
478
479        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
480        let content = response.candidates[0]
481            .content
482            .as_ref()
483            .expect("candidate should contain content");
484        let parts = &content.parts;
485        assert_eq!(parts.len(), 4);
486
487        // Verify reasoning (thought) part
488        if let Part {
489            part: PartKind::Text(text),
490            thought: Some(true),
491            ..
492        } = &parts[0]
493        {
494            assert_eq!(text, "Let me think about this...");
495        } else {
496            panic!("Expected thought part at index 0");
497        }
498
499        // Verify regular text
500        if let Part {
501            part: PartKind::Text(text),
502            thought,
503            ..
504        } = &parts[1]
505        {
506            assert_eq!(text, "Here's my response: ");
507            assert!(thought.is_none() || thought == &Some(false));
508        } else {
509            panic!("Expected text part at index 1");
510        }
511
512        // Verify tool call
513        if let Part {
514            part: PartKind::FunctionCall(call),
515            ..
516        } = &parts[2]
517        {
518            assert_eq!(call.name, "search");
519        } else {
520            panic!("Expected function call at index 2");
521        }
522
523        // Verify final text
524        if let Part {
525            part: PartKind::Text(text),
526            ..
527        } = &parts[3]
528        {
529            assert_eq!(text, "I found the answer!");
530        } else {
531            panic!("Expected text part at index 3");
532        }
533    }
534
535    #[test]
536    fn test_deserialize_stream_response_with_empty_parts() {
537        let json_data = json!({
538            "candidates": [{
539                "content": {
540                    "parts": [],
541                    "role": "model"
542                },
543                "finishReason": "STOP",
544                "index": 0
545            }],
546            "usageMetadata": {
547                "promptTokenCount": 10,
548                "candidatesTokenCount": 0,
549                "totalTokenCount": 10
550            }
551        });
552
553        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
554        let content = response.candidates[0]
555            .content
556            .as_ref()
557            .expect("candidate should contain content");
558        assert_eq!(content.parts.len(), 0);
559    }
560
561    #[test]
562    fn test_partial_usage_token_calculation() {
563        let usage = PartialUsage {
564            total_token_count: 100,
565            cached_content_token_count: Some(20),
566            candidates_token_count: Some(30),
567            thoughts_token_count: Some(10),
568            prompt_token_count: 40,
569            prompt_tokens_details: None,
570            cache_tokens_details: None,
571            candidates_tokens_details: None,
572            tool_use_prompt_token_count: None,
573            tool_use_prompt_tokens_details: None,
574            traffic_type: None,
575        };
576
577        let token_usage = usage.token_usage().unwrap();
578        assert_eq!(token_usage.input_tokens, 40);
579        assert_eq!(token_usage.cached_input_tokens, 20);
580        assert_eq!(token_usage.output_tokens, 30);
581        assert_eq!(token_usage.reasoning_tokens, 10);
582        assert_eq!(token_usage.total_tokens, 100);
583    }
584
585    #[test]
586    fn test_partial_usage_with_missing_counts() {
587        let usage = PartialUsage {
588            total_token_count: 50,
589            cached_content_token_count: None,
590            candidates_token_count: Some(30),
591            thoughts_token_count: None,
592            prompt_token_count: 20,
593            prompt_tokens_details: None,
594            cache_tokens_details: None,
595            candidates_tokens_details: None,
596            tool_use_prompt_token_count: None,
597            tool_use_prompt_tokens_details: None,
598            traffic_type: None,
599        };
600
601        let token_usage = usage.token_usage().unwrap();
602        assert_eq!(token_usage.input_tokens, 20);
603        assert_eq!(token_usage.cached_input_tokens, 0);
604        assert_eq!(token_usage.output_tokens, 30);
605        assert_eq!(token_usage.reasoning_tokens, 0);
606        assert_eq!(token_usage.total_tokens, 50);
607    }
608
609    #[test]
610    fn test_streaming_completion_response_token_usage() {
611        let response = StreamingCompletionResponse {
612            usage_metadata: PartialUsage {
613                total_token_count: 150,
614                cached_content_token_count: None,
615                candidates_token_count: Some(75),
616                thoughts_token_count: None,
617                prompt_token_count: 75,
618                prompt_tokens_details: None,
619                cache_tokens_details: None,
620                candidates_tokens_details: None,
621                tool_use_prompt_token_count: None,
622                tool_use_prompt_tokens_details: None,
623                traffic_type: None,
624            },
625        };
626
627        let token_usage = response.token_usage().unwrap();
628        assert_eq!(token_usage.input_tokens, 75);
629        assert_eq!(token_usage.output_tokens, 75);
630        assert_eq!(token_usage.reasoning_tokens, 0);
631        assert_eq!(token_usage.cached_input_tokens, 0);
632        assert_eq!(token_usage.total_tokens, 150);
633    }
634
635    #[test]
636    fn test_partial_usage_serde_roundtrip_with_all_optional_fields() {
637        let json_data = serde_json::json!({
638            "promptTokenCount": 100,
639            "cachedContentTokenCount": 25,
640            "candidatesTokenCount": 50,
641            "thoughtsTokenCount": 15,
642            "totalTokenCount": 190,
643            "promptTokensDetails": [
644                { "modality": "TEXT", "tokenCount": 80 },
645                { "modality": "IMAGE", "tokenCount": 20 }
646            ],
647            "cacheTokensDetails": [
648                { "modality": "TEXT", "tokenCount": 25 }
649            ],
650            "candidatesTokensDetails": [
651                { "modality": "TEXT", "tokenCount": 50 }
652            ],
653            "toolUsePromptTokenCount": 12,
654            "toolUsePromptTokensDetails": [
655                { "modality": "TEXT", "tokenCount": 12 }
656            ],
657            "trafficType": "PROVISIONED_THROUGHPUT"
658        });
659
660        let usage: PartialUsage = serde_json::from_value(json_data).unwrap();
661        assert_eq!(usage.prompt_token_count, 100);
662        assert_eq!(usage.cached_content_token_count, Some(25));
663        assert_eq!(usage.candidates_token_count, Some(50));
664        assert_eq!(usage.thoughts_token_count, Some(15));
665        assert_eq!(usage.total_token_count, 190);
666        assert!(usage.prompt_tokens_details.is_some());
667        assert_eq!(usage.prompt_tokens_details.as_ref().unwrap().len(), 2);
668        assert!(usage.cache_tokens_details.is_some());
669        assert!(usage.candidates_tokens_details.is_some());
670        assert_eq!(usage.tool_use_prompt_token_count, Some(12));
671        assert!(usage.tool_use_prompt_tokens_details.is_some());
672        assert!(matches!(
673            usage.traffic_type,
674            Some(TrafficType::ProvisionedThroughput)
675        ));
676
677        let token_usage = usage.token_usage().unwrap();
678        assert_eq!(token_usage.input_tokens, 100);
679        assert_eq!(token_usage.cached_input_tokens, 25);
680        assert_eq!(token_usage.output_tokens, 50);
681        assert_eq!(token_usage.reasoning_tokens, 15);
682        assert_eq!(token_usage.total_tokens, 190);
683    }
684}