Skip to main content

rig/providers/gemini/interactions_api/
streaming.rs

1use async_stream::stream;
2use futures::{Stream, StreamExt};
3use serde::{Deserialize, Serialize};
4use std::pin::Pin;
5use tracing::{Level, enabled, info_span};
6use tracing_futures::Instrument;
7
8use super::InteractionsCompletionModel;
9use super::create_request_body;
10use super::interactions_api_types::{
11    Content, ContentDelta, FunctionCallContent, FunctionCallDelta, Interaction,
12    InteractionSseEvent, InteractionUsage, TextContent, TextDelta, ThoughtSummaryContent,
13    ThoughtSummaryDelta,
14};
15use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
16use crate::http_client::HttpClientExt;
17use crate::http_client::Request;
18use crate::http_client::sse::{Event, GenericEventSource};
19use crate::streaming;
20use crate::telemetry::SpanCombinator;
21use serde_json::{Map, Value};
22
23/// Final metadata yielded by an Interactions streaming response.
24#[derive(Debug, Serialize, Deserialize, Default, Clone)]
25pub struct StreamingCompletionResponse {
26    pub usage: Option<InteractionUsage>,
27    pub interaction: Option<Interaction>,
28}
29
30#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
31pub type InteractionEventStream =
32    Pin<Box<dyn Stream<Item = Result<InteractionSseEvent, CompletionError>> + Send>>;
33
34#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
35pub type InteractionEventStream =
36    Pin<Box<dyn Stream<Item = Result<InteractionSseEvent, CompletionError>>>>;
37
38impl GetTokenUsage for StreamingCompletionResponse {
39    fn token_usage(&self) -> Option<crate::completion::Usage> {
40        self.usage.as_ref().and_then(|usage| usage.token_usage())
41    }
42}
43
44impl<T> InteractionsCompletionModel<T>
45where
46    T: HttpClientExt + Clone + Default + std::fmt::Debug + 'static,
47{
48    pub(crate) async fn stream(
49        &self,
50        completion_request: CompletionRequest,
51    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
52    {
53        let span = if tracing::Span::current().is_disabled() {
54            info_span!(
55                target: "rig::completions",
56                "interactions_streaming",
57                gen_ai.operation.name = "interactions_streaming",
58                gen_ai.provider.name = "gcp.gemini",
59                gen_ai.request.model = self.model,
60                gen_ai.system_instructions = &completion_request.preamble,
61                gen_ai.response.id = tracing::field::Empty,
62                gen_ai.response.model = tracing::field::Empty,
63                gen_ai.usage.output_tokens = tracing::field::Empty,
64                gen_ai.usage.input_tokens = tracing::field::Empty,
65            )
66        } else {
67            tracing::Span::current()
68        };
69
70        let request = create_request_body(self.model.clone(), completion_request, Some(true))?;
71
72        if enabled!(Level::TRACE) {
73            tracing::trace!(
74                target: "rig::streaming",
75                "Gemini interactions streaming request: {}",
76                serde_json::to_string_pretty(&request)?
77            );
78        }
79
80        let body = serde_json::to_vec(&request)?;
81        let req = self
82            .client
83            .post_sse("/v1beta/interactions")?
84            .header("Content-Type", "application/json")
85            .body(body)
86            .map_err(|e| CompletionError::HttpError(e.into()))?;
87
88        let mut event_source = GenericEventSource::new(self.client.clone(), req);
89
90        let stream = stream! {
91            let mut final_interaction: Option<Interaction> = None;
92            let mut final_usage: Option<InteractionUsage> = None;
93
94            while let Some(event_result) = event_source.next().await {
95                match event_result {
96                    Ok(Event::Open) => {
97                        tracing::debug!("SSE connection opened");
98                        continue;
99                    }
100                    Ok(Event::Message(message)) => {
101                        if message.data.trim().is_empty() {
102                            continue;
103                        }
104
105                        let data = match serde_json::from_str::<InteractionSseEvent>(&message.data)
106                        {
107                            Ok(data) => data,
108                            Err(err) => {
109                                tracing::debug!(
110                                    "Failed to deserialize interactions SSE event: {err}"
111                                );
112                                continue;
113                            }
114                        };
115
116                        match data {
117                            InteractionSseEvent::ContentDelta { delta, .. } => {
118                                if let Some(choice) = content_delta_to_choice(delta) {
119                                    yield Ok(choice);
120                                }
121                            }
122                            InteractionSseEvent::ContentStart { content, .. } => {
123                                if let Some(choice) = content_start_to_choice(content) {
124                                    yield Ok(choice);
125                                }
126                            }
127                            InteractionSseEvent::InteractionComplete { interaction, .. } => {
128                                let span = tracing::Span::current();
129                                span.record("gen_ai.response.id", &interaction.id);
130                                if let Some(model) = interaction.model.clone() {
131                                    span.record("gen_ai.response.model", model);
132                                }
133
134                                if let Some(usage) = interaction.usage.clone() {
135                                    span.record_token_usage(&usage);
136                                    final_usage = Some(usage);
137                                }
138                                final_interaction = Some(interaction);
139                            }
140                            InteractionSseEvent::Error { error, .. } => {
141                                yield Err(CompletionError::ProviderError(error.message));
142                                break;
143                            }
144                            _ => continue,
145                        }
146                    }
147                    Err(crate::http_client::Error::StreamEnded) => {
148                        break;
149                    }
150                    Err(error) => {
151                        tracing::error!(?error, "SSE error");
152                        yield Err(CompletionError::ProviderError(error.to_string()));
153                        break;
154                    }
155                }
156            }
157
158            event_source.close();
159
160            yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
161                usage: final_usage.or_else(|| final_interaction.as_ref().and_then(|i| i.usage.clone())),
162                interaction: final_interaction,
163            }));
164        }
165        .instrument(span);
166
167        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
168            stream,
169        )))
170    }
171}
172
173pub(crate) fn stream_interaction_events<T>(
174    client: super::InteractionsClient<T>,
175    request: Request<Vec<u8>>,
176) -> InteractionEventStream
177where
178    T: HttpClientExt + Clone + Default + std::fmt::Debug + 'static,
179{
180    let mut event_source = GenericEventSource::new(client.clone(), request);
181
182    let stream = stream! {
183        while let Some(event_result) = event_source.next().await {
184            match event_result {
185                Ok(Event::Open) => continue,
186                Ok(Event::Message(message)) => {
187                    if message.data.trim().is_empty() {
188                        continue;
189                    }
190
191                    let data = serde_json::from_str::<InteractionSseEvent>(&message.data);
192                    let Ok(data) = data else {
193                        let err = data.unwrap_err();
194                        tracing::debug!("Failed to deserialize interactions SSE event: {err}");
195                        continue;
196                    };
197
198                    yield Ok(data);
199                }
200                Err(crate::http_client::Error::StreamEnded) => break,
201                Err(error) => {
202                    tracing::error!(?error, "SSE error");
203                    yield Err(CompletionError::ProviderError(error.to_string()));
204                    break;
205                }
206            }
207        }
208
209        event_source.close();
210    };
211
212    Box::pin(stream)
213}
214
215fn content_start_to_choice(
216    content: Content,
217) -> Option<streaming::RawStreamingChoice<StreamingCompletionResponse>> {
218    match content {
219        Content::Text(TextContent { text, .. }) => {
220            if text.is_empty() {
221                None
222            } else {
223                Some(streaming::RawStreamingChoice::Message(text))
224            }
225        }
226        Content::FunctionCall(FunctionCallContent {
227            name,
228            arguments,
229            id,
230        }) => {
231            let name = name?;
232            let call_id = id.unwrap_or_else(|| name.clone());
233            Some(streaming::RawStreamingChoice::ToolCall(
234                streaming::RawStreamingToolCall::new(
235                    name.clone(),
236                    name,
237                    arguments.unwrap_or(Value::Object(Map::new())),
238                )
239                .with_call_id(call_id),
240            ))
241        }
242        _ => None,
243    }
244}
245
246fn content_delta_to_choice(
247    delta: ContentDelta,
248) -> Option<streaming::RawStreamingChoice<StreamingCompletionResponse>> {
249    match delta {
250        ContentDelta::Text(TextDelta {
251            text: Some(text), ..
252        }) => Some(streaming::RawStreamingChoice::Message(text)),
253        ContentDelta::FunctionCall(FunctionCallDelta {
254            name,
255            arguments,
256            id,
257        }) => {
258            let name = name?;
259            let call_id = id.unwrap_or_else(|| name.clone());
260            Some(streaming::RawStreamingChoice::ToolCall(
261                streaming::RawStreamingToolCall::new(
262                    name.clone(),
263                    name,
264                    arguments.unwrap_or(Value::Object(Map::new())),
265                )
266                .with_call_id(call_id),
267            ))
268        }
269        ContentDelta::ThoughtSummary(ThoughtSummaryDelta { content }) => {
270            let text = match content {
271                ThoughtSummaryContent::Text(text) => text.text,
272                _ => return None,
273            };
274            Some(streaming::RawStreamingChoice::ReasoningDelta {
275                id: None,
276                reasoning: text,
277            })
278        }
279        _ => None,
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286    use serde_json::json;
287
288    #[test]
289    fn test_content_delta_text_event() {
290        let event_json = json!({
291            "event_type": "content.delta",
292            "index": 0,
293            "delta": {
294                "type": "text",
295                "text": "Hello"
296            }
297        });
298
299        let event: InteractionSseEvent = serde_json::from_value(event_json).unwrap();
300        let InteractionSseEvent::ContentDelta { delta, .. } = event else {
301            panic!("expected content delta");
302        };
303
304        let choice = content_delta_to_choice(delta).expect("choice should exist");
305        match choice {
306            crate::streaming::RawStreamingChoice::Message(text) => {
307                assert_eq!(text, "Hello");
308            }
309            other => panic!("unexpected choice: {other:?}"),
310        }
311    }
312
313    #[test]
314    fn test_content_delta_function_call_event() {
315        let event_json = json!({
316            "event_type": "content.delta",
317            "index": 0,
318            "delta": {
319                "type": "function_call",
320                "name": "get_weather",
321                "arguments": {"location": "Paris"},
322                "id": "call-1"
323            }
324        });
325
326        let event: InteractionSseEvent = serde_json::from_value(event_json).unwrap();
327        let InteractionSseEvent::ContentDelta { delta, .. } = event else {
328            panic!("expected content delta");
329        };
330
331        let choice = content_delta_to_choice(delta).expect("choice should exist");
332        match choice {
333            crate::streaming::RawStreamingChoice::ToolCall(call) => {
334                assert_eq!(call.name, "get_weather");
335                assert_eq!(call.call_id.as_deref(), Some("call-1"));
336            }
337            other => panic!("unexpected choice: {other:?}"),
338        }
339    }
340}