rig_core/providers/gemini/interactions_api/
streaming.rs1use async_stream::stream;
2use futures::{Stream, StreamExt};
3use serde::{Deserialize, Serialize};
4use std::pin::Pin;
5use tracing::{Level, enabled, info_span};
6use tracing_futures::Instrument;
7
8use super::InteractionsCompletionModel;
9use super::create_request_body;
10use super::interactions_api_types::{
11 Content, ContentDelta, FunctionCallContent, FunctionCallDelta, Interaction,
12 InteractionSseEvent, InteractionUsage, Step, TextDelta, ThoughtSummaryContent,
13 ThoughtSummaryDelta,
14};
15use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
16use crate::http_client::HttpClientExt;
17use crate::http_client::Request;
18use crate::http_client::sse::{Event, GenericEventSource};
19use crate::streaming;
20use crate::telemetry::SpanCombinator;
21use serde_json::{Map, Value};
22
23#[derive(Debug, Serialize, Deserialize, Default, Clone)]
25pub struct StreamingCompletionResponse {
26 pub usage: Option<InteractionUsage>,
27 pub interaction: Option<Interaction>,
28 #[serde(skip_serializing_if = "Option::is_none")]
32 pub model_version: Option<String>,
33}
34
35#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
36pub type InteractionEventStream =
37 Pin<Box<dyn Stream<Item = Result<InteractionSseEvent, CompletionError>> + Send>>;
38
39#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
40pub type InteractionEventStream =
41 Pin<Box<dyn Stream<Item = Result<InteractionSseEvent, CompletionError>>>>;
42
43impl GetTokenUsage for StreamingCompletionResponse {
44 fn token_usage(&self) -> crate::completion::Usage {
45 self.usage
46 .as_ref()
47 .map(|usage| usage.token_usage())
48 .unwrap_or_default()
49 }
50}
51
52impl<T> InteractionsCompletionModel<T>
53where
54 T: HttpClientExt + Clone + Default + std::fmt::Debug + 'static,
55{
56 pub(crate) async fn stream(
57 &self,
58 completion_request: CompletionRequest,
59 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
60 {
61 let span = if tracing::Span::current().is_disabled() {
62 info_span!(
63 target: "rig::completions",
64 "interactions_streaming",
65 gen_ai.operation.name = "interactions_streaming",
66 gen_ai.provider.name = "gcp.gemini",
67 gen_ai.request.model = self.model,
68 gen_ai.system_instructions = &completion_request.preamble,
69 gen_ai.response.id = tracing::field::Empty,
70 gen_ai.response.model = tracing::field::Empty,
71 gen_ai.usage.output_tokens = tracing::field::Empty,
72 gen_ai.usage.input_tokens = tracing::field::Empty,
73 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
74 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
75 gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
76 gen_ai.usage.reasoning_tokens = tracing::field::Empty,
77 )
78 } else {
79 tracing::Span::current()
80 };
81
82 let request = create_request_body(self.model.clone(), completion_request, Some(true))?;
83
84 if enabled!(Level::TRACE) {
85 tracing::trace!(
86 target: "rig::streaming",
87 "Gemini interactions streaming request: {}",
88 serde_json::to_string_pretty(&request)?
89 );
90 }
91
92 let body = serde_json::to_vec(&request)?;
93 let req = self
94 .client
95 .post_sse("/v1beta/interactions")?
96 .header("Content-Type", "application/json")
97 .body(body)
98 .map_err(|e| CompletionError::HttpError(e.into()))?;
99
100 let mut event_source = GenericEventSource::new(self.client.clone(), req);
101
102 let stream = stream! {
103 let mut final_interaction: Option<Interaction> = None;
104 let mut final_usage: Option<InteractionUsage> = None;
105
106 while let Some(event_result) = event_source.next().await {
107 match event_result {
108 Ok(Event::Open) => {
109 tracing::debug!("SSE connection opened");
110 continue;
111 }
112 Ok(Event::Message(message)) => {
113 if message.data.trim().is_empty() {
114 continue;
115 }
116
117 let data = match serde_json::from_str::<InteractionSseEvent>(&message.data)
118 {
119 Ok(data) => data,
120 Err(err) => {
121 tracing::debug!(
122 "Failed to deserialize interactions SSE event: {err}"
123 );
124 continue;
125 }
126 };
127
128 match data {
129 InteractionSseEvent::StepDelta { delta, .. } => {
130 if let Some(choice) = content_delta_to_choice(delta) {
131 yield Ok(choice);
132 }
133 }
134 InteractionSseEvent::StepStart { step, .. } => {
135 if let Some(choice) = step_start_to_choice(step) {
136 yield Ok(choice);
137 }
138 }
139 InteractionSseEvent::InteractionCompleted { interaction, .. } => {
140 let span = tracing::Span::current();
141 span.record("gen_ai.response.id", &interaction.id);
142 if let Some(model) = interaction.model.clone() {
143 span.record("gen_ai.response.model", model);
144 }
145
146 if let Some(usage) = interaction.usage.clone() {
147 span.record_token_usage(&usage);
148 final_usage = Some(usage);
149 }
150 final_interaction = Some(interaction);
151 }
152 InteractionSseEvent::Error { error, .. } => {
153 yield Err(CompletionError::ProviderError(error.message));
154 break;
155 }
156 _ => continue,
157 }
158 }
159 Err(crate::http_client::Error::StreamEnded) => {
160 break;
161 }
162 Err(error) => {
163 tracing::error!(?error, "SSE error");
164 yield Err(CompletionError::ProviderError(error.to_string()));
165 break;
166 }
167 }
168 }
169
170 event_source.close();
171
172 let model_version = final_interaction.as_ref().and_then(|i| i.model.clone());
173 yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
174 usage: final_usage.or_else(|| final_interaction.as_ref().and_then(|i| i.usage.clone())),
175 interaction: final_interaction,
176 model_version,
177 }));
178 }
179 .instrument(span);
180
181 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
182 stream,
183 )))
184 }
185}
186
187pub(crate) fn stream_interaction_events<T>(
188 client: super::InteractionsClient<T>,
189 request: Request<Vec<u8>>,
190) -> InteractionEventStream
191where
192 T: HttpClientExt + Clone + Default + std::fmt::Debug + 'static,
193{
194 let mut event_source = GenericEventSource::new(client.clone(), request);
195
196 let stream = stream! {
197 while let Some(event_result) = event_source.next().await {
198 match event_result {
199 Ok(Event::Open) => continue,
200 Ok(Event::Message(message)) => {
201 if message.data.trim().is_empty() {
202 continue;
203 }
204
205 let data = serde_json::from_str::<InteractionSseEvent>(&message.data);
206 let Ok(data) = data else {
207 let Err(err) = data else {
208 continue;
209 };
210 tracing::debug!("Failed to deserialize interactions SSE event: {err}");
211 continue;
212 };
213
214 yield Ok(data);
215 }
216 Err(crate::http_client::Error::StreamEnded) => break,
217 Err(error) => {
218 tracing::error!(?error, "SSE error");
219 yield Err(CompletionError::ProviderError(error.to_string()));
220 break;
221 }
222 }
223 }
224
225 event_source.close();
226 };
227
228 Box::pin(stream)
229}
230
231fn step_start_to_choice(
232 step: Step,
233) -> Option<streaming::RawStreamingChoice<StreamingCompletionResponse>> {
234 match step {
235 Step::ModelOutput { content } => content.into_iter().find_map(content_to_choice),
236 Step::FunctionCall(FunctionCallContent {
237 name,
238 arguments,
239 id,
240 }) => {
241 let name = name?;
242 let call_id = id.unwrap_or_else(|| name.clone());
243 Some(streaming::RawStreamingChoice::ToolCall(
244 streaming::RawStreamingToolCall::new(
245 name.clone(),
246 name,
247 arguments.unwrap_or(Value::Object(Map::new())),
248 )
249 .with_call_id(call_id),
250 ))
251 }
252 _ => None,
253 }
254}
255
256fn content_to_choice(
257 content: Content,
258) -> Option<streaming::RawStreamingChoice<StreamingCompletionResponse>> {
259 match content {
260 Content::Text(text) if !text.text.is_empty() => {
261 Some(streaming::RawStreamingChoice::Message(text.text))
262 }
263 Content::FunctionCall(content) => step_start_to_choice(Step::FunctionCall(content)),
264 _ => None,
265 }
266}
267
268fn content_delta_to_choice(
269 delta: ContentDelta,
270) -> Option<streaming::RawStreamingChoice<StreamingCompletionResponse>> {
271 match delta {
272 ContentDelta::Text(TextDelta {
273 text: Some(text), ..
274 }) => Some(streaming::RawStreamingChoice::Message(text)),
275 ContentDelta::FunctionCall(FunctionCallDelta {
276 name,
277 arguments,
278 id,
279 }) => {
280 let name = name?;
281 let call_id = id.unwrap_or_else(|| name.clone());
282 Some(streaming::RawStreamingChoice::ToolCall(
283 streaming::RawStreamingToolCall::new(
284 name.clone(),
285 name,
286 arguments.unwrap_or(Value::Object(Map::new())),
287 )
288 .with_call_id(call_id),
289 ))
290 }
291 ContentDelta::ThoughtSummary(ThoughtSummaryDelta { content }) => {
292 let text = match content {
293 ThoughtSummaryContent::Text(text) => text.text,
294 _ => return None,
295 };
296 Some(streaming::RawStreamingChoice::ReasoningDelta {
297 id: None,
298 reasoning: text,
299 })
300 }
301 _ => None,
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308 use serde_json::json;
309
310 #[test]
311 fn test_streaming_completion_response_has_model_version() {
312 let response = StreamingCompletionResponse {
313 usage: None,
314 interaction: None,
315 model_version: Some("gemini-2.5-pro-preview-05-06".to_string()),
316 };
317
318 assert_eq!(
319 response.model_version.as_deref(),
320 Some("gemini-2.5-pro-preview-05-06")
321 );
322
323 let json = serde_json::to_string(&response).unwrap();
324 let deserialized: StreamingCompletionResponse = serde_json::from_str(&json).unwrap();
325 assert_eq!(
326 deserialized.model_version.as_deref(),
327 Some("gemini-2.5-pro-preview-05-06")
328 );
329 }
330
331 #[test]
332 fn test_content_delta_text_event() {
333 let event_json = json!({
334 "event_type": "step.delta",
335 "index": 0,
336 "delta": {
337 "type": "text",
338 "text": "Hello"
339 }
340 });
341
342 let event: InteractionSseEvent = serde_json::from_value(event_json).unwrap();
343 let InteractionSseEvent::StepDelta { delta, .. } = event else {
344 panic!("expected step delta");
345 };
346
347 let choice = content_delta_to_choice(delta).expect("choice should exist");
348 match choice {
349 crate::streaming::RawStreamingChoice::Message(text) => {
350 assert_eq!(text, "Hello");
351 }
352 other => panic!("unexpected choice: {other:?}"),
353 }
354 }
355
356 #[test]
357 fn test_content_delta_function_call_event() {
358 let event_json = json!({
359 "event_type": "step.delta",
360 "index": 0,
361 "delta": {
362 "type": "function_call",
363 "name": "get_weather",
364 "arguments": {"location": "Paris"},
365 "id": "call-1"
366 }
367 });
368
369 let event: InteractionSseEvent = serde_json::from_value(event_json).unwrap();
370 let InteractionSseEvent::StepDelta { delta, .. } = event else {
371 panic!("expected step delta");
372 };
373
374 let choice = content_delta_to_choice(delta).expect("choice should exist");
375 match choice {
376 crate::streaming::RawStreamingChoice::ToolCall(call) => {
377 assert_eq!(call.name, "get_weather");
378 assert_eq!(call.call_id.as_deref(), Some("call-1"));
379 }
380 other => panic!("unexpected choice: {other:?}"),
381 }
382 }
383}