Skip to main content

rig/providers/gemini/
streaming.rs

1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use tracing::{Level, enabled, info_span};
5use tracing_futures::Instrument;
6
7use super::completion::gemini_api_types::{ContentCandidate, Part, PartKind};
8use super::completion::{
9    CompletionModel, create_request_body, resolve_request_model, streaming_endpoint,
10};
11use crate::completion::message::ReasoningContent;
12use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
13use crate::http_client::HttpClientExt;
14use crate::http_client::sse::{Event, GenericEventSource};
15use crate::streaming;
16use crate::telemetry::SpanCombinator;
17
18#[derive(Debug, Deserialize, Serialize, Default, Clone)]
19#[serde(rename_all = "camelCase")]
20pub struct PartialUsage {
21    pub total_token_count: i32,
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub cached_content_token_count: Option<i32>,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub candidates_token_count: Option<i32>,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub thoughts_token_count: Option<i32>,
28    pub prompt_token_count: i32,
29}
30
31impl GetTokenUsage for PartialUsage {
32    fn token_usage(&self) -> Option<crate::completion::Usage> {
33        let mut usage = crate::completion::Usage::new();
34
35        usage.input_tokens = self.prompt_token_count as u64;
36        usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
37            + self.candidates_token_count.unwrap_or_default()
38            + self.thoughts_token_count.unwrap_or_default()) as u64;
39        usage.total_tokens = usage.input_tokens + usage.output_tokens;
40
41        Some(usage)
42    }
43}
44
45#[derive(Debug, Deserialize)]
46#[serde(rename_all = "camelCase")]
47pub struct StreamGenerateContentResponse {
48    /// Candidate responses from the model.
49    pub candidates: Vec<ContentCandidate>,
50    pub model_version: Option<String>,
51    pub usage_metadata: Option<PartialUsage>,
52}
53
54#[derive(Clone, Debug, Serialize, Deserialize)]
55pub struct StreamingCompletionResponse {
56    pub usage_metadata: PartialUsage,
57}
58
59impl GetTokenUsage for StreamingCompletionResponse {
60    fn token_usage(&self) -> Option<crate::completion::Usage> {
61        let mut usage = crate::completion::Usage::new();
62        usage.total_tokens = self.usage_metadata.total_token_count as u64;
63        usage.output_tokens = self
64            .usage_metadata
65            .candidates_token_count
66            .map(|x| x as u64)
67            .unwrap_or(0);
68        usage.input_tokens = self.usage_metadata.prompt_token_count as u64;
69        Some(usage)
70    }
71}
72
73impl<T> CompletionModel<T>
74where
75    T: HttpClientExt + Clone + 'static,
76{
77    pub(crate) async fn stream(
78        &self,
79        completion_request: CompletionRequest,
80    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
81    {
82        let request_model = resolve_request_model(&self.model, &completion_request);
83        let span = if tracing::Span::current().is_disabled() {
84            info_span!(
85                target: "rig::completions",
86                "chat_streaming",
87                gen_ai.operation.name = "chat_streaming",
88                gen_ai.provider.name = "gcp.gemini",
89                gen_ai.request.model = &request_model,
90                gen_ai.system_instructions = &completion_request.preamble,
91                gen_ai.response.id = tracing::field::Empty,
92                gen_ai.response.model = &request_model,
93                gen_ai.usage.output_tokens = tracing::field::Empty,
94                gen_ai.usage.input_tokens = tracing::field::Empty,
95            )
96        } else {
97            tracing::Span::current()
98        };
99        let request = create_request_body(completion_request)?;
100
101        if enabled!(Level::TRACE) {
102            tracing::trace!(
103                target: "rig::streaming",
104                "Gemini streaming completion request: {}",
105                serde_json::to_string_pretty(&request)?
106            );
107        }
108
109        let body = serde_json::to_vec(&request)?;
110
111        let req = self
112            .client
113            .post_sse(streaming_endpoint(&request_model))?
114            .header("Content-Type", "application/json")
115            .body(body)
116            .map_err(|e| CompletionError::HttpError(e.into()))?;
117
118        let mut event_source = GenericEventSource::new(self.client.clone(), req);
119
120        let stream = stream! {
121            let mut final_usage = None;
122            while let Some(event_result) = event_source.next().await {
123                match event_result {
124                    Ok(Event::Open) => {
125                        tracing::debug!("SSE connection opened");
126                        continue;
127                    }
128                    Ok(Event::Message(message)) => {
129                        // Skip heartbeat messages or empty data
130                        if message.data.trim().is_empty() {
131                            continue;
132                        }
133
134                        let data = match serde_json::from_str::<StreamGenerateContentResponse>(&message.data) {
135                            Ok(d) => d,
136                            Err(error) => {
137                                tracing::error!(?error, message = message.data, "Failed to parse SSE message");
138                                continue;
139                            }
140                        };
141
142                        // Process the response data
143                        let Some(choice) = data.candidates.into_iter().next() else {
144                            tracing::debug!("There is no content candidate");
145                            continue;
146                        };
147
148                        let Some(content) = choice.content else {
149                            tracing::debug!(finish_reason = ?choice.finish_reason, "Streaming candidate missing content");
150                            continue;
151                        };
152
153                        if content.parts.is_empty() {
154                            tracing::trace!(reason = ?choice.finish_reason, "There is no part in the streaming content");
155                        }
156
157                        for part in content.parts {
158                            match part {
159                                Part {
160                                    part: PartKind::Text(text),
161                                    thought: Some(true),
162                                    thought_signature,
163                                    ..
164                                } => {
165                                    if !text.is_empty() {
166                                        if thought_signature.is_some() {
167                                            // Signature arrives on the final chunk of a
168                                            // thinking block; emit a full Reasoning so the
169                                            // core accumulator captures the signature for
170                                            // Gemini 3+ roundtrip.
171                                            yield Ok(streaming::RawStreamingChoice::Reasoning {
172                                                id: None,
173                                                content: ReasoningContent::Text {
174                                                    text,
175                                                    signature: thought_signature,
176                                                },
177                                            });
178                                        } else {
179                                            yield Ok(streaming::RawStreamingChoice::ReasoningDelta {
180                                                id: None,
181                                                reasoning: text,
182                                            });
183                                        }
184                                    }
185                                },
186                                Part {
187                                    part: PartKind::Text(text),
188                                    ..
189                                } => {
190                                    if !text.is_empty() {
191                                        yield Ok(streaming::RawStreamingChoice::Message(text));
192                                    }
193                                },
194                                Part {
195                                    part: PartKind::FunctionCall(function_call),
196                                    thought_signature,
197                                    ..
198                                } => {
199                                    yield Ok(streaming::RawStreamingChoice::ToolCall(
200                                        streaming::RawStreamingToolCall::new(function_call.name.clone(), function_call.name.clone(), function_call.args.clone())
201                                            .with_signature(thought_signature)
202                                    ));
203                                },
204                                part => {
205                                    tracing::warn!(?part, "Unsupported response type with streaming");
206                                }
207                            }
208                        }
209
210                        // Check if this is the final response
211                        if choice.finish_reason.is_some() {
212                            let span = tracing::Span::current();
213                            span.record_token_usage(&data.usage_metadata);
214                            final_usage = data.usage_metadata;
215                            break;
216                        }
217                    }
218                    Err(crate::http_client::Error::StreamEnded) => {
219                        break;
220                    }
221                    Err(error) => {
222                        tracing::error!(?error, "SSE error");
223                        yield Err(CompletionError::ProviderError(error.to_string()));
224                        break;
225                    }
226                }
227            }
228
229            // Ensure event source is closed when stream ends
230            event_source.close();
231
232            yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
233                usage_metadata: final_usage.unwrap_or_default()
234            }));
235        }.instrument(span);
236
237        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
238            stream,
239        )))
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use serde_json::json;
247
248    #[test]
249    fn test_deserialize_stream_response_with_single_text_part() {
250        let json_data = json!({
251            "candidates": [{
252                "content": {
253                    "parts": [
254                        {"text": "Hello, world!"}
255                    ],
256                    "role": "model"
257                },
258                "finishReason": "STOP",
259                "index": 0
260            }],
261            "usageMetadata": {
262                "promptTokenCount": 10,
263                "candidatesTokenCount": 5,
264                "totalTokenCount": 15
265            }
266        });
267
268        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
269        assert_eq!(response.candidates.len(), 1);
270        let content = response.candidates[0]
271            .content
272            .as_ref()
273            .expect("candidate should contain content");
274        assert_eq!(content.parts.len(), 1);
275
276        if let Part {
277            part: PartKind::Text(text),
278            ..
279        } = &content.parts[0]
280        {
281            assert_eq!(text, "Hello, world!");
282        } else {
283            panic!("Expected text part");
284        }
285    }
286
287    #[test]
288    fn test_deserialize_stream_response_with_multiple_text_parts() {
289        let json_data = json!({
290            "candidates": [{
291                "content": {
292                    "parts": [
293                        {"text": "Hello, "},
294                        {"text": "world!"},
295                        {"text": " How are you?"}
296                    ],
297                    "role": "model"
298                },
299                "finishReason": "STOP",
300                "index": 0
301            }],
302            "usageMetadata": {
303                "promptTokenCount": 10,
304                "candidatesTokenCount": 8,
305                "totalTokenCount": 18
306            }
307        });
308
309        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
310        assert_eq!(response.candidates.len(), 1);
311        let content = response.candidates[0]
312            .content
313            .as_ref()
314            .expect("candidate should contain content");
315        assert_eq!(content.parts.len(), 3);
316
317        // Verify all three text parts are present
318        for (i, expected_text) in ["Hello, ", "world!", " How are you?"].iter().enumerate() {
319            if let Part {
320                part: PartKind::Text(text),
321                ..
322            } = &content.parts[i]
323            {
324                assert_eq!(text, expected_text);
325            } else {
326                panic!("Expected text part at index {}", i);
327            }
328        }
329    }
330
331    #[test]
332    fn test_deserialize_stream_response_with_multiple_tool_calls() {
333        let json_data = json!({
334            "candidates": [{
335                "content": {
336                    "parts": [
337                        {
338                            "functionCall": {
339                                "name": "get_weather",
340                                "args": {"city": "San Francisco"}
341                            }
342                        },
343                        {
344                            "functionCall": {
345                                "name": "get_temperature",
346                                "args": {"location": "New York"}
347                            }
348                        }
349                    ],
350                    "role": "model"
351                },
352                "finishReason": "STOP",
353                "index": 0
354            }],
355            "usageMetadata": {
356                "promptTokenCount": 50,
357                "candidatesTokenCount": 20,
358                "totalTokenCount": 70
359            }
360        });
361
362        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
363        let content = response.candidates[0]
364            .content
365            .as_ref()
366            .expect("candidate should contain content");
367        assert_eq!(content.parts.len(), 2);
368
369        // Verify first tool call
370        if let Part {
371            part: PartKind::FunctionCall(call),
372            ..
373        } = &content.parts[0]
374        {
375            assert_eq!(call.name, "get_weather");
376        } else {
377            panic!("Expected function call at index 0");
378        }
379
380        // Verify second tool call
381        if let Part {
382            part: PartKind::FunctionCall(call),
383            ..
384        } = &content.parts[1]
385        {
386            assert_eq!(call.name, "get_temperature");
387        } else {
388            panic!("Expected function call at index 1");
389        }
390    }
391
392    #[test]
393    fn test_deserialize_stream_response_with_mixed_parts() {
394        let json_data = json!({
395            "candidates": [{
396                "content": {
397                    "parts": [
398                        {
399                            "text": "Let me think about this...",
400                            "thought": true
401                        },
402                        {
403                            "text": "Here's my response: "
404                        },
405                        {
406                            "functionCall": {
407                                "name": "search",
408                                "args": {"query": "rust async"}
409                            }
410                        },
411                        {
412                            "text": "I found the answer!"
413                        }
414                    ],
415                    "role": "model"
416                },
417                "finishReason": "STOP",
418                "index": 0
419            }],
420            "usageMetadata": {
421                "promptTokenCount": 100,
422                "candidatesTokenCount": 50,
423                "thoughtsTokenCount": 15,
424                "totalTokenCount": 165
425            }
426        });
427
428        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
429        let content = response.candidates[0]
430            .content
431            .as_ref()
432            .expect("candidate should contain content");
433        let parts = &content.parts;
434        assert_eq!(parts.len(), 4);
435
436        // Verify reasoning (thought) part
437        if let Part {
438            part: PartKind::Text(text),
439            thought: Some(true),
440            ..
441        } = &parts[0]
442        {
443            assert_eq!(text, "Let me think about this...");
444        } else {
445            panic!("Expected thought part at index 0");
446        }
447
448        // Verify regular text
449        if let Part {
450            part: PartKind::Text(text),
451            thought,
452            ..
453        } = &parts[1]
454        {
455            assert_eq!(text, "Here's my response: ");
456            assert!(thought.is_none() || thought == &Some(false));
457        } else {
458            panic!("Expected text part at index 1");
459        }
460
461        // Verify tool call
462        if let Part {
463            part: PartKind::FunctionCall(call),
464            ..
465        } = &parts[2]
466        {
467            assert_eq!(call.name, "search");
468        } else {
469            panic!("Expected function call at index 2");
470        }
471
472        // Verify final text
473        if let Part {
474            part: PartKind::Text(text),
475            ..
476        } = &parts[3]
477        {
478            assert_eq!(text, "I found the answer!");
479        } else {
480            panic!("Expected text part at index 3");
481        }
482    }
483
484    #[test]
485    fn test_deserialize_stream_response_with_empty_parts() {
486        let json_data = json!({
487            "candidates": [{
488                "content": {
489                    "parts": [],
490                    "role": "model"
491                },
492                "finishReason": "STOP",
493                "index": 0
494            }],
495            "usageMetadata": {
496                "promptTokenCount": 10,
497                "candidatesTokenCount": 0,
498                "totalTokenCount": 10
499            }
500        });
501
502        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
503        let content = response.candidates[0]
504            .content
505            .as_ref()
506            .expect("candidate should contain content");
507        assert_eq!(content.parts.len(), 0);
508    }
509
510    #[test]
511    fn test_partial_usage_token_calculation() {
512        let usage = PartialUsage {
513            total_token_count: 100,
514            cached_content_token_count: Some(20),
515            candidates_token_count: Some(30),
516            thoughts_token_count: Some(10),
517            prompt_token_count: 40,
518        };
519
520        let token_usage = usage.token_usage().unwrap();
521        assert_eq!(token_usage.input_tokens, 40);
522        assert_eq!(token_usage.output_tokens, 60); // 20 + 30 + 10
523        assert_eq!(token_usage.total_tokens, 100);
524    }
525
526    #[test]
527    fn test_partial_usage_with_missing_counts() {
528        let usage = PartialUsage {
529            total_token_count: 50,
530            cached_content_token_count: None,
531            candidates_token_count: Some(30),
532            thoughts_token_count: None,
533            prompt_token_count: 20,
534        };
535
536        let token_usage = usage.token_usage().unwrap();
537        assert_eq!(token_usage.input_tokens, 20);
538        assert_eq!(token_usage.output_tokens, 30); // Only candidates_token_count
539        assert_eq!(token_usage.total_tokens, 50);
540    }
541
542    #[test]
543    fn test_streaming_completion_response_token_usage() {
544        let response = StreamingCompletionResponse {
545            usage_metadata: PartialUsage {
546                total_token_count: 150,
547                cached_content_token_count: None,
548                candidates_token_count: Some(75),
549                thoughts_token_count: None,
550                prompt_token_count: 75,
551            },
552        };
553
554        let token_usage = response.token_usage().unwrap();
555        assert_eq!(token_usage.input_tokens, 75);
556        assert_eq!(token_usage.output_tokens, 75);
557        assert_eq!(token_usage.total_tokens, 150);
558    }
559}