Skip to main content

rig/providers/gemini/
streaming.rs

1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use tracing::{Level, enabled, info_span};
5use tracing_futures::Instrument;
6
7use super::completion::gemini_api_types::{ContentCandidate, Part, PartKind};
8use super::completion::{
9    CompletionModel, create_request_body, resolve_request_model, streaming_endpoint,
10};
11use crate::completion::message::ReasoningContent;
12use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
13use crate::http_client::HttpClientExt;
14use crate::http_client::sse::{Event, GenericEventSource};
15use crate::streaming;
16use crate::telemetry::SpanCombinator;
17
18#[derive(Debug, Deserialize, Serialize, Default, Clone)]
19#[serde(rename_all = "camelCase")]
20pub struct PartialUsage {
21    pub total_token_count: i32,
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub cached_content_token_count: Option<i32>,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub candidates_token_count: Option<i32>,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub thoughts_token_count: Option<i32>,
28    pub prompt_token_count: i32,
29}
30
31impl GetTokenUsage for PartialUsage {
32    fn token_usage(&self) -> Option<crate::completion::Usage> {
33        let mut usage = crate::completion::Usage::new();
34
35        usage.input_tokens = self.prompt_token_count as u64;
36        usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
37            + self.candidates_token_count.unwrap_or_default()
38            + self.thoughts_token_count.unwrap_or_default()) as u64;
39        usage.total_tokens = usage.input_tokens + usage.output_tokens;
40
41        Some(usage)
42    }
43}
44
45#[derive(Debug, Deserialize)]
46#[serde(rename_all = "camelCase")]
47pub struct StreamGenerateContentResponse {
48    /// Candidate responses from the model.
49    pub candidates: Vec<ContentCandidate>,
50    pub model_version: Option<String>,
51    pub usage_metadata: Option<PartialUsage>,
52}
53
54#[derive(Clone, Debug, Serialize, Deserialize)]
55pub struct StreamingCompletionResponse {
56    pub usage_metadata: PartialUsage,
57}
58
59impl GetTokenUsage for StreamingCompletionResponse {
60    fn token_usage(&self) -> Option<crate::completion::Usage> {
61        let mut usage = crate::completion::Usage::new();
62        usage.total_tokens = self.usage_metadata.total_token_count as u64;
63        usage.output_tokens = self
64            .usage_metadata
65            .candidates_token_count
66            .map(|x| x as u64)
67            .unwrap_or(0);
68        usage.input_tokens = self.usage_metadata.prompt_token_count as u64;
69        Some(usage)
70    }
71}
72
73impl<T> CompletionModel<T>
74where
75    T: HttpClientExt + Clone + 'static,
76{
77    pub(crate) async fn stream(
78        &self,
79        completion_request: CompletionRequest,
80    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
81    {
82        let request_model = resolve_request_model(&self.model, &completion_request);
83        let span = if tracing::Span::current().is_disabled() {
84            info_span!(
85                target: "rig::completions",
86                "chat_streaming",
87                gen_ai.operation.name = "chat_streaming",
88                gen_ai.provider.name = "gcp.gemini",
89                gen_ai.request.model = &request_model,
90                gen_ai.system_instructions = &completion_request.preamble,
91                gen_ai.response.id = tracing::field::Empty,
92                gen_ai.response.model = &request_model,
93                gen_ai.usage.output_tokens = tracing::field::Empty,
94                gen_ai.usage.input_tokens = tracing::field::Empty,
95                gen_ai.usage.cached_tokens = tracing::field::Empty,
96            )
97        } else {
98            tracing::Span::current()
99        };
100        let request = create_request_body(completion_request)?;
101
102        if enabled!(Level::TRACE) {
103            tracing::trace!(
104                target: "rig::streaming",
105                "Gemini streaming completion request: {}",
106                serde_json::to_string_pretty(&request)?
107            );
108        }
109
110        let body = serde_json::to_vec(&request)?;
111
112        let req = self
113            .client
114            .post_sse(streaming_endpoint(&request_model))?
115            .header("Content-Type", "application/json")
116            .body(body)
117            .map_err(|e| CompletionError::HttpError(e.into()))?;
118
119        let mut event_source = GenericEventSource::new(self.client.clone(), req);
120
121        let stream = stream! {
122            let mut final_usage = None;
123            while let Some(event_result) = event_source.next().await {
124                match event_result {
125                    Ok(Event::Open) => {
126                        tracing::debug!("SSE connection opened");
127                        continue;
128                    }
129                    Ok(Event::Message(message)) => {
130                        // Skip heartbeat messages or empty data
131                        if message.data.trim().is_empty() {
132                            continue;
133                        }
134
135                        let data = match serde_json::from_str::<StreamGenerateContentResponse>(&message.data) {
136                            Ok(d) => d,
137                            Err(error) => {
138                                tracing::error!(?error, message = message.data, "Failed to parse SSE message");
139                                continue;
140                            }
141                        };
142
143                        // Process the response data
144                        let Some(choice) = data.candidates.into_iter().next() else {
145                            tracing::debug!("There is no content candidate");
146                            continue;
147                        };
148
149                        let Some(content) = choice.content else {
150                            tracing::debug!(finish_reason = ?choice.finish_reason, "Streaming candidate missing content");
151                            continue;
152                        };
153
154                        if content.parts.is_empty() {
155                            tracing::trace!(reason = ?choice.finish_reason, "There is no part in the streaming content");
156                        }
157
158                        for part in content.parts {
159                            match part {
160                                Part {
161                                    part: PartKind::Text(text),
162                                    thought: Some(true),
163                                    thought_signature,
164                                    ..
165                                } => {
166                                    if !text.is_empty() {
167                                        if thought_signature.is_some() {
168                                            // Signature arrives on the final chunk of a
169                                            // thinking block; emit a full Reasoning so the
170                                            // core accumulator captures the signature for
171                                            // Gemini 3+ roundtrip.
172                                            yield Ok(streaming::RawStreamingChoice::Reasoning {
173                                                id: None,
174                                                content: ReasoningContent::Text {
175                                                    text,
176                                                    signature: thought_signature,
177                                                },
178                                            });
179                                        } else {
180                                            yield Ok(streaming::RawStreamingChoice::ReasoningDelta {
181                                                id: None,
182                                                reasoning: text,
183                                            });
184                                        }
185                                    }
186                                },
187                                Part {
188                                    part: PartKind::Text(text),
189                                    ..
190                                } => {
191                                    if !text.is_empty() {
192                                        yield Ok(streaming::RawStreamingChoice::Message(text));
193                                    }
194                                },
195                                Part {
196                                    part: PartKind::FunctionCall(function_call),
197                                    thought_signature,
198                                    ..
199                                } => {
200                                    yield Ok(streaming::RawStreamingChoice::ToolCall(
201                                        streaming::RawStreamingToolCall::new(function_call.name.clone(), function_call.name.clone(), function_call.args.clone())
202                                            .with_signature(thought_signature)
203                                    ));
204                                },
205                                part => {
206                                    tracing::warn!(?part, "Unsupported response type with streaming");
207                                }
208                            }
209                        }
210
211                        // Check if this is the final response
212                        if choice.finish_reason.is_some() {
213                            let span = tracing::Span::current();
214                            span.record_token_usage(&data.usage_metadata);
215                            final_usage = data.usage_metadata;
216                            break;
217                        }
218                    }
219                    Err(crate::http_client::Error::StreamEnded) => {
220                        break;
221                    }
222                    Err(error) => {
223                        tracing::error!(?error, "SSE error");
224                        yield Err(CompletionError::ProviderError(error.to_string()));
225                        break;
226                    }
227                }
228            }
229
230            // Ensure event source is closed when stream ends
231            event_source.close();
232
233            yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
234                usage_metadata: final_usage.unwrap_or_default()
235            }));
236        }.instrument(span);
237
238        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
239            stream,
240        )))
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use serde_json::json;
248
249    #[test]
250    fn test_deserialize_stream_response_with_single_text_part() {
251        let json_data = json!({
252            "candidates": [{
253                "content": {
254                    "parts": [
255                        {"text": "Hello, world!"}
256                    ],
257                    "role": "model"
258                },
259                "finishReason": "STOP",
260                "index": 0
261            }],
262            "usageMetadata": {
263                "promptTokenCount": 10,
264                "candidatesTokenCount": 5,
265                "totalTokenCount": 15
266            }
267        });
268
269        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
270        assert_eq!(response.candidates.len(), 1);
271        let content = response.candidates[0]
272            .content
273            .as_ref()
274            .expect("candidate should contain content");
275        assert_eq!(content.parts.len(), 1);
276
277        if let Part {
278            part: PartKind::Text(text),
279            ..
280        } = &content.parts[0]
281        {
282            assert_eq!(text, "Hello, world!");
283        } else {
284            panic!("Expected text part");
285        }
286    }
287
288    #[test]
289    fn test_deserialize_stream_response_with_multiple_text_parts() {
290        let json_data = json!({
291            "candidates": [{
292                "content": {
293                    "parts": [
294                        {"text": "Hello, "},
295                        {"text": "world!"},
296                        {"text": " How are you?"}
297                    ],
298                    "role": "model"
299                },
300                "finishReason": "STOP",
301                "index": 0
302            }],
303            "usageMetadata": {
304                "promptTokenCount": 10,
305                "candidatesTokenCount": 8,
306                "totalTokenCount": 18
307            }
308        });
309
310        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
311        assert_eq!(response.candidates.len(), 1);
312        let content = response.candidates[0]
313            .content
314            .as_ref()
315            .expect("candidate should contain content");
316        assert_eq!(content.parts.len(), 3);
317
318        // Verify all three text parts are present
319        for (i, expected_text) in ["Hello, ", "world!", " How are you?"].iter().enumerate() {
320            if let Part {
321                part: PartKind::Text(text),
322                ..
323            } = &content.parts[i]
324            {
325                assert_eq!(text, expected_text);
326            } else {
327                panic!("Expected text part at index {}", i);
328            }
329        }
330    }
331
332    #[test]
333    fn test_deserialize_stream_response_with_multiple_tool_calls() {
334        let json_data = json!({
335            "candidates": [{
336                "content": {
337                    "parts": [
338                        {
339                            "functionCall": {
340                                "name": "get_weather",
341                                "args": {"city": "San Francisco"}
342                            }
343                        },
344                        {
345                            "functionCall": {
346                                "name": "get_temperature",
347                                "args": {"location": "New York"}
348                            }
349                        }
350                    ],
351                    "role": "model"
352                },
353                "finishReason": "STOP",
354                "index": 0
355            }],
356            "usageMetadata": {
357                "promptTokenCount": 50,
358                "candidatesTokenCount": 20,
359                "totalTokenCount": 70
360            }
361        });
362
363        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
364        let content = response.candidates[0]
365            .content
366            .as_ref()
367            .expect("candidate should contain content");
368        assert_eq!(content.parts.len(), 2);
369
370        // Verify first tool call
371        if let Part {
372            part: PartKind::FunctionCall(call),
373            ..
374        } = &content.parts[0]
375        {
376            assert_eq!(call.name, "get_weather");
377        } else {
378            panic!("Expected function call at index 0");
379        }
380
381        // Verify second tool call
382        if let Part {
383            part: PartKind::FunctionCall(call),
384            ..
385        } = &content.parts[1]
386        {
387            assert_eq!(call.name, "get_temperature");
388        } else {
389            panic!("Expected function call at index 1");
390        }
391    }
392
393    #[test]
394    fn test_deserialize_stream_response_with_mixed_parts() {
395        let json_data = json!({
396            "candidates": [{
397                "content": {
398                    "parts": [
399                        {
400                            "text": "Let me think about this...",
401                            "thought": true
402                        },
403                        {
404                            "text": "Here's my response: "
405                        },
406                        {
407                            "functionCall": {
408                                "name": "search",
409                                "args": {"query": "rust async"}
410                            }
411                        },
412                        {
413                            "text": "I found the answer!"
414                        }
415                    ],
416                    "role": "model"
417                },
418                "finishReason": "STOP",
419                "index": 0
420            }],
421            "usageMetadata": {
422                "promptTokenCount": 100,
423                "candidatesTokenCount": 50,
424                "thoughtsTokenCount": 15,
425                "totalTokenCount": 165
426            }
427        });
428
429        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
430        let content = response.candidates[0]
431            .content
432            .as_ref()
433            .expect("candidate should contain content");
434        let parts = &content.parts;
435        assert_eq!(parts.len(), 4);
436
437        // Verify reasoning (thought) part
438        if let Part {
439            part: PartKind::Text(text),
440            thought: Some(true),
441            ..
442        } = &parts[0]
443        {
444            assert_eq!(text, "Let me think about this...");
445        } else {
446            panic!("Expected thought part at index 0");
447        }
448
449        // Verify regular text
450        if let Part {
451            part: PartKind::Text(text),
452            thought,
453            ..
454        } = &parts[1]
455        {
456            assert_eq!(text, "Here's my response: ");
457            assert!(thought.is_none() || thought == &Some(false));
458        } else {
459            panic!("Expected text part at index 1");
460        }
461
462        // Verify tool call
463        if let Part {
464            part: PartKind::FunctionCall(call),
465            ..
466        } = &parts[2]
467        {
468            assert_eq!(call.name, "search");
469        } else {
470            panic!("Expected function call at index 2");
471        }
472
473        // Verify final text
474        if let Part {
475            part: PartKind::Text(text),
476            ..
477        } = &parts[3]
478        {
479            assert_eq!(text, "I found the answer!");
480        } else {
481            panic!("Expected text part at index 3");
482        }
483    }
484
485    #[test]
486    fn test_deserialize_stream_response_with_empty_parts() {
487        let json_data = json!({
488            "candidates": [{
489                "content": {
490                    "parts": [],
491                    "role": "model"
492                },
493                "finishReason": "STOP",
494                "index": 0
495            }],
496            "usageMetadata": {
497                "promptTokenCount": 10,
498                "candidatesTokenCount": 0,
499                "totalTokenCount": 10
500            }
501        });
502
503        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
504        let content = response.candidates[0]
505            .content
506            .as_ref()
507            .expect("candidate should contain content");
508        assert_eq!(content.parts.len(), 0);
509    }
510
511    #[test]
512    fn test_partial_usage_token_calculation() {
513        let usage = PartialUsage {
514            total_token_count: 100,
515            cached_content_token_count: Some(20),
516            candidates_token_count: Some(30),
517            thoughts_token_count: Some(10),
518            prompt_token_count: 40,
519        };
520
521        let token_usage = usage.token_usage().unwrap();
522        assert_eq!(token_usage.input_tokens, 40);
523        assert_eq!(token_usage.output_tokens, 60); // 20 + 30 + 10
524        assert_eq!(token_usage.total_tokens, 100);
525    }
526
527    #[test]
528    fn test_partial_usage_with_missing_counts() {
529        let usage = PartialUsage {
530            total_token_count: 50,
531            cached_content_token_count: None,
532            candidates_token_count: Some(30),
533            thoughts_token_count: None,
534            prompt_token_count: 20,
535        };
536
537        let token_usage = usage.token_usage().unwrap();
538        assert_eq!(token_usage.input_tokens, 20);
539        assert_eq!(token_usage.output_tokens, 30); // Only candidates_token_count
540        assert_eq!(token_usage.total_tokens, 50);
541    }
542
543    #[test]
544    fn test_streaming_completion_response_token_usage() {
545        let response = StreamingCompletionResponse {
546            usage_metadata: PartialUsage {
547                total_token_count: 150,
548                cached_content_token_count: None,
549                candidates_token_count: Some(75),
550                thoughts_token_count: None,
551                prompt_token_count: 75,
552            },
553        };
554
555        let token_usage = response.token_usage().unwrap();
556        assert_eq!(token_usage.input_tokens, 75);
557        assert_eq!(token_usage.output_tokens, 75);
558        assert_eq!(token_usage.total_tokens, 150);
559    }
560}