rig/providers/gemini/
streaming.rs

1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use tracing::{Level, enabled, info_span};
5use tracing_futures::Instrument;
6
7use super::completion::gemini_api_types::{ContentCandidate, Part, PartKind};
8use super::completion::{CompletionModel, create_request_body};
9use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
10use crate::http_client::HttpClientExt;
11use crate::http_client::sse::{Event, GenericEventSource};
12use crate::streaming;
13use crate::telemetry::SpanCombinator;
14
15#[derive(Debug, Deserialize, Serialize, Default, Clone)]
16#[serde(rename_all = "camelCase")]
17pub struct PartialUsage {
18    pub total_token_count: i32,
19    #[serde(skip_serializing_if = "Option::is_none")]
20    pub cached_content_token_count: Option<i32>,
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub candidates_token_count: Option<i32>,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub thoughts_token_count: Option<i32>,
25    pub prompt_token_count: i32,
26}
27
28impl GetTokenUsage for PartialUsage {
29    fn token_usage(&self) -> Option<crate::completion::Usage> {
30        let mut usage = crate::completion::Usage::new();
31
32        usage.input_tokens = self.prompt_token_count as u64;
33        usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
34            + self.candidates_token_count.unwrap_or_default()
35            + self.thoughts_token_count.unwrap_or_default()) as u64;
36        usage.total_tokens = usage.input_tokens + usage.output_tokens;
37
38        Some(usage)
39    }
40}
41
42#[derive(Debug, Deserialize)]
43#[serde(rename_all = "camelCase")]
44pub struct StreamGenerateContentResponse {
45    /// Candidate responses from the model.
46    pub candidates: Vec<ContentCandidate>,
47    pub model_version: Option<String>,
48    pub usage_metadata: Option<PartialUsage>,
49}
50
51#[derive(Clone, Debug, Serialize, Deserialize)]
52pub struct StreamingCompletionResponse {
53    pub usage_metadata: PartialUsage,
54}
55
56impl GetTokenUsage for StreamingCompletionResponse {
57    fn token_usage(&self) -> Option<crate::completion::Usage> {
58        let mut usage = crate::completion::Usage::new();
59        usage.total_tokens = self.usage_metadata.total_token_count as u64;
60        usage.output_tokens = self
61            .usage_metadata
62            .candidates_token_count
63            .map(|x| x as u64)
64            .unwrap_or(0);
65        usage.input_tokens = self.usage_metadata.prompt_token_count as u64;
66        Some(usage)
67    }
68}
69
70impl<T> CompletionModel<T>
71where
72    T: HttpClientExt + Clone + 'static,
73{
74    pub(crate) async fn stream(
75        &self,
76        completion_request: CompletionRequest,
77    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
78    {
79        let span = if tracing::Span::current().is_disabled() {
80            info_span!(
81                target: "rig::completions",
82                "chat_streaming",
83                gen_ai.operation.name = "chat_streaming",
84                gen_ai.provider.name = "gcp.gemini",
85                gen_ai.request.model = self.model,
86                gen_ai.system_instructions = &completion_request.preamble,
87                gen_ai.response.id = tracing::field::Empty,
88                gen_ai.response.model = self.model,
89                gen_ai.usage.output_tokens = tracing::field::Empty,
90                gen_ai.usage.input_tokens = tracing::field::Empty,
91            )
92        } else {
93            tracing::Span::current()
94        };
95        let request = create_request_body(completion_request)?;
96
97        if enabled!(Level::TRACE) {
98            tracing::trace!(
99                target: "rig::streaming",
100                "Gemini streaming completion request: {}",
101                serde_json::to_string_pretty(&request)?
102            );
103        }
104
105        let body = serde_json::to_vec(&request)?;
106
107        let req = self
108            .client
109            .post_sse(format!(
110                "/v1beta/models/{}:streamGenerateContent",
111                self.model
112            ))?
113            .header("Content-Type", "application/json")
114            .body(body)
115            .map_err(|e| CompletionError::HttpError(e.into()))?;
116
117        let mut event_source = GenericEventSource::new(self.client.clone(), req);
118
119        let stream = stream! {
120            let mut text_response = String::new();
121            let mut model_outputs: Vec<Part> = Vec::new();
122            let mut final_usage = None;
123            while let Some(event_result) = event_source.next().await {
124                match event_result {
125                    Ok(Event::Open) => {
126                        tracing::debug!("SSE connection opened");
127                        continue;
128                    }
129                    Ok(Event::Message(message)) => {
130                        // Skip heartbeat messages or empty data
131                        if message.data.trim().is_empty() {
132                            continue;
133                        }
134
135                        let data = match serde_json::from_str::<StreamGenerateContentResponse>(&message.data) {
136                            Ok(d) => d,
137                            Err(error) => {
138                                tracing::error!(?error, message = message.data, "Failed to parse SSE message");
139                                continue;
140                            }
141                        };
142
143                        // Process the response data
144                        let Some(choice) = data.candidates.first() else {
145                            tracing::debug!("There is no content candidate");
146                            continue;
147                        };
148
149                        let Some(content) = choice.content.as_ref() else {
150                            tracing::debug!(finish_reason = ?choice.finish_reason, "Streaming candidate missing content");
151                            continue;
152                        };
153
154                        for part in &content.parts {
155                            match part {
156                                Part {
157                                    part: PartKind::Text(text),
158                                    thought: Some(true),
159                                    ..
160                                } => {
161                                    yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning: text.clone(), id: None, signature: None });
162                                },
163                                Part {
164                                    part: PartKind::Text(text),
165                                    ..
166                                } => {
167                                    text_response.push_str(text);
168                                    yield Ok(streaming::RawStreamingChoice::Message(text.clone()));
169                                },
170                                Part {
171                                    part: PartKind::FunctionCall(function_call),
172                                    ..
173                                } => {
174                                    model_outputs.push(part.clone());
175                                    yield Ok(streaming::RawStreamingChoice::ToolCall {
176                                        name: function_call.name.clone(),
177                                        id: function_call.name.clone(),
178                                        arguments: function_call.args.clone(),
179                                        call_id: None
180                                    });
181                                },
182                                part => {
183                                    tracing::warn!(?part, "Unsupported response type with streaming");
184                                }
185                            }
186                        }
187
188                        if content.parts.is_empty() {
189                            tracing::trace!(reason = ?choice.finish_reason, "There is no part in the streaming content");
190                        }
191
192                        // Check if this is the final response
193                        if choice.finish_reason.is_some() {
194                            if !text_response.is_empty() {
195                                model_outputs.push(Part { thought: None, thought_signature: None, part: PartKind::Text(text_response), additional_params: None });
196                            }
197                            let span = tracing::Span::current();
198                            span.record_token_usage(&data.usage_metadata);
199                            final_usage = data.usage_metadata;
200                            break;
201                        }
202                    }
203                    Err(crate::http_client::Error::StreamEnded) => {
204                        break;
205                    }
206                    Err(error) => {
207                        tracing::error!(?error, "SSE error");
208                        yield Err(CompletionError::ProviderError(error.to_string()));
209                        break;
210                    }
211                }
212            }
213
214            // Ensure event source is closed when stream ends
215            event_source.close();
216
217            yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
218                usage_metadata: final_usage.unwrap_or_default()
219            }));
220        }.instrument(span);
221
222        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
223            stream,
224        )))
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use serde_json::json;
232
233    #[test]
234    fn test_deserialize_stream_response_with_single_text_part() {
235        let json_data = json!({
236            "candidates": [{
237                "content": {
238                    "parts": [
239                        {"text": "Hello, world!"}
240                    ],
241                    "role": "model"
242                },
243                "finishReason": "STOP",
244                "index": 0
245            }],
246            "usageMetadata": {
247                "promptTokenCount": 10,
248                "candidatesTokenCount": 5,
249                "totalTokenCount": 15
250            }
251        });
252
253        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
254        assert_eq!(response.candidates.len(), 1);
255        let content = response.candidates[0]
256            .content
257            .as_ref()
258            .expect("candidate should contain content");
259        assert_eq!(content.parts.len(), 1);
260
261        if let Part {
262            part: PartKind::Text(text),
263            ..
264        } = &content.parts[0]
265        {
266            assert_eq!(text, "Hello, world!");
267        } else {
268            panic!("Expected text part");
269        }
270    }
271
272    #[test]
273    fn test_deserialize_stream_response_with_multiple_text_parts() {
274        let json_data = json!({
275            "candidates": [{
276                "content": {
277                    "parts": [
278                        {"text": "Hello, "},
279                        {"text": "world!"},
280                        {"text": " How are you?"}
281                    ],
282                    "role": "model"
283                },
284                "finishReason": "STOP",
285                "index": 0
286            }],
287            "usageMetadata": {
288                "promptTokenCount": 10,
289                "candidatesTokenCount": 8,
290                "totalTokenCount": 18
291            }
292        });
293
294        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
295        assert_eq!(response.candidates.len(), 1);
296        let content = response.candidates[0]
297            .content
298            .as_ref()
299            .expect("candidate should contain content");
300        assert_eq!(content.parts.len(), 3);
301
302        // Verify all three text parts are present
303        for (i, expected_text) in ["Hello, ", "world!", " How are you?"].iter().enumerate() {
304            if let Part {
305                part: PartKind::Text(text),
306                ..
307            } = &content.parts[i]
308            {
309                assert_eq!(text, expected_text);
310            } else {
311                panic!("Expected text part at index {}", i);
312            }
313        }
314    }
315
316    #[test]
317    fn test_deserialize_stream_response_with_multiple_tool_calls() {
318        let json_data = json!({
319            "candidates": [{
320                "content": {
321                    "parts": [
322                        {
323                            "functionCall": {
324                                "name": "get_weather",
325                                "args": {"city": "San Francisco"}
326                            }
327                        },
328                        {
329                            "functionCall": {
330                                "name": "get_temperature",
331                                "args": {"location": "New York"}
332                            }
333                        }
334                    ],
335                    "role": "model"
336                },
337                "finishReason": "STOP",
338                "index": 0
339            }],
340            "usageMetadata": {
341                "promptTokenCount": 50,
342                "candidatesTokenCount": 20,
343                "totalTokenCount": 70
344            }
345        });
346
347        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
348        let content = response.candidates[0]
349            .content
350            .as_ref()
351            .expect("candidate should contain content");
352        assert_eq!(content.parts.len(), 2);
353
354        // Verify first tool call
355        if let Part {
356            part: PartKind::FunctionCall(call),
357            ..
358        } = &content.parts[0]
359        {
360            assert_eq!(call.name, "get_weather");
361        } else {
362            panic!("Expected function call at index 0");
363        }
364
365        // Verify second tool call
366        if let Part {
367            part: PartKind::FunctionCall(call),
368            ..
369        } = &content.parts[1]
370        {
371            assert_eq!(call.name, "get_temperature");
372        } else {
373            panic!("Expected function call at index 1");
374        }
375    }
376
377    #[test]
378    fn test_deserialize_stream_response_with_mixed_parts() {
379        let json_data = json!({
380            "candidates": [{
381                "content": {
382                    "parts": [
383                        {
384                            "text": "Let me think about this...",
385                            "thought": true
386                        },
387                        {
388                            "text": "Here's my response: "
389                        },
390                        {
391                            "functionCall": {
392                                "name": "search",
393                                "args": {"query": "rust async"}
394                            }
395                        },
396                        {
397                            "text": "I found the answer!"
398                        }
399                    ],
400                    "role": "model"
401                },
402                "finishReason": "STOP",
403                "index": 0
404            }],
405            "usageMetadata": {
406                "promptTokenCount": 100,
407                "candidatesTokenCount": 50,
408                "thoughtsTokenCount": 15,
409                "totalTokenCount": 165
410            }
411        });
412
413        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
414        let content = response.candidates[0]
415            .content
416            .as_ref()
417            .expect("candidate should contain content");
418        let parts = &content.parts;
419        assert_eq!(parts.len(), 4);
420
421        // Verify reasoning (thought) part
422        if let Part {
423            part: PartKind::Text(text),
424            thought: Some(true),
425            ..
426        } = &parts[0]
427        {
428            assert_eq!(text, "Let me think about this...");
429        } else {
430            panic!("Expected thought part at index 0");
431        }
432
433        // Verify regular text
434        if let Part {
435            part: PartKind::Text(text),
436            thought,
437            ..
438        } = &parts[1]
439        {
440            assert_eq!(text, "Here's my response: ");
441            assert!(thought.is_none() || thought == &Some(false));
442        } else {
443            panic!("Expected text part at index 1");
444        }
445
446        // Verify tool call
447        if let Part {
448            part: PartKind::FunctionCall(call),
449            ..
450        } = &parts[2]
451        {
452            assert_eq!(call.name, "search");
453        } else {
454            panic!("Expected function call at index 2");
455        }
456
457        // Verify final text
458        if let Part {
459            part: PartKind::Text(text),
460            ..
461        } = &parts[3]
462        {
463            assert_eq!(text, "I found the answer!");
464        } else {
465            panic!("Expected text part at index 3");
466        }
467    }
468
469    #[test]
470    fn test_deserialize_stream_response_with_empty_parts() {
471        let json_data = json!({
472            "candidates": [{
473                "content": {
474                    "parts": [],
475                    "role": "model"
476                },
477                "finishReason": "STOP",
478                "index": 0
479            }],
480            "usageMetadata": {
481                "promptTokenCount": 10,
482                "candidatesTokenCount": 0,
483                "totalTokenCount": 10
484            }
485        });
486
487        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
488        let content = response.candidates[0]
489            .content
490            .as_ref()
491            .expect("candidate should contain content");
492        assert_eq!(content.parts.len(), 0);
493    }
494
495    #[test]
496    fn test_partial_usage_token_calculation() {
497        let usage = PartialUsage {
498            total_token_count: 100,
499            cached_content_token_count: Some(20),
500            candidates_token_count: Some(30),
501            thoughts_token_count: Some(10),
502            prompt_token_count: 40,
503        };
504
505        let token_usage = usage.token_usage().unwrap();
506        assert_eq!(token_usage.input_tokens, 40);
507        assert_eq!(token_usage.output_tokens, 60); // 20 + 30 + 10
508        assert_eq!(token_usage.total_tokens, 100);
509    }
510
511    #[test]
512    fn test_partial_usage_with_missing_counts() {
513        let usage = PartialUsage {
514            total_token_count: 50,
515            cached_content_token_count: None,
516            candidates_token_count: Some(30),
517            thoughts_token_count: None,
518            prompt_token_count: 20,
519        };
520
521        let token_usage = usage.token_usage().unwrap();
522        assert_eq!(token_usage.input_tokens, 20);
523        assert_eq!(token_usage.output_tokens, 30); // Only candidates_token_count
524        assert_eq!(token_usage.total_tokens, 50);
525    }
526
527    #[test]
528    fn test_streaming_completion_response_token_usage() {
529        let response = StreamingCompletionResponse {
530            usage_metadata: PartialUsage {
531                total_token_count: 150,
532                cached_content_token_count: None,
533                candidates_token_count: Some(75),
534                thoughts_token_count: None,
535                prompt_token_count: 75,
536            },
537        };
538
539        let token_usage = response.token_usage().unwrap();
540        assert_eq!(token_usage.input_tokens, 75);
541        assert_eq!(token_usage.output_tokens, 75);
542        assert_eq!(token_usage.total_tokens, 150);
543    }
544}