rig/providers/gemini/
streaming.rs

1use crate::{
2    http_client::{
3        HttpClientExt,
4        sse::{Event, GenericEventSource},
5    },
6    telemetry::SpanCombinator,
7};
8use async_stream::stream;
9use futures::StreamExt;
10use serde::{Deserialize, Serialize};
11use tracing::info_span;
12
13use super::completion::{
14    CompletionModel, create_request_body,
15    gemini_api_types::{ContentCandidate, Part, PartKind},
16};
17use crate::{
18    completion::{CompletionError, CompletionRequest, GetTokenUsage},
19    streaming::{self},
20};
21
22#[derive(Debug, Deserialize, Serialize, Default, Clone)]
23#[serde(rename_all = "camelCase")]
24pub struct PartialUsage {
25    pub total_token_count: i32,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub cached_content_token_count: Option<i32>,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub candidates_token_count: Option<i32>,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub thoughts_token_count: Option<i32>,
32    pub prompt_token_count: i32,
33}
34
35impl GetTokenUsage for PartialUsage {
36    fn token_usage(&self) -> Option<crate::completion::Usage> {
37        let mut usage = crate::completion::Usage::new();
38
39        usage.input_tokens = self.prompt_token_count as u64;
40        usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
41            + self.candidates_token_count.unwrap_or_default()
42            + self.thoughts_token_count.unwrap_or_default()) as u64;
43        usage.total_tokens = usage.input_tokens + usage.output_tokens;
44
45        Some(usage)
46    }
47}
48
49#[derive(Debug, Deserialize)]
50#[serde(rename_all = "camelCase")]
51pub struct StreamGenerateContentResponse {
52    /// Candidate responses from the model.
53    pub candidates: Vec<ContentCandidate>,
54    pub model_version: Option<String>,
55    pub usage_metadata: Option<PartialUsage>,
56}
57
58#[derive(Clone, Debug, Serialize, Deserialize)]
59pub struct StreamingCompletionResponse {
60    pub usage_metadata: PartialUsage,
61}
62
63impl GetTokenUsage for StreamingCompletionResponse {
64    fn token_usage(&self) -> Option<crate::completion::Usage> {
65        let mut usage = crate::completion::Usage::new();
66        usage.total_tokens = self.usage_metadata.total_token_count as u64;
67        usage.output_tokens = self
68            .usage_metadata
69            .candidates_token_count
70            .map(|x| x as u64)
71            .unwrap_or(0);
72        usage.input_tokens = self.usage_metadata.prompt_token_count as u64;
73        Some(usage)
74    }
75}
76
77impl<T> CompletionModel<T>
78where
79    T: HttpClientExt + Clone + 'static,
80{
81    pub(crate) async fn stream(
82        &self,
83        completion_request: CompletionRequest,
84    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
85    {
86        let span = if tracing::Span::current().is_disabled() {
87            info_span!(
88                target: "rig::completions",
89                "chat_streaming",
90                gen_ai.operation.name = "chat_streaming",
91                gen_ai.provider.name = "gcp.gemini",
92                gen_ai.request.model = self.model,
93                gen_ai.system_instructions = &completion_request.preamble,
94                gen_ai.response.id = tracing::field::Empty,
95                gen_ai.response.model = self.model,
96                gen_ai.usage.output_tokens = tracing::field::Empty,
97                gen_ai.usage.input_tokens = tracing::field::Empty,
98                gen_ai.input.messages = tracing::field::Empty,
99                gen_ai.output.messages = tracing::field::Empty,
100            )
101        } else {
102            tracing::Span::current()
103        };
104        let request = create_request_body(completion_request)?;
105
106        span.record_model_input(&request.contents);
107
108        tracing::debug!(
109            "Sending completion request to Gemini API {}",
110            serde_json::to_string_pretty(&request)?
111        );
112        let body = serde_json::to_vec(&request)?;
113
114        let req = self
115            .client
116            .post_sse(&format!(
117                "/v1beta/models/{}:streamGenerateContent",
118                self.model
119            ))
120            .header("Content-Type", "application/json")
121            .body(body)
122            .map_err(|e| CompletionError::HttpError(e.into()))?;
123
124        let mut event_source = GenericEventSource::new(self.client.http_client.clone(), req);
125
126        let stream = stream! {
127            let mut text_response = String::new();
128            let mut model_outputs: Vec<Part> = Vec::new();
129            while let Some(event_result) = event_source.next().await {
130                match event_result {
131                    Ok(Event::Open) => {
132                        tracing::debug!("SSE connection opened");
133                        continue;
134                    }
135                    Ok(Event::Message(message)) => {
136                        // Skip heartbeat messages or empty data
137                        if message.data.trim().is_empty() {
138                            continue;
139                        }
140
141                        let data = match serde_json::from_str::<StreamGenerateContentResponse>(&message.data) {
142                            Ok(d) => d,
143                            Err(error) => {
144                                tracing::error!(?error, message = message.data, "Failed to parse SSE message");
145                                continue;
146                            }
147                        };
148
149                        // Process the response data
150                        let Some(choice) = data.candidates.first() else {
151                            tracing::debug!("There is no content candidate");
152                            continue;
153                        };
154
155                        for part in &choice.content.parts {
156                            match part {
157                                Part {
158                                    part: PartKind::Text(text),
159                                    thought: Some(true),
160                                    ..
161                                } => {
162                                    yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning: text.clone(), id: None, signature: None });
163                                },
164                                Part {
165                                    part: PartKind::Text(text),
166                                    ..
167                                } => {
168                                    text_response += text;
169                                    yield Ok(streaming::RawStreamingChoice::Message(text.clone()));
170                                },
171                                Part {
172                                    part: PartKind::FunctionCall(function_call),
173                                    ..
174                                } => {
175                                    model_outputs.push(part.clone());
176                                    yield Ok(streaming::RawStreamingChoice::ToolCall {
177                                        name: function_call.name.clone(),
178                                        id: function_call.name.clone(),
179                                        arguments: function_call.args.clone(),
180                                        call_id: None
181                                    });
182                                },
183                                part => {
184                                    tracing::warn!(?part, "Unsupported response type with streaming");
185                                }
186                            }
187                        }
188
189                        if choice.content.parts.is_empty() {
190                            tracing::trace!(reason = ?choice.finish_reason, "There is no part in the streaming content");
191                        }
192
193                        // Check if this is the final response
194                        if choice.finish_reason.is_some() {
195                            if !text_response.is_empty() {
196                                model_outputs.push(Part { thought: None, thought_signature: None, part: PartKind::Text(text_response), additional_params: None });
197                            }
198                            let span = tracing::Span::current();
199                            span.record_model_output(&model_outputs);
200                            span.record_token_usage(&data.usage_metadata);
201                            yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
202                                usage_metadata: data.usage_metadata.unwrap_or_default()
203                            }));
204                            break;
205                        }
206                    }
207                    Err(crate::http_client::Error::StreamEnded) => {
208                        break;
209                    }
210                    Err(error) => {
211                        tracing::error!(?error, "SSE error");
212                        yield Err(CompletionError::ResponseError(error.to_string()));
213                        break;
214                    }
215                }
216            }
217
218            // Ensure event source is closed when stream ends
219            event_source.close();
220        };
221
222        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
223            stream,
224        )))
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use serde_json::json;
232
233    #[test]
234    fn test_deserialize_stream_response_with_single_text_part() {
235        let json_data = json!({
236            "candidates": [{
237                "content": {
238                    "parts": [
239                        {"text": "Hello, world!"}
240                    ],
241                    "role": "model"
242                },
243                "finishReason": "STOP",
244                "index": 0
245            }],
246            "usageMetadata": {
247                "promptTokenCount": 10,
248                "candidatesTokenCount": 5,
249                "totalTokenCount": 15
250            }
251        });
252
253        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
254        assert_eq!(response.candidates.len(), 1);
255        assert_eq!(response.candidates[0].content.parts.len(), 1);
256
257        if let Part {
258            part: PartKind::Text(text),
259            ..
260        } = &response.candidates[0].content.parts[0]
261        {
262            assert_eq!(text, "Hello, world!");
263        } else {
264            panic!("Expected text part");
265        }
266    }
267
268    #[test]
269    fn test_deserialize_stream_response_with_multiple_text_parts() {
270        let json_data = json!({
271            "candidates": [{
272                "content": {
273                    "parts": [
274                        {"text": "Hello, "},
275                        {"text": "world!"},
276                        {"text": " How are you?"}
277                    ],
278                    "role": "model"
279                },
280                "finishReason": "STOP",
281                "index": 0
282            }],
283            "usageMetadata": {
284                "promptTokenCount": 10,
285                "candidatesTokenCount": 8,
286                "totalTokenCount": 18
287            }
288        });
289
290        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
291        assert_eq!(response.candidates.len(), 1);
292        assert_eq!(response.candidates[0].content.parts.len(), 3);
293
294        // Verify all three text parts are present
295        for (i, expected_text) in ["Hello, ", "world!", " How are you?"].iter().enumerate() {
296            if let Part {
297                part: PartKind::Text(text),
298                ..
299            } = &response.candidates[0].content.parts[i]
300            {
301                assert_eq!(text, expected_text);
302            } else {
303                panic!("Expected text part at index {}", i);
304            }
305        }
306    }
307
308    #[test]
309    fn test_deserialize_stream_response_with_multiple_tool_calls() {
310        let json_data = json!({
311            "candidates": [{
312                "content": {
313                    "parts": [
314                        {
315                            "functionCall": {
316                                "name": "get_weather",
317                                "args": {"city": "San Francisco"}
318                            }
319                        },
320                        {
321                            "functionCall": {
322                                "name": "get_temperature",
323                                "args": {"location": "New York"}
324                            }
325                        }
326                    ],
327                    "role": "model"
328                },
329                "finishReason": "STOP",
330                "index": 0
331            }],
332            "usageMetadata": {
333                "promptTokenCount": 50,
334                "candidatesTokenCount": 20,
335                "totalTokenCount": 70
336            }
337        });
338
339        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
340        assert_eq!(response.candidates[0].content.parts.len(), 2);
341
342        // Verify first tool call
343        if let Part {
344            part: PartKind::FunctionCall(call),
345            ..
346        } = &response.candidates[0].content.parts[0]
347        {
348            assert_eq!(call.name, "get_weather");
349        } else {
350            panic!("Expected function call at index 0");
351        }
352
353        // Verify second tool call
354        if let Part {
355            part: PartKind::FunctionCall(call),
356            ..
357        } = &response.candidates[0].content.parts[1]
358        {
359            assert_eq!(call.name, "get_temperature");
360        } else {
361            panic!("Expected function call at index 1");
362        }
363    }
364
365    #[test]
366    fn test_deserialize_stream_response_with_mixed_parts() {
367        let json_data = json!({
368            "candidates": [{
369                "content": {
370                    "parts": [
371                        {
372                            "text": "Let me think about this...",
373                            "thought": true
374                        },
375                        {
376                            "text": "Here's my response: "
377                        },
378                        {
379                            "functionCall": {
380                                "name": "search",
381                                "args": {"query": "rust async"}
382                            }
383                        },
384                        {
385                            "text": "I found the answer!"
386                        }
387                    ],
388                    "role": "model"
389                },
390                "finishReason": "STOP",
391                "index": 0
392            }],
393            "usageMetadata": {
394                "promptTokenCount": 100,
395                "candidatesTokenCount": 50,
396                "thoughtsTokenCount": 15,
397                "totalTokenCount": 165
398            }
399        });
400
401        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
402        let parts = &response.candidates[0].content.parts;
403        assert_eq!(parts.len(), 4);
404
405        // Verify reasoning (thought) part
406        if let Part {
407            part: PartKind::Text(text),
408            thought: Some(true),
409            ..
410        } = &parts[0]
411        {
412            assert_eq!(text, "Let me think about this...");
413        } else {
414            panic!("Expected thought part at index 0");
415        }
416
417        // Verify regular text
418        if let Part {
419            part: PartKind::Text(text),
420            thought,
421            ..
422        } = &parts[1]
423        {
424            assert_eq!(text, "Here's my response: ");
425            assert!(thought.is_none() || thought == &Some(false));
426        } else {
427            panic!("Expected text part at index 1");
428        }
429
430        // Verify tool call
431        if let Part {
432            part: PartKind::FunctionCall(call),
433            ..
434        } = &parts[2]
435        {
436            assert_eq!(call.name, "search");
437        } else {
438            panic!("Expected function call at index 2");
439        }
440
441        // Verify final text
442        if let Part {
443            part: PartKind::Text(text),
444            ..
445        } = &parts[3]
446        {
447            assert_eq!(text, "I found the answer!");
448        } else {
449            panic!("Expected text part at index 3");
450        }
451    }
452
453    #[test]
454    fn test_deserialize_stream_response_with_empty_parts() {
455        let json_data = json!({
456            "candidates": [{
457                "content": {
458                    "parts": [],
459                    "role": "model"
460                },
461                "finishReason": "STOP",
462                "index": 0
463            }],
464            "usageMetadata": {
465                "promptTokenCount": 10,
466                "candidatesTokenCount": 0,
467                "totalTokenCount": 10
468            }
469        });
470
471        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
472        assert_eq!(response.candidates[0].content.parts.len(), 0);
473    }
474
475    #[test]
476    fn test_partial_usage_token_calculation() {
477        let usage = PartialUsage {
478            total_token_count: 100,
479            cached_content_token_count: Some(20),
480            candidates_token_count: Some(30),
481            thoughts_token_count: Some(10),
482            prompt_token_count: 40,
483        };
484
485        let token_usage = usage.token_usage().unwrap();
486        assert_eq!(token_usage.input_tokens, 40);
487        assert_eq!(token_usage.output_tokens, 60); // 20 + 30 + 10
488        assert_eq!(token_usage.total_tokens, 100);
489    }
490
491    #[test]
492    fn test_partial_usage_with_missing_counts() {
493        let usage = PartialUsage {
494            total_token_count: 50,
495            cached_content_token_count: None,
496            candidates_token_count: Some(30),
497            thoughts_token_count: None,
498            prompt_token_count: 20,
499        };
500
501        let token_usage = usage.token_usage().unwrap();
502        assert_eq!(token_usage.input_tokens, 20);
503        assert_eq!(token_usage.output_tokens, 30); // Only candidates_token_count
504        assert_eq!(token_usage.total_tokens, 50);
505    }
506
507    #[test]
508    fn test_streaming_completion_response_token_usage() {
509        let response = StreamingCompletionResponse {
510            usage_metadata: PartialUsage {
511                total_token_count: 150,
512                cached_content_token_count: None,
513                candidates_token_count: Some(75),
514                thoughts_token_count: None,
515                prompt_token_count: 75,
516            },
517        };
518
519        let token_usage = response.token_usage().unwrap();
520        assert_eq!(token_usage.input_tokens, 75);
521        assert_eq!(token_usage.output_tokens, 75);
522        assert_eq!(token_usage.total_tokens, 150);
523    }
524}