Skip to main content

rig_core/providers/gemini/interactions_api/
streaming.rs

1use async_stream::stream;
2use futures::{Stream, StreamExt};
3use serde::{Deserialize, Serialize};
4use std::pin::Pin;
5use tracing::{Level, enabled, info_span};
6use tracing_futures::Instrument;
7
8use super::InteractionsCompletionModel;
9use super::create_request_body;
10use super::interactions_api_types::{
11    Content, ContentDelta, FunctionCallContent, FunctionCallDelta, Interaction,
12    InteractionSseEvent, InteractionUsage, TextContent, TextDelta, ThoughtSummaryContent,
13    ThoughtSummaryDelta,
14};
15use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
16use crate::http_client::HttpClientExt;
17use crate::http_client::Request;
18use crate::http_client::sse::{Event, GenericEventSource};
19use crate::streaming;
20use crate::telemetry::SpanCombinator;
21use serde_json::{Map, Value};
22
23/// Final metadata yielded by an Interactions streaming response.
24#[derive(Debug, Serialize, Deserialize, Default, Clone)]
25pub struct StreamingCompletionResponse {
26    pub usage: Option<InteractionUsage>,
27    pub interaction: Option<Interaction>,
28    /// Resolved model identifier (e.g. `gemini-2.5-pro-preview-05-06`), extracted from
29    /// `Interaction.model`. The Interactions API has no `FinishReason` field; use
30    /// `interaction.status` for lifecycle state.
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub model_version: Option<String>,
33}
34
35#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
36pub type InteractionEventStream =
37    Pin<Box<dyn Stream<Item = Result<InteractionSseEvent, CompletionError>> + Send>>;
38
39#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
40pub type InteractionEventStream =
41    Pin<Box<dyn Stream<Item = Result<InteractionSseEvent, CompletionError>>>>;
42
43impl GetTokenUsage for StreamingCompletionResponse {
44    fn token_usage(&self) -> Option<crate::completion::Usage> {
45        self.usage.as_ref().and_then(|usage| usage.token_usage())
46    }
47}
48
49impl<T> InteractionsCompletionModel<T>
50where
51    T: HttpClientExt + Clone + Default + std::fmt::Debug + 'static,
52{
53    pub(crate) async fn stream(
54        &self,
55        completion_request: CompletionRequest,
56    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
57    {
58        let span = if tracing::Span::current().is_disabled() {
59            info_span!(
60                target: "rig::completions",
61                "interactions_streaming",
62                gen_ai.operation.name = "interactions_streaming",
63                gen_ai.provider.name = "gcp.gemini",
64                gen_ai.request.model = self.model,
65                gen_ai.system_instructions = &completion_request.preamble,
66                gen_ai.response.id = tracing::field::Empty,
67                gen_ai.response.model = tracing::field::Empty,
68                gen_ai.usage.output_tokens = tracing::field::Empty,
69                gen_ai.usage.input_tokens = tracing::field::Empty,
70                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
71                gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
72                gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
73                gen_ai.usage.reasoning_tokens = tracing::field::Empty,
74            )
75        } else {
76            tracing::Span::current()
77        };
78
79        let request = create_request_body(self.model.clone(), completion_request, Some(true))?;
80
81        if enabled!(Level::TRACE) {
82            tracing::trace!(
83                target: "rig::streaming",
84                "Gemini interactions streaming request: {}",
85                serde_json::to_string_pretty(&request)?
86            );
87        }
88
89        let body = serde_json::to_vec(&request)?;
90        let req = self
91            .client
92            .post_sse("/v1beta/interactions")?
93            .header("Content-Type", "application/json")
94            .body(body)
95            .map_err(|e| CompletionError::HttpError(e.into()))?;
96
97        let mut event_source = GenericEventSource::new(self.client.clone(), req);
98
99        let stream = stream! {
100            let mut final_interaction: Option<Interaction> = None;
101            let mut final_usage: Option<InteractionUsage> = None;
102
103            while let Some(event_result) = event_source.next().await {
104                match event_result {
105                    Ok(Event::Open) => {
106                        tracing::debug!("SSE connection opened");
107                        continue;
108                    }
109                    Ok(Event::Message(message)) => {
110                        if message.data.trim().is_empty() {
111                            continue;
112                        }
113
114                        let data = match serde_json::from_str::<InteractionSseEvent>(&message.data)
115                        {
116                            Ok(data) => data,
117                            Err(err) => {
118                                tracing::debug!(
119                                    "Failed to deserialize interactions SSE event: {err}"
120                                );
121                                continue;
122                            }
123                        };
124
125                        match data {
126                            InteractionSseEvent::ContentDelta { delta, .. } => {
127                                if let Some(choice) = content_delta_to_choice(delta) {
128                                    yield Ok(choice);
129                                }
130                            }
131                            InteractionSseEvent::ContentStart { content, .. } => {
132                                if let Some(choice) = content_start_to_choice(content) {
133                                    yield Ok(choice);
134                                }
135                            }
136                            InteractionSseEvent::InteractionComplete { interaction, .. } => {
137                                let span = tracing::Span::current();
138                                span.record("gen_ai.response.id", &interaction.id);
139                                if let Some(model) = interaction.model.clone() {
140                                    span.record("gen_ai.response.model", model);
141                                }
142
143                                if let Some(usage) = interaction.usage.clone() {
144                                    span.record_token_usage(&usage);
145                                    final_usage = Some(usage);
146                                }
147                                final_interaction = Some(interaction);
148                            }
149                            InteractionSseEvent::Error { error, .. } => {
150                                yield Err(CompletionError::ProviderError(error.message));
151                                break;
152                            }
153                            _ => continue,
154                        }
155                    }
156                    Err(crate::http_client::Error::StreamEnded) => {
157                        break;
158                    }
159                    Err(error) => {
160                        tracing::error!(?error, "SSE error");
161                        yield Err(CompletionError::ProviderError(error.to_string()));
162                        break;
163                    }
164                }
165            }
166
167            event_source.close();
168
169            let model_version = final_interaction.as_ref().and_then(|i| i.model.clone());
170            yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
171                usage: final_usage.or_else(|| final_interaction.as_ref().and_then(|i| i.usage.clone())),
172                interaction: final_interaction,
173                model_version,
174            }));
175        }
176        .instrument(span);
177
178        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
179            stream,
180        )))
181    }
182}
183
184pub(crate) fn stream_interaction_events<T>(
185    client: super::InteractionsClient<T>,
186    request: Request<Vec<u8>>,
187) -> InteractionEventStream
188where
189    T: HttpClientExt + Clone + Default + std::fmt::Debug + 'static,
190{
191    let mut event_source = GenericEventSource::new(client.clone(), request);
192
193    let stream = stream! {
194        while let Some(event_result) = event_source.next().await {
195            match event_result {
196                Ok(Event::Open) => continue,
197                Ok(Event::Message(message)) => {
198                    if message.data.trim().is_empty() {
199                        continue;
200                    }
201
202                    let data = serde_json::from_str::<InteractionSseEvent>(&message.data);
203                    let Ok(data) = data else {
204                        let Err(err) = data else {
205                            continue;
206                        };
207                        tracing::debug!("Failed to deserialize interactions SSE event: {err}");
208                        continue;
209                    };
210
211                    yield Ok(data);
212                }
213                Err(crate::http_client::Error::StreamEnded) => break,
214                Err(error) => {
215                    tracing::error!(?error, "SSE error");
216                    yield Err(CompletionError::ProviderError(error.to_string()));
217                    break;
218                }
219            }
220        }
221
222        event_source.close();
223    };
224
225    Box::pin(stream)
226}
227
228fn content_start_to_choice(
229    content: Content,
230) -> Option<streaming::RawStreamingChoice<StreamingCompletionResponse>> {
231    match content {
232        Content::Text(TextContent { text, .. }) => {
233            if text.is_empty() {
234                None
235            } else {
236                Some(streaming::RawStreamingChoice::Message(text))
237            }
238        }
239        Content::FunctionCall(FunctionCallContent {
240            name,
241            arguments,
242            id,
243        }) => {
244            let name = name?;
245            let call_id = id.unwrap_or_else(|| name.clone());
246            Some(streaming::RawStreamingChoice::ToolCall(
247                streaming::RawStreamingToolCall::new(
248                    name.clone(),
249                    name,
250                    arguments.unwrap_or(Value::Object(Map::new())),
251                )
252                .with_call_id(call_id),
253            ))
254        }
255        _ => None,
256    }
257}
258
259fn content_delta_to_choice(
260    delta: ContentDelta,
261) -> Option<streaming::RawStreamingChoice<StreamingCompletionResponse>> {
262    match delta {
263        ContentDelta::Text(TextDelta {
264            text: Some(text), ..
265        }) => Some(streaming::RawStreamingChoice::Message(text)),
266        ContentDelta::FunctionCall(FunctionCallDelta {
267            name,
268            arguments,
269            id,
270        }) => {
271            let name = name?;
272            let call_id = id.unwrap_or_else(|| name.clone());
273            Some(streaming::RawStreamingChoice::ToolCall(
274                streaming::RawStreamingToolCall::new(
275                    name.clone(),
276                    name,
277                    arguments.unwrap_or(Value::Object(Map::new())),
278                )
279                .with_call_id(call_id),
280            ))
281        }
282        ContentDelta::ThoughtSummary(ThoughtSummaryDelta { content }) => {
283            let text = match content {
284                ThoughtSummaryContent::Text(text) => text.text,
285                _ => return None,
286            };
287            Some(streaming::RawStreamingChoice::ReasoningDelta {
288                id: None,
289                reasoning: text,
290            })
291        }
292        _ => None,
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299    use serde_json::json;
300
301    #[test]
302    fn test_streaming_completion_response_has_model_version() {
303        let response = StreamingCompletionResponse {
304            usage: None,
305            interaction: None,
306            model_version: Some("gemini-2.5-pro-preview-05-06".to_string()),
307        };
308
309        assert_eq!(
310            response.model_version.as_deref(),
311            Some("gemini-2.5-pro-preview-05-06")
312        );
313
314        let json = serde_json::to_string(&response).unwrap();
315        let deserialized: StreamingCompletionResponse = serde_json::from_str(&json).unwrap();
316        assert_eq!(
317            deserialized.model_version.as_deref(),
318            Some("gemini-2.5-pro-preview-05-06")
319        );
320    }
321
322    #[test]
323    fn test_content_delta_text_event() {
324        let event_json = json!({
325            "event_type": "content.delta",
326            "index": 0,
327            "delta": {
328                "type": "text",
329                "text": "Hello"
330            }
331        });
332
333        let event: InteractionSseEvent = serde_json::from_value(event_json).unwrap();
334        let InteractionSseEvent::ContentDelta { delta, .. } = event else {
335            panic!("expected content delta");
336        };
337
338        let choice = content_delta_to_choice(delta).expect("choice should exist");
339        match choice {
340            crate::streaming::RawStreamingChoice::Message(text) => {
341                assert_eq!(text, "Hello");
342            }
343            other => panic!("unexpected choice: {other:?}"),
344        }
345    }
346
347    #[test]
348    fn test_content_delta_function_call_event() {
349        let event_json = json!({
350            "event_type": "content.delta",
351            "index": 0,
352            "delta": {
353                "type": "function_call",
354                "name": "get_weather",
355                "arguments": {"location": "Paris"},
356                "id": "call-1"
357            }
358        });
359
360        let event: InteractionSseEvent = serde_json::from_value(event_json).unwrap();
361        let InteractionSseEvent::ContentDelta { delta, .. } = event else {
362            panic!("expected content delta");
363        };
364
365        let choice = content_delta_to_choice(delta).expect("choice should exist");
366        match choice {
367            crate::streaming::RawStreamingChoice::ToolCall(call) => {
368                assert_eq!(call.name, "get_weather");
369                assert_eq!(call.call_id.as_deref(), Some("call-1"));
370            }
371            other => panic!("unexpected choice: {other:?}"),
372        }
373    }
374}