rig/providers/gemini/
streaming.rs

1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use tracing::{Level, enabled, info_span};
5use tracing_futures::Instrument;
6
7use super::completion::gemini_api_types::{ContentCandidate, Part, PartKind};
8use super::completion::{CompletionModel, create_request_body};
9use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
10use crate::http_client::HttpClientExt;
11use crate::http_client::sse::{Event, GenericEventSource};
12use crate::streaming;
13use crate::telemetry::SpanCombinator;
14
15#[derive(Debug, Deserialize, Serialize, Default, Clone)]
16#[serde(rename_all = "camelCase")]
17pub struct PartialUsage {
18    pub total_token_count: i32,
19    #[serde(skip_serializing_if = "Option::is_none")]
20    pub cached_content_token_count: Option<i32>,
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub candidates_token_count: Option<i32>,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub thoughts_token_count: Option<i32>,
25    pub prompt_token_count: i32,
26}
27
28impl GetTokenUsage for PartialUsage {
29    fn token_usage(&self) -> Option<crate::completion::Usage> {
30        let mut usage = crate::completion::Usage::new();
31
32        usage.input_tokens = self.prompt_token_count as u64;
33        usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
34            + self.candidates_token_count.unwrap_or_default()
35            + self.thoughts_token_count.unwrap_or_default()) as u64;
36        usage.total_tokens = usage.input_tokens + usage.output_tokens;
37
38        Some(usage)
39    }
40}
41
42#[derive(Debug, Deserialize)]
43#[serde(rename_all = "camelCase")]
44pub struct StreamGenerateContentResponse {
45    /// Candidate responses from the model.
46    pub candidates: Vec<ContentCandidate>,
47    pub model_version: Option<String>,
48    pub usage_metadata: Option<PartialUsage>,
49}
50
51#[derive(Clone, Debug, Serialize, Deserialize)]
52pub struct StreamingCompletionResponse {
53    pub usage_metadata: PartialUsage,
54}
55
56impl GetTokenUsage for StreamingCompletionResponse {
57    fn token_usage(&self) -> Option<crate::completion::Usage> {
58        let mut usage = crate::completion::Usage::new();
59        usage.total_tokens = self.usage_metadata.total_token_count as u64;
60        usage.output_tokens = self
61            .usage_metadata
62            .candidates_token_count
63            .map(|x| x as u64)
64            .unwrap_or(0);
65        usage.input_tokens = self.usage_metadata.prompt_token_count as u64;
66        Some(usage)
67    }
68}
69
70impl<T> CompletionModel<T>
71where
72    T: HttpClientExt + Clone + 'static,
73{
74    pub(crate) async fn stream(
75        &self,
76        completion_request: CompletionRequest,
77    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
78    {
79        let span = if tracing::Span::current().is_disabled() {
80            info_span!(
81                target: "rig::completions",
82                "chat_streaming",
83                gen_ai.operation.name = "chat_streaming",
84                gen_ai.provider.name = "gcp.gemini",
85                gen_ai.request.model = self.model,
86                gen_ai.system_instructions = &completion_request.preamble,
87                gen_ai.response.id = tracing::field::Empty,
88                gen_ai.response.model = self.model,
89                gen_ai.usage.output_tokens = tracing::field::Empty,
90                gen_ai.usage.input_tokens = tracing::field::Empty,
91            )
92        } else {
93            tracing::Span::current()
94        };
95        let request = create_request_body(completion_request)?;
96
97        if enabled!(Level::TRACE) {
98            tracing::trace!(
99                target: "rig::streaming",
100                "Gemini streaming completion request: {}",
101                serde_json::to_string_pretty(&request)?
102            );
103        }
104
105        let body = serde_json::to_vec(&request)?;
106
107        let req = self
108            .client
109            .post_sse(format!(
110                "/v1beta/models/{}:streamGenerateContent",
111                self.model
112            ))?
113            .header("Content-Type", "application/json")
114            .body(body)
115            .map_err(|e| CompletionError::HttpError(e.into()))?;
116
117        let mut event_source = GenericEventSource::new(self.client.clone(), req);
118
119        let stream = stream! {
120            let mut final_usage = None;
121            while let Some(event_result) = event_source.next().await {
122                match event_result {
123                    Ok(Event::Open) => {
124                        tracing::debug!("SSE connection opened");
125                        continue;
126                    }
127                    Ok(Event::Message(message)) => {
128                        // Skip heartbeat messages or empty data
129                        if message.data.trim().is_empty() {
130                            continue;
131                        }
132
133                        let data = match serde_json::from_str::<StreamGenerateContentResponse>(&message.data) {
134                            Ok(d) => d,
135                            Err(error) => {
136                                tracing::error!(?error, message = message.data, "Failed to parse SSE message");
137                                continue;
138                            }
139                        };
140
141                        // Process the response data
142                        let Some(choice) = data.candidates.into_iter().next() else {
143                            tracing::debug!("There is no content candidate");
144                            continue;
145                        };
146
147                        let Some(content) = choice.content else {
148                            tracing::debug!(finish_reason = ?choice.finish_reason, "Streaming candidate missing content");
149                            continue;
150                        };
151
152                        if content.parts.is_empty() {
153                            tracing::trace!(reason = ?choice.finish_reason, "There is no part in the streaming content");
154                        }
155
156                        for part in content.parts {
157                            match part {
158                                Part {
159                                    part: PartKind::Text(text),
160                                    thought: Some(true),
161                                    ..
162                                } => {
163                                    yield Ok(streaming::RawStreamingChoice::ReasoningDelta {
164                                        id: None,
165                                        reasoning: text,
166                                    });
167                                },
168                                Part {
169                                    part: PartKind::Text(text),
170                                    ..
171                                } => {
172                                    yield Ok(streaming::RawStreamingChoice::Message(text));
173                                },
174                                Part {
175                                    part: PartKind::FunctionCall(function_call),
176                                    thought_signature,
177                                    ..
178                                } => {
179                                    yield Ok(streaming::RawStreamingChoice::ToolCall(
180                                        streaming::RawStreamingToolCall::new(function_call.name.clone(), function_call.name.clone(), function_call.args.clone())
181                                            .with_signature(thought_signature)
182                                    ));
183                                },
184                                part => {
185                                    tracing::warn!(?part, "Unsupported response type with streaming");
186                                }
187                            }
188                        }
189
190                        // Check if this is the final response
191                        if choice.finish_reason.is_some() {
192                            let span = tracing::Span::current();
193                            span.record_token_usage(&data.usage_metadata);
194                            final_usage = data.usage_metadata;
195                            break;
196                        }
197                    }
198                    Err(crate::http_client::Error::StreamEnded) => {
199                        break;
200                    }
201                    Err(error) => {
202                        tracing::error!(?error, "SSE error");
203                        yield Err(CompletionError::ProviderError(error.to_string()));
204                        break;
205                    }
206                }
207            }
208
209            // Ensure event source is closed when stream ends
210            event_source.close();
211
212            yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
213                usage_metadata: final_usage.unwrap_or_default()
214            }));
215        }.instrument(span);
216
217        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
218            stream,
219        )))
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226    use serde_json::json;
227
228    #[test]
229    fn test_deserialize_stream_response_with_single_text_part() {
230        let json_data = json!({
231            "candidates": [{
232                "content": {
233                    "parts": [
234                        {"text": "Hello, world!"}
235                    ],
236                    "role": "model"
237                },
238                "finishReason": "STOP",
239                "index": 0
240            }],
241            "usageMetadata": {
242                "promptTokenCount": 10,
243                "candidatesTokenCount": 5,
244                "totalTokenCount": 15
245            }
246        });
247
248        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
249        assert_eq!(response.candidates.len(), 1);
250        let content = response.candidates[0]
251            .content
252            .as_ref()
253            .expect("candidate should contain content");
254        assert_eq!(content.parts.len(), 1);
255
256        if let Part {
257            part: PartKind::Text(text),
258            ..
259        } = &content.parts[0]
260        {
261            assert_eq!(text, "Hello, world!");
262        } else {
263            panic!("Expected text part");
264        }
265    }
266
267    #[test]
268    fn test_deserialize_stream_response_with_multiple_text_parts() {
269        let json_data = json!({
270            "candidates": [{
271                "content": {
272                    "parts": [
273                        {"text": "Hello, "},
274                        {"text": "world!"},
275                        {"text": " How are you?"}
276                    ],
277                    "role": "model"
278                },
279                "finishReason": "STOP",
280                "index": 0
281            }],
282            "usageMetadata": {
283                "promptTokenCount": 10,
284                "candidatesTokenCount": 8,
285                "totalTokenCount": 18
286            }
287        });
288
289        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
290        assert_eq!(response.candidates.len(), 1);
291        let content = response.candidates[0]
292            .content
293            .as_ref()
294            .expect("candidate should contain content");
295        assert_eq!(content.parts.len(), 3);
296
297        // Verify all three text parts are present
298        for (i, expected_text) in ["Hello, ", "world!", " How are you?"].iter().enumerate() {
299            if let Part {
300                part: PartKind::Text(text),
301                ..
302            } = &content.parts[i]
303            {
304                assert_eq!(text, expected_text);
305            } else {
306                panic!("Expected text part at index {}", i);
307            }
308        }
309    }
310
311    #[test]
312    fn test_deserialize_stream_response_with_multiple_tool_calls() {
313        let json_data = json!({
314            "candidates": [{
315                "content": {
316                    "parts": [
317                        {
318                            "functionCall": {
319                                "name": "get_weather",
320                                "args": {"city": "San Francisco"}
321                            }
322                        },
323                        {
324                            "functionCall": {
325                                "name": "get_temperature",
326                                "args": {"location": "New York"}
327                            }
328                        }
329                    ],
330                    "role": "model"
331                },
332                "finishReason": "STOP",
333                "index": 0
334            }],
335            "usageMetadata": {
336                "promptTokenCount": 50,
337                "candidatesTokenCount": 20,
338                "totalTokenCount": 70
339            }
340        });
341
342        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
343        let content = response.candidates[0]
344            .content
345            .as_ref()
346            .expect("candidate should contain content");
347        assert_eq!(content.parts.len(), 2);
348
349        // Verify first tool call
350        if let Part {
351            part: PartKind::FunctionCall(call),
352            ..
353        } = &content.parts[0]
354        {
355            assert_eq!(call.name, "get_weather");
356        } else {
357            panic!("Expected function call at index 0");
358        }
359
360        // Verify second tool call
361        if let Part {
362            part: PartKind::FunctionCall(call),
363            ..
364        } = &content.parts[1]
365        {
366            assert_eq!(call.name, "get_temperature");
367        } else {
368            panic!("Expected function call at index 1");
369        }
370    }
371
372    #[test]
373    fn test_deserialize_stream_response_with_mixed_parts() {
374        let json_data = json!({
375            "candidates": [{
376                "content": {
377                    "parts": [
378                        {
379                            "text": "Let me think about this...",
380                            "thought": true
381                        },
382                        {
383                            "text": "Here's my response: "
384                        },
385                        {
386                            "functionCall": {
387                                "name": "search",
388                                "args": {"query": "rust async"}
389                            }
390                        },
391                        {
392                            "text": "I found the answer!"
393                        }
394                    ],
395                    "role": "model"
396                },
397                "finishReason": "STOP",
398                "index": 0
399            }],
400            "usageMetadata": {
401                "promptTokenCount": 100,
402                "candidatesTokenCount": 50,
403                "thoughtsTokenCount": 15,
404                "totalTokenCount": 165
405            }
406        });
407
408        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
409        let content = response.candidates[0]
410            .content
411            .as_ref()
412            .expect("candidate should contain content");
413        let parts = &content.parts;
414        assert_eq!(parts.len(), 4);
415
416        // Verify reasoning (thought) part
417        if let Part {
418            part: PartKind::Text(text),
419            thought: Some(true),
420            ..
421        } = &parts[0]
422        {
423            assert_eq!(text, "Let me think about this...");
424        } else {
425            panic!("Expected thought part at index 0");
426        }
427
428        // Verify regular text
429        if let Part {
430            part: PartKind::Text(text),
431            thought,
432            ..
433        } = &parts[1]
434        {
435            assert_eq!(text, "Here's my response: ");
436            assert!(thought.is_none() || thought == &Some(false));
437        } else {
438            panic!("Expected text part at index 1");
439        }
440
441        // Verify tool call
442        if let Part {
443            part: PartKind::FunctionCall(call),
444            ..
445        } = &parts[2]
446        {
447            assert_eq!(call.name, "search");
448        } else {
449            panic!("Expected function call at index 2");
450        }
451
452        // Verify final text
453        if let Part {
454            part: PartKind::Text(text),
455            ..
456        } = &parts[3]
457        {
458            assert_eq!(text, "I found the answer!");
459        } else {
460            panic!("Expected text part at index 3");
461        }
462    }
463
464    #[test]
465    fn test_deserialize_stream_response_with_empty_parts() {
466        let json_data = json!({
467            "candidates": [{
468                "content": {
469                    "parts": [],
470                    "role": "model"
471                },
472                "finishReason": "STOP",
473                "index": 0
474            }],
475            "usageMetadata": {
476                "promptTokenCount": 10,
477                "candidatesTokenCount": 0,
478                "totalTokenCount": 10
479            }
480        });
481
482        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
483        let content = response.candidates[0]
484            .content
485            .as_ref()
486            .expect("candidate should contain content");
487        assert_eq!(content.parts.len(), 0);
488    }
489
490    #[test]
491    fn test_partial_usage_token_calculation() {
492        let usage = PartialUsage {
493            total_token_count: 100,
494            cached_content_token_count: Some(20),
495            candidates_token_count: Some(30),
496            thoughts_token_count: Some(10),
497            prompt_token_count: 40,
498        };
499
500        let token_usage = usage.token_usage().unwrap();
501        assert_eq!(token_usage.input_tokens, 40);
502        assert_eq!(token_usage.output_tokens, 60); // 20 + 30 + 10
503        assert_eq!(token_usage.total_tokens, 100);
504    }
505
506    #[test]
507    fn test_partial_usage_with_missing_counts() {
508        let usage = PartialUsage {
509            total_token_count: 50,
510            cached_content_token_count: None,
511            candidates_token_count: Some(30),
512            thoughts_token_count: None,
513            prompt_token_count: 20,
514        };
515
516        let token_usage = usage.token_usage().unwrap();
517        assert_eq!(token_usage.input_tokens, 20);
518        assert_eq!(token_usage.output_tokens, 30); // Only candidates_token_count
519        assert_eq!(token_usage.total_tokens, 50);
520    }
521
522    #[test]
523    fn test_streaming_completion_response_token_usage() {
524        let response = StreamingCompletionResponse {
525            usage_metadata: PartialUsage {
526                total_token_count: 150,
527                cached_content_token_count: None,
528                candidates_token_count: Some(75),
529                thoughts_token_count: None,
530                prompt_token_count: 75,
531            },
532        };
533
534        let token_usage = response.token_usage().unwrap();
535        assert_eq!(token_usage.input_tokens, 75);
536        assert_eq!(token_usage.output_tokens, 75);
537        assert_eq!(token_usage.total_tokens, 150);
538    }
539}