Skip to main content

rig_core/providers/gemini/interactions_api/
streaming.rs

1use async_stream::stream;
2use futures::{Stream, StreamExt};
3use serde::{Deserialize, Serialize};
4use std::pin::Pin;
5use tracing::{Level, enabled, info_span};
6use tracing_futures::Instrument;
7
8use super::InteractionsCompletionModel;
9use super::create_request_body;
10use super::interactions_api_types::{
11    Content, ContentDelta, FunctionCallContent, FunctionCallDelta, Interaction,
12    InteractionSseEvent, InteractionUsage, TextContent, TextDelta, ThoughtSummaryContent,
13    ThoughtSummaryDelta,
14};
15use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
16use crate::http_client::HttpClientExt;
17use crate::http_client::Request;
18use crate::http_client::sse::{Event, GenericEventSource};
19use crate::streaming;
20use crate::telemetry::SpanCombinator;
21use serde_json::{Map, Value};
22
23/// Final metadata yielded by an Interactions streaming response.
24#[derive(Debug, Serialize, Deserialize, Default, Clone)]
25pub struct StreamingCompletionResponse {
26    pub usage: Option<InteractionUsage>,
27    pub interaction: Option<Interaction>,
28}
29
30#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
31pub type InteractionEventStream =
32    Pin<Box<dyn Stream<Item = Result<InteractionSseEvent, CompletionError>> + Send>>;
33
34#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
35pub type InteractionEventStream =
36    Pin<Box<dyn Stream<Item = Result<InteractionSseEvent, CompletionError>>>>;
37
38impl GetTokenUsage for StreamingCompletionResponse {
39    fn token_usage(&self) -> Option<crate::completion::Usage> {
40        self.usage.as_ref().and_then(|usage| usage.token_usage())
41    }
42}
43
44impl<T> InteractionsCompletionModel<T>
45where
46    T: HttpClientExt + Clone + Default + std::fmt::Debug + 'static,
47{
48    pub(crate) async fn stream(
49        &self,
50        completion_request: CompletionRequest,
51    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
52    {
53        let span = if tracing::Span::current().is_disabled() {
54            info_span!(
55                target: "rig::completions",
56                "interactions_streaming",
57                gen_ai.operation.name = "interactions_streaming",
58                gen_ai.provider.name = "gcp.gemini",
59                gen_ai.request.model = self.model,
60                gen_ai.system_instructions = &completion_request.preamble,
61                gen_ai.response.id = tracing::field::Empty,
62                gen_ai.response.model = tracing::field::Empty,
63                gen_ai.usage.output_tokens = tracing::field::Empty,
64                gen_ai.usage.input_tokens = tracing::field::Empty,
65                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
66                gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
67                gen_ai.usage.reasoning_tokens = tracing::field::Empty,
68            )
69        } else {
70            tracing::Span::current()
71        };
72
73        let request = create_request_body(self.model.clone(), completion_request, Some(true))?;
74
75        if enabled!(Level::TRACE) {
76            tracing::trace!(
77                target: "rig::streaming",
78                "Gemini interactions streaming request: {}",
79                serde_json::to_string_pretty(&request)?
80            );
81        }
82
83        let body = serde_json::to_vec(&request)?;
84        let req = self
85            .client
86            .post_sse("/v1beta/interactions")?
87            .header("Content-Type", "application/json")
88            .body(body)
89            .map_err(|e| CompletionError::HttpError(e.into()))?;
90
91        let mut event_source = GenericEventSource::new(self.client.clone(), req);
92
93        let stream = stream! {
94            let mut final_interaction: Option<Interaction> = None;
95            let mut final_usage: Option<InteractionUsage> = None;
96
97            while let Some(event_result) = event_source.next().await {
98                match event_result {
99                    Ok(Event::Open) => {
100                        tracing::debug!("SSE connection opened");
101                        continue;
102                    }
103                    Ok(Event::Message(message)) => {
104                        if message.data.trim().is_empty() {
105                            continue;
106                        }
107
108                        let data = match serde_json::from_str::<InteractionSseEvent>(&message.data)
109                        {
110                            Ok(data) => data,
111                            Err(err) => {
112                                tracing::debug!(
113                                    "Failed to deserialize interactions SSE event: {err}"
114                                );
115                                continue;
116                            }
117                        };
118
119                        match data {
120                            InteractionSseEvent::ContentDelta { delta, .. } => {
121                                if let Some(choice) = content_delta_to_choice(delta) {
122                                    yield Ok(choice);
123                                }
124                            }
125                            InteractionSseEvent::ContentStart { content, .. } => {
126                                if let Some(choice) = content_start_to_choice(content) {
127                                    yield Ok(choice);
128                                }
129                            }
130                            InteractionSseEvent::InteractionComplete { interaction, .. } => {
131                                let span = tracing::Span::current();
132                                span.record("gen_ai.response.id", &interaction.id);
133                                if let Some(model) = interaction.model.clone() {
134                                    span.record("gen_ai.response.model", model);
135                                }
136
137                                if let Some(usage) = interaction.usage.clone() {
138                                    span.record_token_usage(&usage);
139                                    final_usage = Some(usage);
140                                }
141                                final_interaction = Some(interaction);
142                            }
143                            InteractionSseEvent::Error { error, .. } => {
144                                yield Err(CompletionError::ProviderError(error.message));
145                                break;
146                            }
147                            _ => continue,
148                        }
149                    }
150                    Err(crate::http_client::Error::StreamEnded) => {
151                        break;
152                    }
153                    Err(error) => {
154                        tracing::error!(?error, "SSE error");
155                        yield Err(CompletionError::ProviderError(error.to_string()));
156                        break;
157                    }
158                }
159            }
160
161            event_source.close();
162
163            yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
164                usage: final_usage.or_else(|| final_interaction.as_ref().and_then(|i| i.usage.clone())),
165                interaction: final_interaction,
166            }));
167        }
168        .instrument(span);
169
170        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
171            stream,
172        )))
173    }
174}
175
176pub(crate) fn stream_interaction_events<T>(
177    client: super::InteractionsClient<T>,
178    request: Request<Vec<u8>>,
179) -> InteractionEventStream
180where
181    T: HttpClientExt + Clone + Default + std::fmt::Debug + 'static,
182{
183    let mut event_source = GenericEventSource::new(client.clone(), request);
184
185    let stream = stream! {
186        while let Some(event_result) = event_source.next().await {
187            match event_result {
188                Ok(Event::Open) => continue,
189                Ok(Event::Message(message)) => {
190                    if message.data.trim().is_empty() {
191                        continue;
192                    }
193
194                    let data = serde_json::from_str::<InteractionSseEvent>(&message.data);
195                    let Ok(data) = data else {
196                        let Err(err) = data else {
197                            continue;
198                        };
199                        tracing::debug!("Failed to deserialize interactions SSE event: {err}");
200                        continue;
201                    };
202
203                    yield Ok(data);
204                }
205                Err(crate::http_client::Error::StreamEnded) => break,
206                Err(error) => {
207                    tracing::error!(?error, "SSE error");
208                    yield Err(CompletionError::ProviderError(error.to_string()));
209                    break;
210                }
211            }
212        }
213
214        event_source.close();
215    };
216
217    Box::pin(stream)
218}
219
220fn content_start_to_choice(
221    content: Content,
222) -> Option<streaming::RawStreamingChoice<StreamingCompletionResponse>> {
223    match content {
224        Content::Text(TextContent { text, .. }) => {
225            if text.is_empty() {
226                None
227            } else {
228                Some(streaming::RawStreamingChoice::Message(text))
229            }
230        }
231        Content::FunctionCall(FunctionCallContent {
232            name,
233            arguments,
234            id,
235        }) => {
236            let name = name?;
237            let call_id = id.unwrap_or_else(|| name.clone());
238            Some(streaming::RawStreamingChoice::ToolCall(
239                streaming::RawStreamingToolCall::new(
240                    name.clone(),
241                    name,
242                    arguments.unwrap_or(Value::Object(Map::new())),
243                )
244                .with_call_id(call_id),
245            ))
246        }
247        _ => None,
248    }
249}
250
251fn content_delta_to_choice(
252    delta: ContentDelta,
253) -> Option<streaming::RawStreamingChoice<StreamingCompletionResponse>> {
254    match delta {
255        ContentDelta::Text(TextDelta {
256            text: Some(text), ..
257        }) => Some(streaming::RawStreamingChoice::Message(text)),
258        ContentDelta::FunctionCall(FunctionCallDelta {
259            name,
260            arguments,
261            id,
262        }) => {
263            let name = name?;
264            let call_id = id.unwrap_or_else(|| name.clone());
265            Some(streaming::RawStreamingChoice::ToolCall(
266                streaming::RawStreamingToolCall::new(
267                    name.clone(),
268                    name,
269                    arguments.unwrap_or(Value::Object(Map::new())),
270                )
271                .with_call_id(call_id),
272            ))
273        }
274        ContentDelta::ThoughtSummary(ThoughtSummaryDelta { content }) => {
275            let text = match content {
276                ThoughtSummaryContent::Text(text) => text.text,
277                _ => return None,
278            };
279            Some(streaming::RawStreamingChoice::ReasoningDelta {
280                id: None,
281                reasoning: text,
282            })
283        }
284        _ => None,
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use serde_json::json;
292
293    #[test]
294    fn test_content_delta_text_event() {
295        let event_json = json!({
296            "event_type": "content.delta",
297            "index": 0,
298            "delta": {
299                "type": "text",
300                "text": "Hello"
301            }
302        });
303
304        let event: InteractionSseEvent = serde_json::from_value(event_json).unwrap();
305        let InteractionSseEvent::ContentDelta { delta, .. } = event else {
306            panic!("expected content delta");
307        };
308
309        let choice = content_delta_to_choice(delta).expect("choice should exist");
310        match choice {
311            crate::streaming::RawStreamingChoice::Message(text) => {
312                assert_eq!(text, "Hello");
313            }
314            other => panic!("unexpected choice: {other:?}"),
315        }
316    }
317
318    #[test]
319    fn test_content_delta_function_call_event() {
320        let event_json = json!({
321            "event_type": "content.delta",
322            "index": 0,
323            "delta": {
324                "type": "function_call",
325                "name": "get_weather",
326                "arguments": {"location": "Paris"},
327                "id": "call-1"
328            }
329        });
330
331        let event: InteractionSseEvent = serde_json::from_value(event_json).unwrap();
332        let InteractionSseEvent::ContentDelta { delta, .. } = event else {
333            panic!("expected content delta");
334        };
335
336        let choice = content_delta_to_choice(delta).expect("choice should exist");
337        match choice {
338            crate::streaming::RawStreamingChoice::ToolCall(call) => {
339                assert_eq!(call.name, "get_weather");
340                assert_eq!(call.call_id.as_deref(), Some("call-1"));
341            }
342            other => panic!("unexpected choice: {other:?}"),
343        }
344    }
345}