Skip to main content

rig/providers/gemini/interactions_api/
streaming.rs

1use async_stream::stream;
2use futures::{Stream, StreamExt};
3use serde::{Deserialize, Serialize};
4use std::pin::Pin;
5use tracing::{Level, enabled, info_span};
6use tracing_futures::Instrument;
7
8use super::InteractionsCompletionModel;
9use super::create_request_body;
10use super::interactions_api_types::{
11    Content, ContentDelta, FunctionCallContent, FunctionCallDelta, Interaction,
12    InteractionSseEvent, InteractionUsage, TextContent, TextDelta, ThoughtSummaryContent,
13    ThoughtSummaryDelta,
14};
15use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
16use crate::http_client::HttpClientExt;
17use crate::http_client::Request;
18use crate::http_client::sse::{Event, GenericEventSource};
19use crate::streaming;
20use crate::telemetry::SpanCombinator;
21use serde_json::{Map, Value};
22
23/// Final metadata yielded by an Interactions streaming response.
24#[derive(Debug, Serialize, Deserialize, Default, Clone)]
25pub struct StreamingCompletionResponse {
26    pub usage: Option<InteractionUsage>,
27    pub interaction: Option<Interaction>,
28}
29
30#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
31pub type InteractionEventStream =
32    Pin<Box<dyn Stream<Item = Result<InteractionSseEvent, CompletionError>> + Send>>;
33
34#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
35pub type InteractionEventStream =
36    Pin<Box<dyn Stream<Item = Result<InteractionSseEvent, CompletionError>>>>;
37
38impl GetTokenUsage for StreamingCompletionResponse {
39    fn token_usage(&self) -> Option<crate::completion::Usage> {
40        self.usage.as_ref().and_then(|usage| usage.token_usage())
41    }
42}
43
44impl<T> InteractionsCompletionModel<T>
45where
46    T: HttpClientExt + Clone + Default + std::fmt::Debug + 'static,
47{
48    pub(crate) async fn stream(
49        &self,
50        completion_request: CompletionRequest,
51    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
52    {
53        let span = if tracing::Span::current().is_disabled() {
54            info_span!(
55                target: "rig::completions",
56                "interactions_streaming",
57                gen_ai.operation.name = "interactions_streaming",
58                gen_ai.provider.name = "gcp.gemini",
59                gen_ai.request.model = self.model,
60                gen_ai.system_instructions = &completion_request.preamble,
61                gen_ai.response.id = tracing::field::Empty,
62                gen_ai.response.model = tracing::field::Empty,
63                gen_ai.usage.output_tokens = tracing::field::Empty,
64                gen_ai.usage.input_tokens = tracing::field::Empty,
65                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
66            )
67        } else {
68            tracing::Span::current()
69        };
70
71        let request = create_request_body(self.model.clone(), completion_request, Some(true))?;
72
73        if enabled!(Level::TRACE) {
74            tracing::trace!(
75                target: "rig::streaming",
76                "Gemini interactions streaming request: {}",
77                serde_json::to_string_pretty(&request)?
78            );
79        }
80
81        let body = serde_json::to_vec(&request)?;
82        let req = self
83            .client
84            .post_sse("/v1beta/interactions")?
85            .header("Content-Type", "application/json")
86            .body(body)
87            .map_err(|e| CompletionError::HttpError(e.into()))?;
88
89        let mut event_source = GenericEventSource::new(self.client.clone(), req);
90
91        let stream = stream! {
92            let mut final_interaction: Option<Interaction> = None;
93            let mut final_usage: Option<InteractionUsage> = None;
94
95            while let Some(event_result) = event_source.next().await {
96                match event_result {
97                    Ok(Event::Open) => {
98                        tracing::debug!("SSE connection opened");
99                        continue;
100                    }
101                    Ok(Event::Message(message)) => {
102                        if message.data.trim().is_empty() {
103                            continue;
104                        }
105
106                        let data = match serde_json::from_str::<InteractionSseEvent>(&message.data)
107                        {
108                            Ok(data) => data,
109                            Err(err) => {
110                                tracing::debug!(
111                                    "Failed to deserialize interactions SSE event: {err}"
112                                );
113                                continue;
114                            }
115                        };
116
117                        match data {
118                            InteractionSseEvent::ContentDelta { delta, .. } => {
119                                if let Some(choice) = content_delta_to_choice(delta) {
120                                    yield Ok(choice);
121                                }
122                            }
123                            InteractionSseEvent::ContentStart { content, .. } => {
124                                if let Some(choice) = content_start_to_choice(content) {
125                                    yield Ok(choice);
126                                }
127                            }
128                            InteractionSseEvent::InteractionComplete { interaction, .. } => {
129                                let span = tracing::Span::current();
130                                span.record("gen_ai.response.id", &interaction.id);
131                                if let Some(model) = interaction.model.clone() {
132                                    span.record("gen_ai.response.model", model);
133                                }
134
135                                if let Some(usage) = interaction.usage.clone() {
136                                    span.record_token_usage(&usage);
137                                    final_usage = Some(usage);
138                                }
139                                final_interaction = Some(interaction);
140                            }
141                            InteractionSseEvent::Error { error, .. } => {
142                                yield Err(CompletionError::ProviderError(error.message));
143                                break;
144                            }
145                            _ => continue,
146                        }
147                    }
148                    Err(crate::http_client::Error::StreamEnded) => {
149                        break;
150                    }
151                    Err(error) => {
152                        tracing::error!(?error, "SSE error");
153                        yield Err(CompletionError::ProviderError(error.to_string()));
154                        break;
155                    }
156                }
157            }
158
159            event_source.close();
160
161            yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
162                usage: final_usage.or_else(|| final_interaction.as_ref().and_then(|i| i.usage.clone())),
163                interaction: final_interaction,
164            }));
165        }
166        .instrument(span);
167
168        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
169            stream,
170        )))
171    }
172}
173
174pub(crate) fn stream_interaction_events<T>(
175    client: super::InteractionsClient<T>,
176    request: Request<Vec<u8>>,
177) -> InteractionEventStream
178where
179    T: HttpClientExt + Clone + Default + std::fmt::Debug + 'static,
180{
181    let mut event_source = GenericEventSource::new(client.clone(), request);
182
183    let stream = stream! {
184        while let Some(event_result) = event_source.next().await {
185            match event_result {
186                Ok(Event::Open) => continue,
187                Ok(Event::Message(message)) => {
188                    if message.data.trim().is_empty() {
189                        continue;
190                    }
191
192                    let data = serde_json::from_str::<InteractionSseEvent>(&message.data);
193                    let Ok(data) = data else {
194                        let Err(err) = data else {
195                            continue;
196                        };
197                        tracing::debug!("Failed to deserialize interactions SSE event: {err}");
198                        continue;
199                    };
200
201                    yield Ok(data);
202                }
203                Err(crate::http_client::Error::StreamEnded) => break,
204                Err(error) => {
205                    tracing::error!(?error, "SSE error");
206                    yield Err(CompletionError::ProviderError(error.to_string()));
207                    break;
208                }
209            }
210        }
211
212        event_source.close();
213    };
214
215    Box::pin(stream)
216}
217
218fn content_start_to_choice(
219    content: Content,
220) -> Option<streaming::RawStreamingChoice<StreamingCompletionResponse>> {
221    match content {
222        Content::Text(TextContent { text, .. }) => {
223            if text.is_empty() {
224                None
225            } else {
226                Some(streaming::RawStreamingChoice::Message(text))
227            }
228        }
229        Content::FunctionCall(FunctionCallContent {
230            name,
231            arguments,
232            id,
233        }) => {
234            let name = name?;
235            let call_id = id.unwrap_or_else(|| name.clone());
236            Some(streaming::RawStreamingChoice::ToolCall(
237                streaming::RawStreamingToolCall::new(
238                    name.clone(),
239                    name,
240                    arguments.unwrap_or(Value::Object(Map::new())),
241                )
242                .with_call_id(call_id),
243            ))
244        }
245        _ => None,
246    }
247}
248
249fn content_delta_to_choice(
250    delta: ContentDelta,
251) -> Option<streaming::RawStreamingChoice<StreamingCompletionResponse>> {
252    match delta {
253        ContentDelta::Text(TextDelta {
254            text: Some(text), ..
255        }) => Some(streaming::RawStreamingChoice::Message(text)),
256        ContentDelta::FunctionCall(FunctionCallDelta {
257            name,
258            arguments,
259            id,
260        }) => {
261            let name = name?;
262            let call_id = id.unwrap_or_else(|| name.clone());
263            Some(streaming::RawStreamingChoice::ToolCall(
264                streaming::RawStreamingToolCall::new(
265                    name.clone(),
266                    name,
267                    arguments.unwrap_or(Value::Object(Map::new())),
268                )
269                .with_call_id(call_id),
270            ))
271        }
272        ContentDelta::ThoughtSummary(ThoughtSummaryDelta { content }) => {
273            let text = match content {
274                ThoughtSummaryContent::Text(text) => text.text,
275                _ => return None,
276            };
277            Some(streaming::RawStreamingChoice::ReasoningDelta {
278                id: None,
279                reasoning: text,
280            })
281        }
282        _ => None,
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289    use serde_json::json;
290
291    #[test]
292    fn test_content_delta_text_event() {
293        let event_json = json!({
294            "event_type": "content.delta",
295            "index": 0,
296            "delta": {
297                "type": "text",
298                "text": "Hello"
299            }
300        });
301
302        let event: InteractionSseEvent = serde_json::from_value(event_json).unwrap();
303        let InteractionSseEvent::ContentDelta { delta, .. } = event else {
304            panic!("expected content delta");
305        };
306
307        let choice = content_delta_to_choice(delta).expect("choice should exist");
308        match choice {
309            crate::streaming::RawStreamingChoice::Message(text) => {
310                assert_eq!(text, "Hello");
311            }
312            other => panic!("unexpected choice: {other:?}"),
313        }
314    }
315
316    #[test]
317    fn test_content_delta_function_call_event() {
318        let event_json = json!({
319            "event_type": "content.delta",
320            "index": 0,
321            "delta": {
322                "type": "function_call",
323                "name": "get_weather",
324                "arguments": {"location": "Paris"},
325                "id": "call-1"
326            }
327        });
328
329        let event: InteractionSseEvent = serde_json::from_value(event_json).unwrap();
330        let InteractionSseEvent::ContentDelta { delta, .. } = event else {
331            panic!("expected content delta");
332        };
333
334        let choice = content_delta_to_choice(delta).expect("choice should exist");
335        match choice {
336            crate::streaming::RawStreamingChoice::ToolCall(call) => {
337                assert_eq!(call.name, "get_weather");
338                assert_eq!(call.call_id.as_deref(), Some("call-1"));
339            }
340            other => panic!("unexpected choice: {other:?}"),
341        }
342    }
343}