Skip to main content

rig_core/providers/gemini/interactions_api/
streaming.rs

1use async_stream::stream;
2use futures::{Stream, StreamExt};
3use serde::{Deserialize, Serialize};
4use std::pin::Pin;
5use tracing::{Level, enabled, info_span};
6use tracing_futures::Instrument;
7
8use super::InteractionsCompletionModel;
9use super::create_request_body;
10use super::interactions_api_types::{
11    Content, ContentDelta, FunctionCallContent, FunctionCallDelta, Interaction,
12    InteractionSseEvent, InteractionUsage, Step, TextDelta, ThoughtSummaryContent,
13    ThoughtSummaryDelta,
14};
15use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
16use crate::http_client::HttpClientExt;
17use crate::http_client::Request;
18use crate::http_client::sse::{Event, GenericEventSource};
19use crate::streaming;
20use crate::telemetry::SpanCombinator;
21use serde_json::{Map, Value};
22
23/// Final metadata yielded by an Interactions streaming response.
24#[derive(Debug, Serialize, Deserialize, Default, Clone)]
25pub struct StreamingCompletionResponse {
26    pub usage: Option<InteractionUsage>,
27    pub interaction: Option<Interaction>,
28    /// Resolved model identifier (e.g. `gemini-2.5-pro-preview-05-06`), extracted from
29    /// `Interaction.model`. The Interactions API has no `FinishReason` field; use
30    /// `interaction.status` for lifecycle state.
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub model_version: Option<String>,
33}
34
35#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
36pub type InteractionEventStream =
37    Pin<Box<dyn Stream<Item = Result<InteractionSseEvent, CompletionError>> + Send>>;
38
39#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
40pub type InteractionEventStream =
41    Pin<Box<dyn Stream<Item = Result<InteractionSseEvent, CompletionError>>>>;
42
43impl GetTokenUsage for StreamingCompletionResponse {
44    fn token_usage(&self) -> crate::completion::Usage {
45        self.usage
46            .as_ref()
47            .map(|usage| usage.token_usage())
48            .unwrap_or_default()
49    }
50}
51
52impl<T> InteractionsCompletionModel<T>
53where
54    T: HttpClientExt + Clone + Default + std::fmt::Debug + 'static,
55{
56    pub(crate) async fn stream(
57        &self,
58        completion_request: CompletionRequest,
59    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
60    {
61        let span = if tracing::Span::current().is_disabled() {
62            info_span!(
63                target: "rig::completions",
64                "interactions_streaming",
65                gen_ai.operation.name = "interactions_streaming",
66                gen_ai.provider.name = "gcp.gemini",
67                gen_ai.request.model = self.model,
68                gen_ai.system_instructions = &completion_request.preamble,
69                gen_ai.response.id = tracing::field::Empty,
70                gen_ai.response.model = tracing::field::Empty,
71                gen_ai.usage.output_tokens = tracing::field::Empty,
72                gen_ai.usage.input_tokens = tracing::field::Empty,
73                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
74                gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
75                gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
76                gen_ai.usage.reasoning_tokens = tracing::field::Empty,
77            )
78        } else {
79            tracing::Span::current()
80        };
81
82        let request = create_request_body(self.model.clone(), completion_request, Some(true))?;
83
84        if enabled!(Level::TRACE) {
85            tracing::trace!(
86                target: "rig::streaming",
87                "Gemini interactions streaming request: {}",
88                serde_json::to_string_pretty(&request)?
89            );
90        }
91
92        let body = serde_json::to_vec(&request)?;
93        let req = self
94            .client
95            .post_sse("/v1beta/interactions")?
96            .header("Content-Type", "application/json")
97            .body(body)
98            .map_err(|e| CompletionError::HttpError(e.into()))?;
99
100        let mut event_source = GenericEventSource::new(self.client.clone(), req);
101
102        let stream = stream! {
103            let mut final_interaction: Option<Interaction> = None;
104            let mut final_usage: Option<InteractionUsage> = None;
105
106            while let Some(event_result) = event_source.next().await {
107                match event_result {
108                    Ok(Event::Open) => {
109                        tracing::debug!("SSE connection opened");
110                        continue;
111                    }
112                    Ok(Event::Message(message)) => {
113                        if message.data.trim().is_empty() {
114                            continue;
115                        }
116
117                        let data = match serde_json::from_str::<InteractionSseEvent>(&message.data)
118                        {
119                            Ok(data) => data,
120                            Err(err) => {
121                                tracing::debug!(
122                                    "Failed to deserialize interactions SSE event: {err}"
123                                );
124                                continue;
125                            }
126                        };
127
128                        match data {
129                            InteractionSseEvent::StepDelta { delta, .. } => {
130                                if let Some(choice) = content_delta_to_choice(delta) {
131                                    yield Ok(choice);
132                                }
133                            }
134                            InteractionSseEvent::StepStart { step, .. } => {
135                                if let Some(choice) = step_start_to_choice(step) {
136                                    yield Ok(choice);
137                                }
138                            }
139                            InteractionSseEvent::InteractionCompleted { interaction, .. } => {
140                                let span = tracing::Span::current();
141                                span.record("gen_ai.response.id", &interaction.id);
142                                if let Some(model) = interaction.model.clone() {
143                                    span.record("gen_ai.response.model", model);
144                                }
145
146                                if let Some(usage) = interaction.usage.clone() {
147                                    span.record_token_usage(&usage);
148                                    final_usage = Some(usage);
149                                }
150                                final_interaction = Some(interaction);
151                            }
152                            InteractionSseEvent::Error { error, .. } => {
153                                yield Err(CompletionError::ProviderError(error.message));
154                                break;
155                            }
156                            _ => continue,
157                        }
158                    }
159                    Err(crate::http_client::Error::StreamEnded) => {
160                        break;
161                    }
162                    Err(error) => {
163                        tracing::error!(?error, "SSE error");
164                        yield Err(CompletionError::ProviderError(error.to_string()));
165                        break;
166                    }
167                }
168            }
169
170            event_source.close();
171
172            let model_version = final_interaction.as_ref().and_then(|i| i.model.clone());
173            yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
174                usage: final_usage.or_else(|| final_interaction.as_ref().and_then(|i| i.usage.clone())),
175                interaction: final_interaction,
176                model_version,
177            }));
178        }
179        .instrument(span);
180
181        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
182            stream,
183        )))
184    }
185}
186
187pub(crate) fn stream_interaction_events<T>(
188    client: super::InteractionsClient<T>,
189    request: Request<Vec<u8>>,
190) -> InteractionEventStream
191where
192    T: HttpClientExt + Clone + Default + std::fmt::Debug + 'static,
193{
194    let mut event_source = GenericEventSource::new(client.clone(), request);
195
196    let stream = stream! {
197        while let Some(event_result) = event_source.next().await {
198            match event_result {
199                Ok(Event::Open) => continue,
200                Ok(Event::Message(message)) => {
201                    if message.data.trim().is_empty() {
202                        continue;
203                    }
204
205                    let data = serde_json::from_str::<InteractionSseEvent>(&message.data);
206                    let Ok(data) = data else {
207                        let Err(err) = data else {
208                            continue;
209                        };
210                        tracing::debug!("Failed to deserialize interactions SSE event: {err}");
211                        continue;
212                    };
213
214                    yield Ok(data);
215                }
216                Err(crate::http_client::Error::StreamEnded) => break,
217                Err(error) => {
218                    tracing::error!(?error, "SSE error");
219                    yield Err(CompletionError::ProviderError(error.to_string()));
220                    break;
221                }
222            }
223        }
224
225        event_source.close();
226    };
227
228    Box::pin(stream)
229}
230
231fn step_start_to_choice(
232    step: Step,
233) -> Option<streaming::RawStreamingChoice<StreamingCompletionResponse>> {
234    match step {
235        Step::ModelOutput { content } => content.into_iter().find_map(content_to_choice),
236        Step::FunctionCall(FunctionCallContent {
237            name,
238            arguments,
239            id,
240        }) => {
241            let name = name?;
242            let call_id = id.unwrap_or_else(|| name.clone());
243            Some(streaming::RawStreamingChoice::ToolCall(
244                streaming::RawStreamingToolCall::new(
245                    name.clone(),
246                    name,
247                    arguments.unwrap_or(Value::Object(Map::new())),
248                )
249                .with_call_id(call_id),
250            ))
251        }
252        _ => None,
253    }
254}
255
256fn content_to_choice(
257    content: Content,
258) -> Option<streaming::RawStreamingChoice<StreamingCompletionResponse>> {
259    match content {
260        Content::Text(text) if !text.text.is_empty() => {
261            Some(streaming::RawStreamingChoice::Message(text.text))
262        }
263        Content::FunctionCall(content) => step_start_to_choice(Step::FunctionCall(content)),
264        _ => None,
265    }
266}
267
268fn content_delta_to_choice(
269    delta: ContentDelta,
270) -> Option<streaming::RawStreamingChoice<StreamingCompletionResponse>> {
271    match delta {
272        ContentDelta::Text(TextDelta {
273            text: Some(text), ..
274        }) => Some(streaming::RawStreamingChoice::Message(text)),
275        ContentDelta::FunctionCall(FunctionCallDelta {
276            name,
277            arguments,
278            id,
279        }) => {
280            let name = name?;
281            let call_id = id.unwrap_or_else(|| name.clone());
282            Some(streaming::RawStreamingChoice::ToolCall(
283                streaming::RawStreamingToolCall::new(
284                    name.clone(),
285                    name,
286                    arguments.unwrap_or(Value::Object(Map::new())),
287                )
288                .with_call_id(call_id),
289            ))
290        }
291        ContentDelta::ThoughtSummary(ThoughtSummaryDelta { content }) => {
292            let text = match content {
293                ThoughtSummaryContent::Text(text) => text.text,
294                _ => return None,
295            };
296            Some(streaming::RawStreamingChoice::ReasoningDelta {
297                id: None,
298                reasoning: text,
299            })
300        }
301        _ => None,
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308    use serde_json::json;
309
310    #[test]
311    fn test_streaming_completion_response_has_model_version() {
312        let response = StreamingCompletionResponse {
313            usage: None,
314            interaction: None,
315            model_version: Some("gemini-2.5-pro-preview-05-06".to_string()),
316        };
317
318        assert_eq!(
319            response.model_version.as_deref(),
320            Some("gemini-2.5-pro-preview-05-06")
321        );
322
323        let json = serde_json::to_string(&response).unwrap();
324        let deserialized: StreamingCompletionResponse = serde_json::from_str(&json).unwrap();
325        assert_eq!(
326            deserialized.model_version.as_deref(),
327            Some("gemini-2.5-pro-preview-05-06")
328        );
329    }
330
331    #[test]
332    fn test_content_delta_text_event() {
333        let event_json = json!({
334            "event_type": "step.delta",
335            "index": 0,
336            "delta": {
337                "type": "text",
338                "text": "Hello"
339            }
340        });
341
342        let event: InteractionSseEvent = serde_json::from_value(event_json).unwrap();
343        let InteractionSseEvent::StepDelta { delta, .. } = event else {
344            panic!("expected step delta");
345        };
346
347        let choice = content_delta_to_choice(delta).expect("choice should exist");
348        match choice {
349            crate::streaming::RawStreamingChoice::Message(text) => {
350                assert_eq!(text, "Hello");
351            }
352            other => panic!("unexpected choice: {other:?}"),
353        }
354    }
355
356    #[test]
357    fn test_content_delta_function_call_event() {
358        let event_json = json!({
359            "event_type": "step.delta",
360            "index": 0,
361            "delta": {
362                "type": "function_call",
363                "name": "get_weather",
364                "arguments": {"location": "Paris"},
365                "id": "call-1"
366            }
367        });
368
369        let event: InteractionSseEvent = serde_json::from_value(event_json).unwrap();
370        let InteractionSseEvent::StepDelta { delta, .. } = event else {
371            panic!("expected step delta");
372        };
373
374        let choice = content_delta_to_choice(delta).expect("choice should exist");
375        match choice {
376            crate::streaming::RawStreamingChoice::ToolCall(call) => {
377                assert_eq!(call.name, "get_weather");
378                assert_eq!(call.call_id.as_deref(), Some("call-1"));
379            }
380            other => panic!("unexpected choice: {other:?}"),
381        }
382    }
383}