rig_core/providers/gemini/interactions_api/
streaming.rs1use async_stream::stream;
2use futures::{Stream, StreamExt};
3use serde::{Deserialize, Serialize};
4use std::pin::Pin;
5use tracing::{Level, enabled, info_span};
6use tracing_futures::Instrument;
7
8use super::InteractionsCompletionModel;
9use super::create_request_body;
10use super::interactions_api_types::{
11 Content, ContentDelta, FunctionCallContent, FunctionCallDelta, Interaction,
12 InteractionSseEvent, InteractionUsage, TextContent, TextDelta, ThoughtSummaryContent,
13 ThoughtSummaryDelta,
14};
15use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
16use crate::http_client::HttpClientExt;
17use crate::http_client::Request;
18use crate::http_client::sse::{Event, GenericEventSource};
19use crate::streaming;
20use crate::telemetry::SpanCombinator;
21use serde_json::{Map, Value};
22
23#[derive(Debug, Serialize, Deserialize, Default, Clone)]
25pub struct StreamingCompletionResponse {
26 pub usage: Option<InteractionUsage>,
27 pub interaction: Option<Interaction>,
28 #[serde(skip_serializing_if = "Option::is_none")]
32 pub model_version: Option<String>,
33}
34
35#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
36pub type InteractionEventStream =
37 Pin<Box<dyn Stream<Item = Result<InteractionSseEvent, CompletionError>> + Send>>;
38
39#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
40pub type InteractionEventStream =
41 Pin<Box<dyn Stream<Item = Result<InteractionSseEvent, CompletionError>>>>;
42
43impl GetTokenUsage for StreamingCompletionResponse {
44 fn token_usage(&self) -> Option<crate::completion::Usage> {
45 self.usage.as_ref().and_then(|usage| usage.token_usage())
46 }
47}
48
49impl<T> InteractionsCompletionModel<T>
50where
51 T: HttpClientExt + Clone + Default + std::fmt::Debug + 'static,
52{
53 pub(crate) async fn stream(
54 &self,
55 completion_request: CompletionRequest,
56 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
57 {
58 let span = if tracing::Span::current().is_disabled() {
59 info_span!(
60 target: "rig::completions",
61 "interactions_streaming",
62 gen_ai.operation.name = "interactions_streaming",
63 gen_ai.provider.name = "gcp.gemini",
64 gen_ai.request.model = self.model,
65 gen_ai.system_instructions = &completion_request.preamble,
66 gen_ai.response.id = tracing::field::Empty,
67 gen_ai.response.model = tracing::field::Empty,
68 gen_ai.usage.output_tokens = tracing::field::Empty,
69 gen_ai.usage.input_tokens = tracing::field::Empty,
70 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
71 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
72 gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
73 gen_ai.usage.reasoning_tokens = tracing::field::Empty,
74 )
75 } else {
76 tracing::Span::current()
77 };
78
79 let request = create_request_body(self.model.clone(), completion_request, Some(true))?;
80
81 if enabled!(Level::TRACE) {
82 tracing::trace!(
83 target: "rig::streaming",
84 "Gemini interactions streaming request: {}",
85 serde_json::to_string_pretty(&request)?
86 );
87 }
88
89 let body = serde_json::to_vec(&request)?;
90 let req = self
91 .client
92 .post_sse("/v1beta/interactions")?
93 .header("Content-Type", "application/json")
94 .body(body)
95 .map_err(|e| CompletionError::HttpError(e.into()))?;
96
97 let mut event_source = GenericEventSource::new(self.client.clone(), req);
98
99 let stream = stream! {
100 let mut final_interaction: Option<Interaction> = None;
101 let mut final_usage: Option<InteractionUsage> = None;
102
103 while let Some(event_result) = event_source.next().await {
104 match event_result {
105 Ok(Event::Open) => {
106 tracing::debug!("SSE connection opened");
107 continue;
108 }
109 Ok(Event::Message(message)) => {
110 if message.data.trim().is_empty() {
111 continue;
112 }
113
114 let data = match serde_json::from_str::<InteractionSseEvent>(&message.data)
115 {
116 Ok(data) => data,
117 Err(err) => {
118 tracing::debug!(
119 "Failed to deserialize interactions SSE event: {err}"
120 );
121 continue;
122 }
123 };
124
125 match data {
126 InteractionSseEvent::ContentDelta { delta, .. } => {
127 if let Some(choice) = content_delta_to_choice(delta) {
128 yield Ok(choice);
129 }
130 }
131 InteractionSseEvent::ContentStart { content, .. } => {
132 if let Some(choice) = content_start_to_choice(content) {
133 yield Ok(choice);
134 }
135 }
136 InteractionSseEvent::InteractionComplete { interaction, .. } => {
137 let span = tracing::Span::current();
138 span.record("gen_ai.response.id", &interaction.id);
139 if let Some(model) = interaction.model.clone() {
140 span.record("gen_ai.response.model", model);
141 }
142
143 if let Some(usage) = interaction.usage.clone() {
144 span.record_token_usage(&usage);
145 final_usage = Some(usage);
146 }
147 final_interaction = Some(interaction);
148 }
149 InteractionSseEvent::Error { error, .. } => {
150 yield Err(CompletionError::ProviderError(error.message));
151 break;
152 }
153 _ => continue,
154 }
155 }
156 Err(crate::http_client::Error::StreamEnded) => {
157 break;
158 }
159 Err(error) => {
160 tracing::error!(?error, "SSE error");
161 yield Err(CompletionError::ProviderError(error.to_string()));
162 break;
163 }
164 }
165 }
166
167 event_source.close();
168
169 let model_version = final_interaction.as_ref().and_then(|i| i.model.clone());
170 yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
171 usage: final_usage.or_else(|| final_interaction.as_ref().and_then(|i| i.usage.clone())),
172 interaction: final_interaction,
173 model_version,
174 }));
175 }
176 .instrument(span);
177
178 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
179 stream,
180 )))
181 }
182}
183
184pub(crate) fn stream_interaction_events<T>(
185 client: super::InteractionsClient<T>,
186 request: Request<Vec<u8>>,
187) -> InteractionEventStream
188where
189 T: HttpClientExt + Clone + Default + std::fmt::Debug + 'static,
190{
191 let mut event_source = GenericEventSource::new(client.clone(), request);
192
193 let stream = stream! {
194 while let Some(event_result) = event_source.next().await {
195 match event_result {
196 Ok(Event::Open) => continue,
197 Ok(Event::Message(message)) => {
198 if message.data.trim().is_empty() {
199 continue;
200 }
201
202 let data = serde_json::from_str::<InteractionSseEvent>(&message.data);
203 let Ok(data) = data else {
204 let Err(err) = data else {
205 continue;
206 };
207 tracing::debug!("Failed to deserialize interactions SSE event: {err}");
208 continue;
209 };
210
211 yield Ok(data);
212 }
213 Err(crate::http_client::Error::StreamEnded) => break,
214 Err(error) => {
215 tracing::error!(?error, "SSE error");
216 yield Err(CompletionError::ProviderError(error.to_string()));
217 break;
218 }
219 }
220 }
221
222 event_source.close();
223 };
224
225 Box::pin(stream)
226}
227
228fn content_start_to_choice(
229 content: Content,
230) -> Option<streaming::RawStreamingChoice<StreamingCompletionResponse>> {
231 match content {
232 Content::Text(TextContent { text, .. }) => {
233 if text.is_empty() {
234 None
235 } else {
236 Some(streaming::RawStreamingChoice::Message(text))
237 }
238 }
239 Content::FunctionCall(FunctionCallContent {
240 name,
241 arguments,
242 id,
243 }) => {
244 let name = name?;
245 let call_id = id.unwrap_or_else(|| name.clone());
246 Some(streaming::RawStreamingChoice::ToolCall(
247 streaming::RawStreamingToolCall::new(
248 name.clone(),
249 name,
250 arguments.unwrap_or(Value::Object(Map::new())),
251 )
252 .with_call_id(call_id),
253 ))
254 }
255 _ => None,
256 }
257}
258
259fn content_delta_to_choice(
260 delta: ContentDelta,
261) -> Option<streaming::RawStreamingChoice<StreamingCompletionResponse>> {
262 match delta {
263 ContentDelta::Text(TextDelta {
264 text: Some(text), ..
265 }) => Some(streaming::RawStreamingChoice::Message(text)),
266 ContentDelta::FunctionCall(FunctionCallDelta {
267 name,
268 arguments,
269 id,
270 }) => {
271 let name = name?;
272 let call_id = id.unwrap_or_else(|| name.clone());
273 Some(streaming::RawStreamingChoice::ToolCall(
274 streaming::RawStreamingToolCall::new(
275 name.clone(),
276 name,
277 arguments.unwrap_or(Value::Object(Map::new())),
278 )
279 .with_call_id(call_id),
280 ))
281 }
282 ContentDelta::ThoughtSummary(ThoughtSummaryDelta { content }) => {
283 let text = match content {
284 ThoughtSummaryContent::Text(text) => text.text,
285 _ => return None,
286 };
287 Some(streaming::RawStreamingChoice::ReasoningDelta {
288 id: None,
289 reasoning: text,
290 })
291 }
292 _ => None,
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use serde_json::json;
300
301 #[test]
302 fn test_streaming_completion_response_has_model_version() {
303 let response = StreamingCompletionResponse {
304 usage: None,
305 interaction: None,
306 model_version: Some("gemini-2.5-pro-preview-05-06".to_string()),
307 };
308
309 assert_eq!(
310 response.model_version.as_deref(),
311 Some("gemini-2.5-pro-preview-05-06")
312 );
313
314 let json = serde_json::to_string(&response).unwrap();
315 let deserialized: StreamingCompletionResponse = serde_json::from_str(&json).unwrap();
316 assert_eq!(
317 deserialized.model_version.as_deref(),
318 Some("gemini-2.5-pro-preview-05-06")
319 );
320 }
321
322 #[test]
323 fn test_content_delta_text_event() {
324 let event_json = json!({
325 "event_type": "content.delta",
326 "index": 0,
327 "delta": {
328 "type": "text",
329 "text": "Hello"
330 }
331 });
332
333 let event: InteractionSseEvent = serde_json::from_value(event_json).unwrap();
334 let InteractionSseEvent::ContentDelta { delta, .. } = event else {
335 panic!("expected content delta");
336 };
337
338 let choice = content_delta_to_choice(delta).expect("choice should exist");
339 match choice {
340 crate::streaming::RawStreamingChoice::Message(text) => {
341 assert_eq!(text, "Hello");
342 }
343 other => panic!("unexpected choice: {other:?}"),
344 }
345 }
346
347 #[test]
348 fn test_content_delta_function_call_event() {
349 let event_json = json!({
350 "event_type": "content.delta",
351 "index": 0,
352 "delta": {
353 "type": "function_call",
354 "name": "get_weather",
355 "arguments": {"location": "Paris"},
356 "id": "call-1"
357 }
358 });
359
360 let event: InteractionSseEvent = serde_json::from_value(event_json).unwrap();
361 let InteractionSseEvent::ContentDelta { delta, .. } = event else {
362 panic!("expected content delta");
363 };
364
365 let choice = content_delta_to_choice(delta).expect("choice should exist");
366 match choice {
367 crate::streaming::RawStreamingChoice::ToolCall(call) => {
368 assert_eq!(call.name, "get_weather");
369 assert_eq!(call.call_id.as_deref(), Some("call-1"));
370 }
371 other => panic!("unexpected choice: {other:?}"),
372 }
373 }
374}