rig/providers/gemini/
streaming.rs

1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use tracing::info_span;
5
6use super::completion::gemini_api_types::{ContentCandidate, Part, PartKind};
7use super::completion::{CompletionModel, create_request_body};
8use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
9use crate::http_client::HttpClientExt;
10use crate::http_client::sse::{Event, GenericEventSource};
11use crate::streaming;
12use crate::telemetry::SpanCombinator;
13
14#[derive(Debug, Deserialize, Serialize, Default, Clone)]
15#[serde(rename_all = "camelCase")]
16pub struct PartialUsage {
17    pub total_token_count: i32,
18    #[serde(skip_serializing_if = "Option::is_none")]
19    pub cached_content_token_count: Option<i32>,
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub candidates_token_count: Option<i32>,
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub thoughts_token_count: Option<i32>,
24    pub prompt_token_count: i32,
25}
26
27impl GetTokenUsage for PartialUsage {
28    fn token_usage(&self) -> Option<crate::completion::Usage> {
29        let mut usage = crate::completion::Usage::new();
30
31        usage.input_tokens = self.prompt_token_count as u64;
32        usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
33            + self.candidates_token_count.unwrap_or_default()
34            + self.thoughts_token_count.unwrap_or_default()) as u64;
35        usage.total_tokens = usage.input_tokens + usage.output_tokens;
36
37        Some(usage)
38    }
39}
40
41#[derive(Debug, Deserialize)]
42#[serde(rename_all = "camelCase")]
43pub struct StreamGenerateContentResponse {
44    /// Candidate responses from the model.
45    pub candidates: Vec<ContentCandidate>,
46    pub model_version: Option<String>,
47    pub usage_metadata: Option<PartialUsage>,
48}
49
50#[derive(Clone, Debug, Serialize, Deserialize)]
51pub struct StreamingCompletionResponse {
52    pub usage_metadata: PartialUsage,
53}
54
55impl GetTokenUsage for StreamingCompletionResponse {
56    fn token_usage(&self) -> Option<crate::completion::Usage> {
57        let mut usage = crate::completion::Usage::new();
58        usage.total_tokens = self.usage_metadata.total_token_count as u64;
59        usage.output_tokens = self
60            .usage_metadata
61            .candidates_token_count
62            .map(|x| x as u64)
63            .unwrap_or(0);
64        usage.input_tokens = self.usage_metadata.prompt_token_count as u64;
65        Some(usage)
66    }
67}
68
69impl<T> CompletionModel<T>
70where
71    T: HttpClientExt + Clone + 'static,
72{
73    pub(crate) async fn stream(
74        &self,
75        completion_request: CompletionRequest,
76    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
77    {
78        let span = if tracing::Span::current().is_disabled() {
79            info_span!(
80                target: "rig::completions",
81                "chat_streaming",
82                gen_ai.operation.name = "chat_streaming",
83                gen_ai.provider.name = "gcp.gemini",
84                gen_ai.request.model = self.model,
85                gen_ai.system_instructions = &completion_request.preamble,
86                gen_ai.response.id = tracing::field::Empty,
87                gen_ai.response.model = self.model,
88                gen_ai.usage.output_tokens = tracing::field::Empty,
89                gen_ai.usage.input_tokens = tracing::field::Empty,
90                gen_ai.input.messages = tracing::field::Empty,
91                gen_ai.output.messages = tracing::field::Empty,
92            )
93        } else {
94            tracing::Span::current()
95        };
96        let request = create_request_body(completion_request)?;
97
98        span.record_model_input(&request.contents);
99
100        tracing::trace!(
101            target: "rig::streaming",
102            "Sending completion request to Gemini API {}",
103            serde_json::to_string_pretty(&request)?
104        );
105        let body = serde_json::to_vec(&request)?;
106
107        let req = self
108            .client
109            .post_sse(format!(
110                "/v1beta/models/{}:streamGenerateContent",
111                self.model
112            ))?
113            .header("Content-Type", "application/json")
114            .body(body)
115            .map_err(|e| CompletionError::HttpError(e.into()))?;
116
117        let mut event_source = GenericEventSource::new(self.client.http_client().clone(), req);
118
119        let stream = stream! {
120            let mut text_response = String::new();
121            let mut model_outputs: Vec<Part> = Vec::new();
122            let mut final_usage = None;
123            while let Some(event_result) = event_source.next().await {
124                match event_result {
125                    Ok(Event::Open) => {
126                        tracing::debug!("SSE connection opened");
127                        continue;
128                    }
129                    Ok(Event::Message(message)) => {
130                        // Skip heartbeat messages or empty data
131                        if message.data.trim().is_empty() {
132                            continue;
133                        }
134
135                        let data = match serde_json::from_str::<StreamGenerateContentResponse>(&message.data) {
136                            Ok(d) => d,
137                            Err(error) => {
138                                tracing::error!(?error, message = message.data, "Failed to parse SSE message");
139                                continue;
140                            }
141                        };
142
143                        // Process the response data
144                        let Some(choice) = data.candidates.first() else {
145                            tracing::debug!("There is no content candidate");
146                            continue;
147                        };
148
149                        let Some(content) = choice.content.as_ref() else {
150                            tracing::debug!(finish_reason = ?choice.finish_reason, "Streaming candidate missing content");
151                            continue;
152                        };
153
154                        for part in &content.parts {
155                            match part {
156                                Part {
157                                    part: PartKind::Text(text),
158                                    thought: Some(true),
159                                    ..
160                                } => {
161                                    yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning: text.clone(), id: None, signature: None });
162                                },
163                                Part {
164                                    part: PartKind::Text(text),
165                                    ..
166                                } => {
167                                    text_response.push_str(text);
168                                    yield Ok(streaming::RawStreamingChoice::Message(text.clone()));
169                                },
170                                Part {
171                                    part: PartKind::FunctionCall(function_call),
172                                    ..
173                                } => {
174                                    model_outputs.push(part.clone());
175                                    yield Ok(streaming::RawStreamingChoice::ToolCall {
176                                        name: function_call.name.clone(),
177                                        id: function_call.name.clone(),
178                                        arguments: function_call.args.clone(),
179                                        call_id: None
180                                    });
181                                },
182                                part => {
183                                    tracing::warn!(?part, "Unsupported response type with streaming");
184                                }
185                            }
186                        }
187
188                        if content.parts.is_empty() {
189                            tracing::trace!(reason = ?choice.finish_reason, "There is no part in the streaming content");
190                        }
191
192                        // Check if this is the final response
193                        if choice.finish_reason.is_some() {
194                            if !text_response.is_empty() {
195                                model_outputs.push(Part { thought: None, thought_signature: None, part: PartKind::Text(text_response), additional_params: None });
196                            }
197                            let span = tracing::Span::current();
198                            span.record_model_output(&model_outputs);
199                            span.record_token_usage(&data.usage_metadata);
200                            final_usage = data.usage_metadata;
201                            break;
202                        }
203                    }
204                    Err(crate::http_client::Error::StreamEnded) => {
205                        break;
206                    }
207                    Err(error) => {
208                        tracing::error!(?error, "SSE error");
209                        yield Err(CompletionError::ProviderError(error.to_string()));
210                        break;
211                    }
212                }
213            }
214
215            // Ensure event source is closed when stream ends
216            event_source.close();
217
218            yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
219                usage_metadata: final_usage.unwrap_or_default()
220            }));
221        };
222
223        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
224            stream,
225        )))
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use serde_json::json;
233
234    #[test]
235    fn test_deserialize_stream_response_with_single_text_part() {
236        let json_data = json!({
237            "candidates": [{
238                "content": {
239                    "parts": [
240                        {"text": "Hello, world!"}
241                    ],
242                    "role": "model"
243                },
244                "finishReason": "STOP",
245                "index": 0
246            }],
247            "usageMetadata": {
248                "promptTokenCount": 10,
249                "candidatesTokenCount": 5,
250                "totalTokenCount": 15
251            }
252        });
253
254        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
255        assert_eq!(response.candidates.len(), 1);
256        let content = response.candidates[0]
257            .content
258            .as_ref()
259            .expect("candidate should contain content");
260        assert_eq!(content.parts.len(), 1);
261
262        if let Part {
263            part: PartKind::Text(text),
264            ..
265        } = &content.parts[0]
266        {
267            assert_eq!(text, "Hello, world!");
268        } else {
269            panic!("Expected text part");
270        }
271    }
272
273    #[test]
274    fn test_deserialize_stream_response_with_multiple_text_parts() {
275        let json_data = json!({
276            "candidates": [{
277                "content": {
278                    "parts": [
279                        {"text": "Hello, "},
280                        {"text": "world!"},
281                        {"text": " How are you?"}
282                    ],
283                    "role": "model"
284                },
285                "finishReason": "STOP",
286                "index": 0
287            }],
288            "usageMetadata": {
289                "promptTokenCount": 10,
290                "candidatesTokenCount": 8,
291                "totalTokenCount": 18
292            }
293        });
294
295        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
296        assert_eq!(response.candidates.len(), 1);
297        let content = response.candidates[0]
298            .content
299            .as_ref()
300            .expect("candidate should contain content");
301        assert_eq!(content.parts.len(), 3);
302
303        // Verify all three text parts are present
304        for (i, expected_text) in ["Hello, ", "world!", " How are you?"].iter().enumerate() {
305            if let Part {
306                part: PartKind::Text(text),
307                ..
308            } = &content.parts[i]
309            {
310                assert_eq!(text, expected_text);
311            } else {
312                panic!("Expected text part at index {}", i);
313            }
314        }
315    }
316
317    #[test]
318    fn test_deserialize_stream_response_with_multiple_tool_calls() {
319        let json_data = json!({
320            "candidates": [{
321                "content": {
322                    "parts": [
323                        {
324                            "functionCall": {
325                                "name": "get_weather",
326                                "args": {"city": "San Francisco"}
327                            }
328                        },
329                        {
330                            "functionCall": {
331                                "name": "get_temperature",
332                                "args": {"location": "New York"}
333                            }
334                        }
335                    ],
336                    "role": "model"
337                },
338                "finishReason": "STOP",
339                "index": 0
340            }],
341            "usageMetadata": {
342                "promptTokenCount": 50,
343                "candidatesTokenCount": 20,
344                "totalTokenCount": 70
345            }
346        });
347
348        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
349        let content = response.candidates[0]
350            .content
351            .as_ref()
352            .expect("candidate should contain content");
353        assert_eq!(content.parts.len(), 2);
354
355        // Verify first tool call
356        if let Part {
357            part: PartKind::FunctionCall(call),
358            ..
359        } = &content.parts[0]
360        {
361            assert_eq!(call.name, "get_weather");
362        } else {
363            panic!("Expected function call at index 0");
364        }
365
366        // Verify second tool call
367        if let Part {
368            part: PartKind::FunctionCall(call),
369            ..
370        } = &content.parts[1]
371        {
372            assert_eq!(call.name, "get_temperature");
373        } else {
374            panic!("Expected function call at index 1");
375        }
376    }
377
378    #[test]
379    fn test_deserialize_stream_response_with_mixed_parts() {
380        let json_data = json!({
381            "candidates": [{
382                "content": {
383                    "parts": [
384                        {
385                            "text": "Let me think about this...",
386                            "thought": true
387                        },
388                        {
389                            "text": "Here's my response: "
390                        },
391                        {
392                            "functionCall": {
393                                "name": "search",
394                                "args": {"query": "rust async"}
395                            }
396                        },
397                        {
398                            "text": "I found the answer!"
399                        }
400                    ],
401                    "role": "model"
402                },
403                "finishReason": "STOP",
404                "index": 0
405            }],
406            "usageMetadata": {
407                "promptTokenCount": 100,
408                "candidatesTokenCount": 50,
409                "thoughtsTokenCount": 15,
410                "totalTokenCount": 165
411            }
412        });
413
414        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
415        let content = response.candidates[0]
416            .content
417            .as_ref()
418            .expect("candidate should contain content");
419        let parts = &content.parts;
420        assert_eq!(parts.len(), 4);
421
422        // Verify reasoning (thought) part
423        if let Part {
424            part: PartKind::Text(text),
425            thought: Some(true),
426            ..
427        } = &parts[0]
428        {
429            assert_eq!(text, "Let me think about this...");
430        } else {
431            panic!("Expected thought part at index 0");
432        }
433
434        // Verify regular text
435        if let Part {
436            part: PartKind::Text(text),
437            thought,
438            ..
439        } = &parts[1]
440        {
441            assert_eq!(text, "Here's my response: ");
442            assert!(thought.is_none() || thought == &Some(false));
443        } else {
444            panic!("Expected text part at index 1");
445        }
446
447        // Verify tool call
448        if let Part {
449            part: PartKind::FunctionCall(call),
450            ..
451        } = &parts[2]
452        {
453            assert_eq!(call.name, "search");
454        } else {
455            panic!("Expected function call at index 2");
456        }
457
458        // Verify final text
459        if let Part {
460            part: PartKind::Text(text),
461            ..
462        } = &parts[3]
463        {
464            assert_eq!(text, "I found the answer!");
465        } else {
466            panic!("Expected text part at index 3");
467        }
468    }
469
470    #[test]
471    fn test_deserialize_stream_response_with_empty_parts() {
472        let json_data = json!({
473            "candidates": [{
474                "content": {
475                    "parts": [],
476                    "role": "model"
477                },
478                "finishReason": "STOP",
479                "index": 0
480            }],
481            "usageMetadata": {
482                "promptTokenCount": 10,
483                "candidatesTokenCount": 0,
484                "totalTokenCount": 10
485            }
486        });
487
488        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
489        let content = response.candidates[0]
490            .content
491            .as_ref()
492            .expect("candidate should contain content");
493        assert_eq!(content.parts.len(), 0);
494    }
495
496    #[test]
497    fn test_partial_usage_token_calculation() {
498        let usage = PartialUsage {
499            total_token_count: 100,
500            cached_content_token_count: Some(20),
501            candidates_token_count: Some(30),
502            thoughts_token_count: Some(10),
503            prompt_token_count: 40,
504        };
505
506        let token_usage = usage.token_usage().unwrap();
507        assert_eq!(token_usage.input_tokens, 40);
508        assert_eq!(token_usage.output_tokens, 60); // 20 + 30 + 10
509        assert_eq!(token_usage.total_tokens, 100);
510    }
511
512    #[test]
513    fn test_partial_usage_with_missing_counts() {
514        let usage = PartialUsage {
515            total_token_count: 50,
516            cached_content_token_count: None,
517            candidates_token_count: Some(30),
518            thoughts_token_count: None,
519            prompt_token_count: 20,
520        };
521
522        let token_usage = usage.token_usage().unwrap();
523        assert_eq!(token_usage.input_tokens, 20);
524        assert_eq!(token_usage.output_tokens, 30); // Only candidates_token_count
525        assert_eq!(token_usage.total_tokens, 50);
526    }
527
528    #[test]
529    fn test_streaming_completion_response_token_usage() {
530        let response = StreamingCompletionResponse {
531            usage_metadata: PartialUsage {
532                total_token_count: 150,
533                cached_content_token_count: None,
534                candidates_token_count: Some(75),
535                thoughts_token_count: None,
536                prompt_token_count: 75,
537            },
538        };
539
540        let token_usage = response.token_usage().unwrap();
541        assert_eq!(token_usage.input_tokens, 75);
542        assert_eq!(token_usage.output_tokens, 75);
543        assert_eq!(token_usage.total_tokens, 150);
544    }
545}