rig/providers/gemini/
streaming.rs

1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use tracing::{Level, enabled, info_span};
5use tracing_futures::Instrument;
6
7use super::completion::gemini_api_types::{ContentCandidate, Part, PartKind};
8use super::completion::{CompletionModel, create_request_body};
9use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
10use crate::http_client::HttpClientExt;
11use crate::http_client::sse::{Event, GenericEventSource};
12use crate::streaming;
13use crate::telemetry::SpanCombinator;
14
15#[derive(Debug, Deserialize, Serialize, Default, Clone)]
16#[serde(rename_all = "camelCase")]
17pub struct PartialUsage {
18    pub total_token_count: i32,
19    #[serde(skip_serializing_if = "Option::is_none")]
20    pub cached_content_token_count: Option<i32>,
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub candidates_token_count: Option<i32>,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub thoughts_token_count: Option<i32>,
25    pub prompt_token_count: i32,
26}
27
28impl GetTokenUsage for PartialUsage {
29    fn token_usage(&self) -> Option<crate::completion::Usage> {
30        let mut usage = crate::completion::Usage::new();
31
32        usage.input_tokens = self.prompt_token_count as u64;
33        usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
34            + self.candidates_token_count.unwrap_or_default()
35            + self.thoughts_token_count.unwrap_or_default()) as u64;
36        usage.total_tokens = usage.input_tokens + usage.output_tokens;
37
38        Some(usage)
39    }
40}
41
42#[derive(Debug, Deserialize)]
43#[serde(rename_all = "camelCase")]
44pub struct StreamGenerateContentResponse {
45    /// Candidate responses from the model.
46    pub candidates: Vec<ContentCandidate>,
47    pub model_version: Option<String>,
48    pub usage_metadata: Option<PartialUsage>,
49}
50
51#[derive(Clone, Debug, Serialize, Deserialize)]
52pub struct StreamingCompletionResponse {
53    pub usage_metadata: PartialUsage,
54}
55
56impl GetTokenUsage for StreamingCompletionResponse {
57    fn token_usage(&self) -> Option<crate::completion::Usage> {
58        let mut usage = crate::completion::Usage::new();
59        usage.total_tokens = self.usage_metadata.total_token_count as u64;
60        usage.output_tokens = self
61            .usage_metadata
62            .candidates_token_count
63            .map(|x| x as u64)
64            .unwrap_or(0);
65        usage.input_tokens = self.usage_metadata.prompt_token_count as u64;
66        Some(usage)
67    }
68}
69
70impl<T> CompletionModel<T>
71where
72    T: HttpClientExt + Clone + 'static,
73{
74    pub(crate) async fn stream(
75        &self,
76        completion_request: CompletionRequest,
77    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
78    {
79        let span = if tracing::Span::current().is_disabled() {
80            info_span!(
81                target: "rig::completions",
82                "chat_streaming",
83                gen_ai.operation.name = "chat_streaming",
84                gen_ai.provider.name = "gcp.gemini",
85                gen_ai.request.model = self.model,
86                gen_ai.system_instructions = &completion_request.preamble,
87                gen_ai.response.id = tracing::field::Empty,
88                gen_ai.response.model = self.model,
89                gen_ai.usage.output_tokens = tracing::field::Empty,
90                gen_ai.usage.input_tokens = tracing::field::Empty,
91            )
92        } else {
93            tracing::Span::current()
94        };
95        let request = create_request_body(completion_request)?;
96
97        if enabled!(Level::TRACE) {
98            tracing::trace!(
99                target: "rig::streaming",
100                "Gemini streaming completion request: {}",
101                serde_json::to_string_pretty(&request)?
102            );
103        }
104
105        let body = serde_json::to_vec(&request)?;
106
107        let req = self
108            .client
109            .post_sse(format!(
110                "/v1beta/models/{}:streamGenerateContent",
111                self.model
112            ))?
113            .header("Content-Type", "application/json")
114            .body(body)
115            .map_err(|e| CompletionError::HttpError(e.into()))?;
116
117        let mut event_source = GenericEventSource::new(self.client.clone(), req);
118
119        let stream = stream! {
120            let mut final_usage = None;
121            while let Some(event_result) = event_source.next().await {
122                match event_result {
123                    Ok(Event::Open) => {
124                        tracing::debug!("SSE connection opened");
125                        continue;
126                    }
127                    Ok(Event::Message(message)) => {
128                        // Skip heartbeat messages or empty data
129                        if message.data.trim().is_empty() {
130                            continue;
131                        }
132
133                        let data = match serde_json::from_str::<StreamGenerateContentResponse>(&message.data) {
134                            Ok(d) => d,
135                            Err(error) => {
136                                tracing::error!(?error, message = message.data, "Failed to parse SSE message");
137                                continue;
138                            }
139                        };
140
141                        // Process the response data
142                        let Some(choice) = data.candidates.into_iter().next() else {
143                            tracing::debug!("There is no content candidate");
144                            continue;
145                        };
146
147                        let Some(content) = choice.content else {
148                            tracing::debug!(finish_reason = ?choice.finish_reason, "Streaming candidate missing content");
149                            continue;
150                        };
151
152                        if content.parts.is_empty() {
153                            tracing::trace!(reason = ?choice.finish_reason, "There is no part in the streaming content");
154                        }
155
156                        for part in content.parts {
157                            match part {
158                                Part {
159                                    part: PartKind::Text(text),
160                                    thought: Some(true),
161                                    ..
162                                } => {
163                                    if !text.is_empty() {
164                                        yield Ok(streaming::RawStreamingChoice::ReasoningDelta {
165                                            id: None,
166                                            reasoning: text,
167                                        });
168                                    }
169                                },
170                                Part {
171                                    part: PartKind::Text(text),
172                                    ..
173                                } => {
174                                    if !text.is_empty() {
175                                        yield Ok(streaming::RawStreamingChoice::Message(text));
176                                    }
177                                },
178                                Part {
179                                    part: PartKind::FunctionCall(function_call),
180                                    thought_signature,
181                                    ..
182                                } => {
183                                    yield Ok(streaming::RawStreamingChoice::ToolCall(
184                                        streaming::RawStreamingToolCall::new(function_call.name.clone(), function_call.name.clone(), function_call.args.clone())
185                                            .with_signature(thought_signature)
186                                    ));
187                                },
188                                part => {
189                                    tracing::warn!(?part, "Unsupported response type with streaming");
190                                }
191                            }
192                        }
193
194                        // Check if this is the final response
195                        if choice.finish_reason.is_some() {
196                            let span = tracing::Span::current();
197                            span.record_token_usage(&data.usage_metadata);
198                            final_usage = data.usage_metadata;
199                            break;
200                        }
201                    }
202                    Err(crate::http_client::Error::StreamEnded) => {
203                        break;
204                    }
205                    Err(error) => {
206                        tracing::error!(?error, "SSE error");
207                        yield Err(CompletionError::ProviderError(error.to_string()));
208                        break;
209                    }
210                }
211            }
212
213            // Ensure event source is closed when stream ends
214            event_source.close();
215
216            yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
217                usage_metadata: final_usage.unwrap_or_default()
218            }));
219        }.instrument(span);
220
221        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
222            stream,
223        )))
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use serde_json::json;
231
232    #[test]
233    fn test_deserialize_stream_response_with_single_text_part() {
234        let json_data = json!({
235            "candidates": [{
236                "content": {
237                    "parts": [
238                        {"text": "Hello, world!"}
239                    ],
240                    "role": "model"
241                },
242                "finishReason": "STOP",
243                "index": 0
244            }],
245            "usageMetadata": {
246                "promptTokenCount": 10,
247                "candidatesTokenCount": 5,
248                "totalTokenCount": 15
249            }
250        });
251
252        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
253        assert_eq!(response.candidates.len(), 1);
254        let content = response.candidates[0]
255            .content
256            .as_ref()
257            .expect("candidate should contain content");
258        assert_eq!(content.parts.len(), 1);
259
260        if let Part {
261            part: PartKind::Text(text),
262            ..
263        } = &content.parts[0]
264        {
265            assert_eq!(text, "Hello, world!");
266        } else {
267            panic!("Expected text part");
268        }
269    }
270
271    #[test]
272    fn test_deserialize_stream_response_with_multiple_text_parts() {
273        let json_data = json!({
274            "candidates": [{
275                "content": {
276                    "parts": [
277                        {"text": "Hello, "},
278                        {"text": "world!"},
279                        {"text": " How are you?"}
280                    ],
281                    "role": "model"
282                },
283                "finishReason": "STOP",
284                "index": 0
285            }],
286            "usageMetadata": {
287                "promptTokenCount": 10,
288                "candidatesTokenCount": 8,
289                "totalTokenCount": 18
290            }
291        });
292
293        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
294        assert_eq!(response.candidates.len(), 1);
295        let content = response.candidates[0]
296            .content
297            .as_ref()
298            .expect("candidate should contain content");
299        assert_eq!(content.parts.len(), 3);
300
301        // Verify all three text parts are present
302        for (i, expected_text) in ["Hello, ", "world!", " How are you?"].iter().enumerate() {
303            if let Part {
304                part: PartKind::Text(text),
305                ..
306            } = &content.parts[i]
307            {
308                assert_eq!(text, expected_text);
309            } else {
310                panic!("Expected text part at index {}", i);
311            }
312        }
313    }
314
315    #[test]
316    fn test_deserialize_stream_response_with_multiple_tool_calls() {
317        let json_data = json!({
318            "candidates": [{
319                "content": {
320                    "parts": [
321                        {
322                            "functionCall": {
323                                "name": "get_weather",
324                                "args": {"city": "San Francisco"}
325                            }
326                        },
327                        {
328                            "functionCall": {
329                                "name": "get_temperature",
330                                "args": {"location": "New York"}
331                            }
332                        }
333                    ],
334                    "role": "model"
335                },
336                "finishReason": "STOP",
337                "index": 0
338            }],
339            "usageMetadata": {
340                "promptTokenCount": 50,
341                "candidatesTokenCount": 20,
342                "totalTokenCount": 70
343            }
344        });
345
346        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
347        let content = response.candidates[0]
348            .content
349            .as_ref()
350            .expect("candidate should contain content");
351        assert_eq!(content.parts.len(), 2);
352
353        // Verify first tool call
354        if let Part {
355            part: PartKind::FunctionCall(call),
356            ..
357        } = &content.parts[0]
358        {
359            assert_eq!(call.name, "get_weather");
360        } else {
361            panic!("Expected function call at index 0");
362        }
363
364        // Verify second tool call
365        if let Part {
366            part: PartKind::FunctionCall(call),
367            ..
368        } = &content.parts[1]
369        {
370            assert_eq!(call.name, "get_temperature");
371        } else {
372            panic!("Expected function call at index 1");
373        }
374    }
375
376    #[test]
377    fn test_deserialize_stream_response_with_mixed_parts() {
378        let json_data = json!({
379            "candidates": [{
380                "content": {
381                    "parts": [
382                        {
383                            "text": "Let me think about this...",
384                            "thought": true
385                        },
386                        {
387                            "text": "Here's my response: "
388                        },
389                        {
390                            "functionCall": {
391                                "name": "search",
392                                "args": {"query": "rust async"}
393                            }
394                        },
395                        {
396                            "text": "I found the answer!"
397                        }
398                    ],
399                    "role": "model"
400                },
401                "finishReason": "STOP",
402                "index": 0
403            }],
404            "usageMetadata": {
405                "promptTokenCount": 100,
406                "candidatesTokenCount": 50,
407                "thoughtsTokenCount": 15,
408                "totalTokenCount": 165
409            }
410        });
411
412        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
413        let content = response.candidates[0]
414            .content
415            .as_ref()
416            .expect("candidate should contain content");
417        let parts = &content.parts;
418        assert_eq!(parts.len(), 4);
419
420        // Verify reasoning (thought) part
421        if let Part {
422            part: PartKind::Text(text),
423            thought: Some(true),
424            ..
425        } = &parts[0]
426        {
427            assert_eq!(text, "Let me think about this...");
428        } else {
429            panic!("Expected thought part at index 0");
430        }
431
432        // Verify regular text
433        if let Part {
434            part: PartKind::Text(text),
435            thought,
436            ..
437        } = &parts[1]
438        {
439            assert_eq!(text, "Here's my response: ");
440            assert!(thought.is_none() || thought == &Some(false));
441        } else {
442            panic!("Expected text part at index 1");
443        }
444
445        // Verify tool call
446        if let Part {
447            part: PartKind::FunctionCall(call),
448            ..
449        } = &parts[2]
450        {
451            assert_eq!(call.name, "search");
452        } else {
453            panic!("Expected function call at index 2");
454        }
455
456        // Verify final text
457        if let Part {
458            part: PartKind::Text(text),
459            ..
460        } = &parts[3]
461        {
462            assert_eq!(text, "I found the answer!");
463        } else {
464            panic!("Expected text part at index 3");
465        }
466    }
467
468    #[test]
469    fn test_deserialize_stream_response_with_empty_parts() {
470        let json_data = json!({
471            "candidates": [{
472                "content": {
473                    "parts": [],
474                    "role": "model"
475                },
476                "finishReason": "STOP",
477                "index": 0
478            }],
479            "usageMetadata": {
480                "promptTokenCount": 10,
481                "candidatesTokenCount": 0,
482                "totalTokenCount": 10
483            }
484        });
485
486        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
487        let content = response.candidates[0]
488            .content
489            .as_ref()
490            .expect("candidate should contain content");
491        assert_eq!(content.parts.len(), 0);
492    }
493
494    #[test]
495    fn test_partial_usage_token_calculation() {
496        let usage = PartialUsage {
497            total_token_count: 100,
498            cached_content_token_count: Some(20),
499            candidates_token_count: Some(30),
500            thoughts_token_count: Some(10),
501            prompt_token_count: 40,
502        };
503
504        let token_usage = usage.token_usage().unwrap();
505        assert_eq!(token_usage.input_tokens, 40);
506        assert_eq!(token_usage.output_tokens, 60); // 20 + 30 + 10
507        assert_eq!(token_usage.total_tokens, 100);
508    }
509
510    #[test]
511    fn test_partial_usage_with_missing_counts() {
512        let usage = PartialUsage {
513            total_token_count: 50,
514            cached_content_token_count: None,
515            candidates_token_count: Some(30),
516            thoughts_token_count: None,
517            prompt_token_count: 20,
518        };
519
520        let token_usage = usage.token_usage().unwrap();
521        assert_eq!(token_usage.input_tokens, 20);
522        assert_eq!(token_usage.output_tokens, 30); // Only candidates_token_count
523        assert_eq!(token_usage.total_tokens, 50);
524    }
525
526    #[test]
527    fn test_streaming_completion_response_token_usage() {
528        let response = StreamingCompletionResponse {
529            usage_metadata: PartialUsage {
530                total_token_count: 150,
531                cached_content_token_count: None,
532                candidates_token_count: Some(75),
533                thoughts_token_count: None,
534                prompt_token_count: 75,
535            },
536        };
537
538        let token_usage = response.token_usage().unwrap();
539        assert_eq!(token_usage.input_tokens, 75);
540        assert_eq!(token_usage.output_tokens, 75);
541        assert_eq!(token_usage.total_tokens, 150);
542    }
543}