rig/providers/gemini/
streaming.rs

1use crate::{
2    http_client::{
3        HttpClientExt,
4        sse::{Event, GenericEventSource},
5    },
6    telemetry::SpanCombinator,
7};
8use async_stream::stream;
9use futures::StreamExt;
10use serde::{Deserialize, Serialize};
11use tracing::info_span;
12
13use super::completion::{
14    CompletionModel, create_request_body,
15    gemini_api_types::{ContentCandidate, Part, PartKind},
16};
17use crate::{
18    completion::{CompletionError, CompletionRequest, GetTokenUsage},
19    streaming::{self},
20};
21
22#[derive(Debug, Deserialize, Serialize, Default, Clone)]
23#[serde(rename_all = "camelCase")]
24pub struct PartialUsage {
25    pub total_token_count: i32,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub cached_content_token_count: Option<i32>,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub candidates_token_count: Option<i32>,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub thoughts_token_count: Option<i32>,
32    pub prompt_token_count: i32,
33}
34
35impl GetTokenUsage for PartialUsage {
36    fn token_usage(&self) -> Option<crate::completion::Usage> {
37        let mut usage = crate::completion::Usage::new();
38
39        usage.input_tokens = self.prompt_token_count as u64;
40        usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
41            + self.candidates_token_count.unwrap_or_default()
42            + self.thoughts_token_count.unwrap_or_default()) as u64;
43        usage.total_tokens = usage.input_tokens + usage.output_tokens;
44
45        Some(usage)
46    }
47}
48
49#[derive(Debug, Deserialize)]
50#[serde(rename_all = "camelCase")]
51pub struct StreamGenerateContentResponse {
52    /// Candidate responses from the model.
53    pub candidates: Vec<ContentCandidate>,
54    pub model_version: Option<String>,
55    pub usage_metadata: Option<PartialUsage>,
56}
57
58#[derive(Clone, Debug, Serialize, Deserialize)]
59pub struct StreamingCompletionResponse {
60    pub usage_metadata: PartialUsage,
61}
62
63impl GetTokenUsage for StreamingCompletionResponse {
64    fn token_usage(&self) -> Option<crate::completion::Usage> {
65        let mut usage = crate::completion::Usage::new();
66        usage.total_tokens = self.usage_metadata.total_token_count as u64;
67        usage.output_tokens = self
68            .usage_metadata
69            .candidates_token_count
70            .map(|x| x as u64)
71            .unwrap_or(0);
72        usage.input_tokens = self.usage_metadata.prompt_token_count as u64;
73        Some(usage)
74    }
75}
76
77impl<T> CompletionModel<T>
78where
79    T: HttpClientExt + Clone + 'static,
80{
81    pub(crate) async fn stream(
82        &self,
83        completion_request: CompletionRequest,
84    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
85    {
86        let span = if tracing::Span::current().is_disabled() {
87            info_span!(
88                target: "rig::completions",
89                "chat_streaming",
90                gen_ai.operation.name = "chat_streaming",
91                gen_ai.provider.name = "gcp.gemini",
92                gen_ai.request.model = self.model,
93                gen_ai.system_instructions = &completion_request.preamble,
94                gen_ai.response.id = tracing::field::Empty,
95                gen_ai.response.model = self.model,
96                gen_ai.usage.output_tokens = tracing::field::Empty,
97                gen_ai.usage.input_tokens = tracing::field::Empty,
98                gen_ai.input.messages = tracing::field::Empty,
99                gen_ai.output.messages = tracing::field::Empty,
100            )
101        } else {
102            tracing::Span::current()
103        };
104        let request = create_request_body(completion_request)?;
105
106        span.record_model_input(&request.contents);
107
108        tracing::debug!(
109            "Sending completion request to Gemini API {}",
110            serde_json::to_string_pretty(&request)?
111        );
112        let body = serde_json::to_vec(&request)?;
113
114        let req = self
115            .client
116            .post_sse(&format!(
117                "/v1beta/models/{}:streamGenerateContent",
118                self.model
119            ))
120            .header("Content-Type", "application/json")
121            .body(body)
122            .map_err(|e| CompletionError::HttpError(e.into()))?;
123
124        let mut event_source = GenericEventSource::new(self.client.http_client.clone(), req);
125
126        let stream = stream! {
127            let mut text_response = String::new();
128            let mut model_outputs: Vec<Part> = Vec::new();
129            while let Some(event_result) = event_source.next().await {
130                match event_result {
131                    Ok(Event::Open) => {
132                        tracing::debug!("SSE connection opened");
133                        continue;
134                    }
135                    Ok(Event::Message(message)) => {
136                        // Skip heartbeat messages or empty data
137                        if message.data.trim().is_empty() {
138                            continue;
139                        }
140
141                        let data = match serde_json::from_str::<StreamGenerateContentResponse>(&message.data) {
142                            Ok(d) => d,
143                            Err(error) => {
144                                tracing::error!(?error, message = message.data, "Failed to parse SSE message");
145                                continue;
146                            }
147                        };
148
149                        // Process the response data
150                        let Some(choice) = data.candidates.first() else {
151                            tracing::debug!("There is no content candidate");
152                            continue;
153                        };
154
155                        let Some(content) = choice.content.as_ref() else {
156                            tracing::debug!(finish_reason = ?choice.finish_reason, "Streaming candidate missing content");
157                            continue;
158                        };
159
160                        for part in &content.parts {
161                            match part {
162                                Part {
163                                    part: PartKind::Text(text),
164                                    thought: Some(true),
165                                    ..
166                                } => {
167                                    yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning: text.clone(), id: None, signature: None });
168                                },
169                                Part {
170                                    part: PartKind::Text(text),
171                                    ..
172                                } => {
173                                    text_response.push_str(text);
174                                    yield Ok(streaming::RawStreamingChoice::Message(text.clone()));
175                                },
176                                Part {
177                                    part: PartKind::FunctionCall(function_call),
178                                    ..
179                                } => {
180                                    model_outputs.push(part.clone());
181                                    yield Ok(streaming::RawStreamingChoice::ToolCall {
182                                        name: function_call.name.clone(),
183                                        id: function_call.name.clone(),
184                                        arguments: function_call.args.clone(),
185                                        call_id: None
186                                    });
187                                },
188                                part => {
189                                    tracing::warn!(?part, "Unsupported response type with streaming");
190                                }
191                            }
192                        }
193
194                        if content.parts.is_empty() {
195                            tracing::trace!(reason = ?choice.finish_reason, "There is no part in the streaming content");
196                        }
197
198                        // Check if this is the final response
199                        if choice.finish_reason.is_some() {
200                            if !text_response.is_empty() {
201                                model_outputs.push(Part { thought: None, thought_signature: None, part: PartKind::Text(text_response), additional_params: None });
202                            }
203                            let span = tracing::Span::current();
204                            span.record_model_output(&model_outputs);
205                            span.record_token_usage(&data.usage_metadata);
206                            yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
207                                usage_metadata: data.usage_metadata.unwrap_or_default()
208                            }));
209                            break;
210                        }
211                    }
212                    Err(crate::http_client::Error::StreamEnded) => {
213                        break;
214                    }
215                    Err(error) => {
216                        tracing::error!(?error, "SSE error");
217                        yield Err(CompletionError::ResponseError(error.to_string()));
218                        break;
219                    }
220                }
221            }
222
223            // Ensure event source is closed when stream ends
224            event_source.close();
225        };
226
227        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
228            stream,
229        )))
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236    use serde_json::json;
237
238    #[test]
239    fn test_deserialize_stream_response_with_single_text_part() {
240        let json_data = json!({
241            "candidates": [{
242                "content": {
243                    "parts": [
244                        {"text": "Hello, world!"}
245                    ],
246                    "role": "model"
247                },
248                "finishReason": "STOP",
249                "index": 0
250            }],
251            "usageMetadata": {
252                "promptTokenCount": 10,
253                "candidatesTokenCount": 5,
254                "totalTokenCount": 15
255            }
256        });
257
258        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
259        assert_eq!(response.candidates.len(), 1);
260        let content = response.candidates[0]
261            .content
262            .as_ref()
263            .expect("candidate should contain content");
264        assert_eq!(content.parts.len(), 1);
265
266        if let Part {
267            part: PartKind::Text(text),
268            ..
269        } = &content.parts[0]
270        {
271            assert_eq!(text, "Hello, world!");
272        } else {
273            panic!("Expected text part");
274        }
275    }
276
277    #[test]
278    fn test_deserialize_stream_response_with_multiple_text_parts() {
279        let json_data = json!({
280            "candidates": [{
281                "content": {
282                    "parts": [
283                        {"text": "Hello, "},
284                        {"text": "world!"},
285                        {"text": " How are you?"}
286                    ],
287                    "role": "model"
288                },
289                "finishReason": "STOP",
290                "index": 0
291            }],
292            "usageMetadata": {
293                "promptTokenCount": 10,
294                "candidatesTokenCount": 8,
295                "totalTokenCount": 18
296            }
297        });
298
299        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
300        assert_eq!(response.candidates.len(), 1);
301        let content = response.candidates[0]
302            .content
303            .as_ref()
304            .expect("candidate should contain content");
305        assert_eq!(content.parts.len(), 3);
306
307        // Verify all three text parts are present
308        for (i, expected_text) in ["Hello, ", "world!", " How are you?"].iter().enumerate() {
309            if let Part {
310                part: PartKind::Text(text),
311                ..
312            } = &content.parts[i]
313            {
314                assert_eq!(text, expected_text);
315            } else {
316                panic!("Expected text part at index {}", i);
317            }
318        }
319    }
320
321    #[test]
322    fn test_deserialize_stream_response_with_multiple_tool_calls() {
323        let json_data = json!({
324            "candidates": [{
325                "content": {
326                    "parts": [
327                        {
328                            "functionCall": {
329                                "name": "get_weather",
330                                "args": {"city": "San Francisco"}
331                            }
332                        },
333                        {
334                            "functionCall": {
335                                "name": "get_temperature",
336                                "args": {"location": "New York"}
337                            }
338                        }
339                    ],
340                    "role": "model"
341                },
342                "finishReason": "STOP",
343                "index": 0
344            }],
345            "usageMetadata": {
346                "promptTokenCount": 50,
347                "candidatesTokenCount": 20,
348                "totalTokenCount": 70
349            }
350        });
351
352        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
353        let content = response.candidates[0]
354            .content
355            .as_ref()
356            .expect("candidate should contain content");
357        assert_eq!(content.parts.len(), 2);
358
359        // Verify first tool call
360        if let Part {
361            part: PartKind::FunctionCall(call),
362            ..
363        } = &content.parts[0]
364        {
365            assert_eq!(call.name, "get_weather");
366        } else {
367            panic!("Expected function call at index 0");
368        }
369
370        // Verify second tool call
371        if let Part {
372            part: PartKind::FunctionCall(call),
373            ..
374        } = &content.parts[1]
375        {
376            assert_eq!(call.name, "get_temperature");
377        } else {
378            panic!("Expected function call at index 1");
379        }
380    }
381
382    #[test]
383    fn test_deserialize_stream_response_with_mixed_parts() {
384        let json_data = json!({
385            "candidates": [{
386                "content": {
387                    "parts": [
388                        {
389                            "text": "Let me think about this...",
390                            "thought": true
391                        },
392                        {
393                            "text": "Here's my response: "
394                        },
395                        {
396                            "functionCall": {
397                                "name": "search",
398                                "args": {"query": "rust async"}
399                            }
400                        },
401                        {
402                            "text": "I found the answer!"
403                        }
404                    ],
405                    "role": "model"
406                },
407                "finishReason": "STOP",
408                "index": 0
409            }],
410            "usageMetadata": {
411                "promptTokenCount": 100,
412                "candidatesTokenCount": 50,
413                "thoughtsTokenCount": 15,
414                "totalTokenCount": 165
415            }
416        });
417
418        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
419        let content = response.candidates[0]
420            .content
421            .as_ref()
422            .expect("candidate should contain content");
423        let parts = &content.parts;
424        assert_eq!(parts.len(), 4);
425
426        // Verify reasoning (thought) part
427        if let Part {
428            part: PartKind::Text(text),
429            thought: Some(true),
430            ..
431        } = &parts[0]
432        {
433            assert_eq!(text, "Let me think about this...");
434        } else {
435            panic!("Expected thought part at index 0");
436        }
437
438        // Verify regular text
439        if let Part {
440            part: PartKind::Text(text),
441            thought,
442            ..
443        } = &parts[1]
444        {
445            assert_eq!(text, "Here's my response: ");
446            assert!(thought.is_none() || thought == &Some(false));
447        } else {
448            panic!("Expected text part at index 1");
449        }
450
451        // Verify tool call
452        if let Part {
453            part: PartKind::FunctionCall(call),
454            ..
455        } = &parts[2]
456        {
457            assert_eq!(call.name, "search");
458        } else {
459            panic!("Expected function call at index 2");
460        }
461
462        // Verify final text
463        if let Part {
464            part: PartKind::Text(text),
465            ..
466        } = &parts[3]
467        {
468            assert_eq!(text, "I found the answer!");
469        } else {
470            panic!("Expected text part at index 3");
471        }
472    }
473
474    #[test]
475    fn test_deserialize_stream_response_with_empty_parts() {
476        let json_data = json!({
477            "candidates": [{
478                "content": {
479                    "parts": [],
480                    "role": "model"
481                },
482                "finishReason": "STOP",
483                "index": 0
484            }],
485            "usageMetadata": {
486                "promptTokenCount": 10,
487                "candidatesTokenCount": 0,
488                "totalTokenCount": 10
489            }
490        });
491
492        let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
493        let content = response.candidates[0]
494            .content
495            .as_ref()
496            .expect("candidate should contain content");
497        assert_eq!(content.parts.len(), 0);
498    }
499
500    #[test]
501    fn test_partial_usage_token_calculation() {
502        let usage = PartialUsage {
503            total_token_count: 100,
504            cached_content_token_count: Some(20),
505            candidates_token_count: Some(30),
506            thoughts_token_count: Some(10),
507            prompt_token_count: 40,
508        };
509
510        let token_usage = usage.token_usage().unwrap();
511        assert_eq!(token_usage.input_tokens, 40);
512        assert_eq!(token_usage.output_tokens, 60); // 20 + 30 + 10
513        assert_eq!(token_usage.total_tokens, 100);
514    }
515
516    #[test]
517    fn test_partial_usage_with_missing_counts() {
518        let usage = PartialUsage {
519            total_token_count: 50,
520            cached_content_token_count: None,
521            candidates_token_count: Some(30),
522            thoughts_token_count: None,
523            prompt_token_count: 20,
524        };
525
526        let token_usage = usage.token_usage().unwrap();
527        assert_eq!(token_usage.input_tokens, 20);
528        assert_eq!(token_usage.output_tokens, 30); // Only candidates_token_count
529        assert_eq!(token_usage.total_tokens, 50);
530    }
531
532    #[test]
533    fn test_streaming_completion_response_token_usage() {
534        let response = StreamingCompletionResponse {
535            usage_metadata: PartialUsage {
536                total_token_count: 150,
537                cached_content_token_count: None,
538                candidates_token_count: Some(75),
539                thoughts_token_count: None,
540                prompt_token_count: 75,
541            },
542        };
543
544        let token_usage = response.token_usage().unwrap();
545        assert_eq!(token_usage.input_tokens, 75);
546        assert_eq!(token_usage.output_tokens, 75);
547        assert_eq!(token_usage.total_tokens, 150);
548    }
549}