rig/providers/gemini/
streaming.rs

1use crate::telemetry::SpanCombinator;
2use async_stream::stream;
3use futures::StreamExt;
4use reqwest_eventsource::{Event, RequestBuilderExt};
5use serde::{Deserialize, Serialize};
6use tracing::info_span;
7
8use super::completion::{
9    CompletionModel, create_request_body,
10    gemini_api_types::{ContentCandidate, Part, PartKind},
11};
12use crate::{
13    completion::{CompletionError, CompletionRequest, GetTokenUsage},
14    streaming::{self},
15};
16
17#[derive(Debug, Deserialize, Serialize, Default, Clone)]
18#[serde(rename_all = "camelCase")]
19pub struct PartialUsage {
20    pub total_token_count: i32,
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub cached_content_token_count: Option<i32>,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub candidates_token_count: Option<i32>,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub thoughts_token_count: Option<i32>,
27    pub prompt_token_count: i32,
28}
29
30impl GetTokenUsage for PartialUsage {
31    fn token_usage(&self) -> Option<crate::completion::Usage> {
32        let mut usage = crate::completion::Usage::new();
33
34        usage.input_tokens = self.prompt_token_count as u64;
35        usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
36            + self.candidates_token_count.unwrap_or_default()
37            + self.thoughts_token_count.unwrap_or_default()) as u64;
38        usage.total_tokens = usage.input_tokens + usage.output_tokens;
39
40        Some(usage)
41    }
42}
43
44#[derive(Debug, Deserialize)]
45#[serde(rename_all = "camelCase")]
46pub struct StreamGenerateContentResponse {
47    /// Candidate responses from the model.
48    pub candidates: Vec<ContentCandidate>,
49    pub model_version: Option<String>,
50    pub usage_metadata: Option<PartialUsage>,
51}
52
53#[derive(Clone, Debug, Serialize, Deserialize)]
54pub struct StreamingCompletionResponse {
55    pub usage_metadata: PartialUsage,
56}
57
58impl GetTokenUsage for StreamingCompletionResponse {
59    fn token_usage(&self) -> Option<crate::completion::Usage> {
60        let mut usage = crate::completion::Usage::new();
61        usage.total_tokens = self.usage_metadata.total_token_count as u64;
62        usage.output_tokens = self
63            .usage_metadata
64            .candidates_token_count
65            .map(|x| x as u64)
66            .unwrap_or(0);
67        usage.input_tokens = self.usage_metadata.prompt_token_count as u64;
68        Some(usage)
69    }
70}
71
72impl CompletionModel<reqwest::Client> {
73    pub(crate) async fn stream(
74        &self,
75        completion_request: CompletionRequest,
76    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
77    {
78        let span = if tracing::Span::current().is_disabled() {
79            info_span!(
80                target: "rig::completions",
81                "chat_streaming",
82                gen_ai.operation.name = "chat_streaming",
83                gen_ai.provider.name = "gcp.gemini",
84                gen_ai.request.model = self.model,
85                gen_ai.system_instructions = &completion_request.preamble,
86                gen_ai.response.id = tracing::field::Empty,
87                gen_ai.response.model = self.model,
88                gen_ai.usage.output_tokens = tracing::field::Empty,
89                gen_ai.usage.input_tokens = tracing::field::Empty,
90                gen_ai.input.messages = tracing::field::Empty,
91                gen_ai.output.messages = tracing::field::Empty,
92            )
93        } else {
94            tracing::Span::current()
95        };
96        let request = create_request_body(completion_request)?;
97
98        span.record_model_input(&request.contents);
99
100        tracing::debug!(
101            "Sending completion request to Gemini API {}",
102            serde_json::to_string_pretty(&request)?
103        );
104
105        // Build the request with proper headers for SSE
106        let mut event_source = self
107            .client
108            .post_sse(&format!(
109                "/v1beta/models/{}:streamGenerateContent",
110                self.model
111            ))
112            .json(&request)
113            .eventsource()
114            .expect("Cloning request must always succeed");
115
116        let stream = stream! {
117            let mut text_response = String::new();
118            let mut model_outputs: Vec<Part> = Vec::new();
119            while let Some(event_result) = event_source.next().await {
120                match event_result {
121                    Ok(Event::Open) => {
122                        tracing::trace!("SSE connection opened");
123                        continue;
124                    }
125                    Ok(Event::Message(message)) => {
126                        // Skip heartbeat messages or empty data
127                        if message.data.trim().is_empty() {
128                            continue;
129                        }
130
131                        let data = match serde_json::from_str::<StreamGenerateContentResponse>(&message.data) {
132                            Ok(d) => d,
133                            Err(error) => {
134                                tracing::error!(?error, message = message.data, "Failed to parse SSE message");
135                                continue;
136                            }
137                        };
138
139                        // Process the response data
140                        let Some(choice) = data.candidates.first() else {
141                            tracing::debug!("There is no content candidate");
142                            continue;
143                        };
144
145                        match choice.content.parts.first() {
146                            Some(Part {
147                                part: PartKind::Text(text),
148                                thought: Some(true),
149                                ..
150                            }) => {
151                                yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning: text.clone(), id: None });
152                            },
153                            Some(Part {
154                                part: PartKind::Text(text),
155                                ..
156                            }) => {
157                                text_response += text;
158                                yield Ok(streaming::RawStreamingChoice::Message(text.clone()));
159                            },
160                            Some(Part {
161                                part: PartKind::FunctionCall(function_call),
162                                ..
163                            }) => {
164                                model_outputs.push(choice.content.parts.first().cloned().expect("This should never fail"));
165                                yield Ok(streaming::RawStreamingChoice::ToolCall {
166                                    name: function_call.name.clone(),
167                                    id: function_call.name.clone(),
168                                    arguments: function_call.args.clone(),
169                                    call_id: None
170                                });
171                            },
172                            Some(part) => {
173                                tracing::warn!(?part, "Unsupported response type with streaming");
174                            }
175                            None => tracing::trace!(reason = ?choice.finish_reason, "There is no part in the streaming content"),
176                        }
177
178                        // Check if this is the final response
179                        if choice.finish_reason.is_some() {
180                            if !text_response.is_empty() {
181                                model_outputs.push(Part { thought: None, thought_signature: None, part: PartKind::Text(text_response), additional_params: None });
182                            }
183                            let span = tracing::Span::current();
184                            span.record_model_output(&model_outputs);
185                            span.record_token_usage(&data.usage_metadata);
186                            yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
187                                usage_metadata: data.usage_metadata.unwrap_or_default()
188                            }));
189                            break;
190                        }
191                    }
192                    Err(reqwest_eventsource::Error::StreamEnded) => {
193                        break;
194                    }
195                    Err(error) => {
196                        tracing::error!(?error, "SSE error");
197                        yield Err(CompletionError::ResponseError(error.to_string()));
198                        break;
199                    }
200                }
201            }
202
203            // Ensure event source is closed when stream ends
204            event_source.close();
205        };
206
207        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
208            stream,
209        )))
210    }
211}