Skip to main content

rig/providers/gemini/
streaming.rs

1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use tracing::{Level, enabled, info_span};
5use tracing_futures::Instrument;
6
7use super::completion::gemini_api_types::{ContentCandidate, Part, PartKind};
8use super::completion::{
9    CompletionModel, create_request_body, resolve_request_model, streaming_endpoint,
10};
11use crate::completion::message::ReasoningContent;
12use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
13use crate::http_client::HttpClientExt;
14use crate::http_client::sse::{Event, GenericEventSource};
15use crate::streaming;
16use crate::telemetry::SpanCombinator;
17
18#[derive(Debug, Deserialize, Serialize, Default, Clone)]
19#[serde(rename_all = "camelCase")]
20pub struct PartialUsage {
21    pub total_token_count: i32,
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub cached_content_token_count: Option<i32>,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub candidates_token_count: Option<i32>,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub thoughts_token_count: Option<i32>,
28    #[serde(default)]
29    pub prompt_token_count: i32,
30}
31
32impl GetTokenUsage for PartialUsage {
33    fn token_usage(&self) -> Option<crate::completion::Usage> {
34        let mut usage = crate::completion::Usage::new();
35
36        usage.input_tokens = self.prompt_token_count as u64;
37        usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
38            + self.candidates_token_count.unwrap_or_default()
39            + self.thoughts_token_count.unwrap_or_default()) as u64;
40        usage.total_tokens = usage.input_tokens + usage.output_tokens;
41
42        Some(usage)
43    }
44}
45
46#[derive(Debug, Deserialize)]
47#[serde(rename_all = "camelCase")]
48pub struct StreamGenerateContentResponse {
49    /// Candidate responses from the model.
50    pub candidates: Vec<ContentCandidate>,
51    pub model_version: Option<String>,
52    pub usage_metadata: Option<PartialUsage>,
53}
54
55#[derive(Clone, Debug, Serialize, Deserialize)]
56pub struct StreamingCompletionResponse {
57    pub usage_metadata: PartialUsage,
58}
59
60impl GetTokenUsage for StreamingCompletionResponse {
61    fn token_usage(&self) -> Option<crate::completion::Usage> {
62        let mut usage = crate::completion::Usage::new();
63        usage.total_tokens = self.usage_metadata.total_token_count as u64;
64        usage.output_tokens = self
65            .usage_metadata
66            .candidates_token_count
67            .map(|x| x as u64)
68            .unwrap_or(0);
69        usage.input_tokens = self.usage_metadata.prompt_token_count as u64;
70        Some(usage)
71    }
72}
73
74impl<T> CompletionModel<T>
75where
76    T: HttpClientExt + Clone + 'static,
77{
78    pub(crate) async fn stream(
79        &self,
80        completion_request: CompletionRequest,
81    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
82    {
83        let request_model = resolve_request_model(&self.model, &completion_request);
84        let span = if tracing::Span::current().is_disabled() {
85            info_span!(
86                target: "rig::completions",
87                "chat_streaming",
88                gen_ai.operation.name = "chat_streaming",
89                gen_ai.provider.name = "gcp.gemini",
90                gen_ai.request.model = &request_model,
91                gen_ai.system_instructions = &completion_request.preamble,
92                gen_ai.response.id = tracing::field::Empty,
93                gen_ai.response.model = &request_model,
94                gen_ai.usage.output_tokens = tracing::field::Empty,
95                gen_ai.usage.input_tokens = tracing::field::Empty,
96                gen_ai.usage.cached_tokens = tracing::field::Empty,
97            )
98        } else {
99            tracing::Span::current()
100        };
101        let request = create_request_body(completion_request)?;
102
103        if enabled!(Level::TRACE) {
104            tracing::trace!(
105                target: "rig::streaming",
106                "Gemini streaming completion request: {}",
107                serde_json::to_string_pretty(&request)?
108            );
109        }
110
111        let body = serde_json::to_vec(&request)?;
112
113        let req = self
114            .client
115            .post_sse(streaming_endpoint(&request_model))?
116            .header("Content-Type", "application/json")
117            .body(body)
118            .map_err(|e| CompletionError::HttpError(e.into()))?;
119
120        let mut event_source = GenericEventSource::new(self.client.clone(), req);
121
122        let stream = stream! {
123            let mut final_usage = None;
124            while let Some(event_result) = event_source.next().await {
125                match event_result {
126                    Ok(Event::Open) => {
127                        tracing::debug!("SSE connection opened");
128                        continue;
129                    }
130                    Ok(Event::Message(message)) => {
131                        // Skip heartbeat messages or empty data
132                        if message.data.trim().is_empty() {
133                            continue;
134                        }
135
136                        let data = match serde_json::from_str::<StreamGenerateContentResponse>(&message.data) {
137                            Ok(d) => d,
138                            Err(error) => {
139                                tracing::error!(?error, message = message.data, "Failed to parse SSE message");
140                                continue;
141                            }
142                        };
143
144                        // Process the response data
145                        let Some(choice) = data.candidates.into_iter().next() else {
146                            tracing::debug!("There is no content candidate");
147                            continue;
148                        };
149
150                        let Some(content) = choice.content else {
151                            tracing::debug!(finish_reason = ?choice.finish_reason, "Streaming candidate missing content");
152                            continue;
153                        };
154
155                        if content.parts.is_empty() {
156                            tracing::trace!(reason = ?choice.finish_reason, "There is no part in the streaming content");
157                        }
158
159                        for part in content.parts {
160                            match part {
161                                Part {
162                                    part: PartKind::Text(text),
163                                    thought: Some(true),
164                                    thought_signature,
165                                    ..
166                                } => {
167                                    if !text.is_empty() {
168                                        if thought_signature.is_some() {
169                                            // Signature arrives on the final chunk of a
170                                            // thinking block; emit a full Reasoning so the
171                                            // core accumulator captures the signature for
172                                            // Gemini 3+ roundtrip.
173                                            yield Ok(streaming::RawStreamingChoice::Reasoning {
174                                                id: None,
175                                                content: ReasoningContent::Text {
176                                                    text,
177                                                    signature: thought_signature,
178                                                },
179                                            });
180                                        } else {
181                                            yield Ok(streaming::RawStreamingChoice::ReasoningDelta {
182                                                id: None,
183                                                reasoning: text,
184                                            });
185                                        }
186                                    }
187                                },
188                                Part {
189                                    part: PartKind::Text(text),
190                                    ..
191                                } => {
192                                    if !text.is_empty() {
193                                        yield Ok(streaming::RawStreamingChoice::Message(text));
194                                    }
195                                },
196                                Part {
197                                    part: PartKind::FunctionCall(function_call),
198                                    thought_signature,
199                                    ..
200                                } => {
201                                    yield Ok(streaming::RawStreamingChoice::ToolCall(
202                                        streaming::RawStreamingToolCall::new(function_call.name.clone(), function_call.name.clone(), function_call.args.clone())
203                                            .with_signature(thought_signature)
204                                    ));
205                                },
206                                part => {
207                                    tracing::warn!(?part, "Unsupported response type with streaming");
208                                }
209                            }
210                        }
211
212                        // Check if this is the final response
213                        if choice.finish_reason.is_some() {
214                            let span = tracing::Span::current();
215                            span.record_token_usage(&data.usage_metadata);
216                            final_usage = data.usage_metadata;
217                            break;
218                        }
219                    }
220                    Err(crate::http_client::Error::StreamEnded) => {
221                        break;
222                    }
223                    Err(error) => {
224                        tracing::error!(?error, "SSE error");
225                        yield Err(CompletionError::ProviderError(error.to_string()));
226                        break;
227                    }
228                }
229            }
230
231            // Ensure event source is closed when stream ends
232            event_source.close();
233
234            yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
235                usage_metadata: final_usage.unwrap_or_default()
236            }));
237        }.instrument(span);
238
239        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
240            stream,
241        )))
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248    use serde_json::json;
249
250    #[test]
251    fn test_deserialize_stream_response_with_single_text_part() {
252        let json_data = json!({
253            "candidates": [{
254                "content": {
255                    "parts": [
256                        {"text": "Hello, world!"}
257                    ],
258                    "role": "model"
259                },
260                "finishReason": "STOP",
261                "index": 0
262            }],
263            "usageMetadata": {
264                "promptTokenCount": 10,
265                "candidatesTokenCount": 5,
266                "totalTokenCount": 15
267            }
268        });
269
270        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
271        assert_eq!(response.candidates.len(), 1);
272        let content = response.candidates[0]
273            .content
274            .as_ref()
275            .expect("candidate should contain content");
276        assert_eq!(content.parts.len(), 1);
277
278        if let Part {
279            part: PartKind::Text(text),
280            ..
281        } = &content.parts[0]
282        {
283            assert_eq!(text, "Hello, world!");
284        } else {
285            panic!("Expected text part");
286        }
287    }
288
289    #[test]
290    fn test_deserialize_stream_response_with_multiple_text_parts() {
291        let json_data = json!({
292            "candidates": [{
293                "content": {
294                    "parts": [
295                        {"text": "Hello, "},
296                        {"text": "world!"},
297                        {"text": " How are you?"}
298                    ],
299                    "role": "model"
300                },
301                "finishReason": "STOP",
302                "index": 0
303            }],
304            "usageMetadata": {
305                "promptTokenCount": 10,
306                "candidatesTokenCount": 8,
307                "totalTokenCount": 18
308            }
309        });
310
311        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
312        assert_eq!(response.candidates.len(), 1);
313        let content = response.candidates[0]
314            .content
315            .as_ref()
316            .expect("candidate should contain content");
317        assert_eq!(content.parts.len(), 3);
318
319        // Verify all three text parts are present
320        for (i, expected_text) in ["Hello, ", "world!", " How are you?"].iter().enumerate() {
321            if let Part {
322                part: PartKind::Text(text),
323                ..
324            } = &content.parts[i]
325            {
326                assert_eq!(text, expected_text);
327            } else {
328                panic!("Expected text part at index {}", i);
329            }
330        }
331    }
332
333    #[test]
334    fn test_deserialize_stream_response_with_multiple_tool_calls() {
335        let json_data = json!({
336            "candidates": [{
337                "content": {
338                    "parts": [
339                        {
340                            "functionCall": {
341                                "name": "get_weather",
342                                "args": {"city": "San Francisco"}
343                            }
344                        },
345                        {
346                            "functionCall": {
347                                "name": "get_temperature",
348                                "args": {"location": "New York"}
349                            }
350                        }
351                    ],
352                    "role": "model"
353                },
354                "finishReason": "STOP",
355                "index": 0
356            }],
357            "usageMetadata": {
358                "promptTokenCount": 50,
359                "candidatesTokenCount": 20,
360                "totalTokenCount": 70
361            }
362        });
363
364        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
365        let content = response.candidates[0]
366            .content
367            .as_ref()
368            .expect("candidate should contain content");
369        assert_eq!(content.parts.len(), 2);
370
371        // Verify first tool call
372        if let Part {
373            part: PartKind::FunctionCall(call),
374            ..
375        } = &content.parts[0]
376        {
377            assert_eq!(call.name, "get_weather");
378        } else {
379            panic!("Expected function call at index 0");
380        }
381
382        // Verify second tool call
383        if let Part {
384            part: PartKind::FunctionCall(call),
385            ..
386        } = &content.parts[1]
387        {
388            assert_eq!(call.name, "get_temperature");
389        } else {
390            panic!("Expected function call at index 1");
391        }
392    }
393
394    #[test]
395    fn test_deserialize_stream_response_with_mixed_parts() {
396        let json_data = json!({
397            "candidates": [{
398                "content": {
399                    "parts": [
400                        {
401                            "text": "Let me think about this...",
402                            "thought": true
403                        },
404                        {
405                            "text": "Here's my response: "
406                        },
407                        {
408                            "functionCall": {
409                                "name": "search",
410                                "args": {"query": "rust async"}
411                            }
412                        },
413                        {
414                            "text": "I found the answer!"
415                        }
416                    ],
417                    "role": "model"
418                },
419                "finishReason": "STOP",
420                "index": 0
421            }],
422            "usageMetadata": {
423                "promptTokenCount": 100,
424                "candidatesTokenCount": 50,
425                "thoughtsTokenCount": 15,
426                "totalTokenCount": 165
427            }
428        });
429
430        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
431        let content = response.candidates[0]
432            .content
433            .as_ref()
434            .expect("candidate should contain content");
435        let parts = &content.parts;
436        assert_eq!(parts.len(), 4);
437
438        // Verify reasoning (thought) part
439        if let Part {
440            part: PartKind::Text(text),
441            thought: Some(true),
442            ..
443        } = &parts[0]
444        {
445            assert_eq!(text, "Let me think about this...");
446        } else {
447            panic!("Expected thought part at index 0");
448        }
449
450        // Verify regular text
451        if let Part {
452            part: PartKind::Text(text),
453            thought,
454            ..
455        } = &parts[1]
456        {
457            assert_eq!(text, "Here's my response: ");
458            assert!(thought.is_none() || thought == &Some(false));
459        } else {
460            panic!("Expected text part at index 1");
461        }
462
463        // Verify tool call
464        if let Part {
465            part: PartKind::FunctionCall(call),
466            ..
467        } = &parts[2]
468        {
469            assert_eq!(call.name, "search");
470        } else {
471            panic!("Expected function call at index 2");
472        }
473
474        // Verify final text
475        if let Part {
476            part: PartKind::Text(text),
477            ..
478        } = &parts[3]
479        {
480            assert_eq!(text, "I found the answer!");
481        } else {
482            panic!("Expected text part at index 3");
483        }
484    }
485
486    #[test]
487    fn test_deserialize_stream_response_with_empty_parts() {
488        let json_data = json!({
489            "candidates": [{
490                "content": {
491                    "parts": [],
492                    "role": "model"
493                },
494                "finishReason": "STOP",
495                "index": 0
496            }],
497            "usageMetadata": {
498                "promptTokenCount": 10,
499                "candidatesTokenCount": 0,
500                "totalTokenCount": 10
501            }
502        });
503
504        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
505        let content = response.candidates[0]
506            .content
507            .as_ref()
508            .expect("candidate should contain content");
509        assert_eq!(content.parts.len(), 0);
510    }
511
512    #[test]
513    fn test_partial_usage_token_calculation() {
514        let usage = PartialUsage {
515            total_token_count: 100,
516            cached_content_token_count: Some(20),
517            candidates_token_count: Some(30),
518            thoughts_token_count: Some(10),
519            prompt_token_count: 40,
520        };
521
522        let token_usage = usage.token_usage().unwrap();
523        assert_eq!(token_usage.input_tokens, 40);
524        assert_eq!(token_usage.output_tokens, 60); // 20 + 30 + 10
525        assert_eq!(token_usage.total_tokens, 100);
526    }
527
528    #[test]
529    fn test_partial_usage_with_missing_counts() {
530        let usage = PartialUsage {
531            total_token_count: 50,
532            cached_content_token_count: None,
533            candidates_token_count: Some(30),
534            thoughts_token_count: None,
535            prompt_token_count: 20,
536        };
537
538        let token_usage = usage.token_usage().unwrap();
539        assert_eq!(token_usage.input_tokens, 20);
540        assert_eq!(token_usage.output_tokens, 30); // Only candidates_token_count
541        assert_eq!(token_usage.total_tokens, 50);
542    }
543
544    #[test]
545    fn test_streaming_completion_response_token_usage() {
546        let response = StreamingCompletionResponse {
547            usage_metadata: PartialUsage {
548                total_token_count: 150,
549                cached_content_token_count: None,
550                candidates_token_count: Some(75),
551                thoughts_token_count: None,
552                prompt_token_count: 75,
553            },
554        };
555
556        let token_usage = response.token_usage().unwrap();
557        assert_eq!(token_usage.input_tokens, 75);
558        assert_eq!(token_usage.output_tokens, 75);
559        assert_eq!(token_usage.total_tokens, 150);
560    }
561}