rig/providers/gemini/interactions_api/
streaming.rs1use async_stream::stream;
2use futures::{Stream, StreamExt};
3use serde::{Deserialize, Serialize};
4use std::pin::Pin;
5use tracing::{Level, enabled, info_span};
6use tracing_futures::Instrument;
7
8use super::InteractionsCompletionModel;
9use super::create_request_body;
10use super::interactions_api_types::{
11 Content, ContentDelta, FunctionCallContent, FunctionCallDelta, Interaction,
12 InteractionSseEvent, InteractionUsage, TextContent, TextDelta, ThoughtSummaryContent,
13 ThoughtSummaryDelta,
14};
15use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
16use crate::http_client::HttpClientExt;
17use crate::http_client::Request;
18use crate::http_client::sse::{Event, GenericEventSource};
19use crate::streaming;
20use crate::telemetry::SpanCombinator;
21use serde_json::{Map, Value};
22
23#[derive(Debug, Serialize, Deserialize, Default, Clone)]
25pub struct StreamingCompletionResponse {
26 pub usage: Option<InteractionUsage>,
27 pub interaction: Option<Interaction>,
28}
29
30#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
31pub type InteractionEventStream =
32 Pin<Box<dyn Stream<Item = Result<InteractionSseEvent, CompletionError>> + Send>>;
33
34#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
35pub type InteractionEventStream =
36 Pin<Box<dyn Stream<Item = Result<InteractionSseEvent, CompletionError>>>>;
37
38impl GetTokenUsage for StreamingCompletionResponse {
39 fn token_usage(&self) -> Option<crate::completion::Usage> {
40 self.usage.as_ref().and_then(|usage| usage.token_usage())
41 }
42}
43
44impl<T> InteractionsCompletionModel<T>
45where
46 T: HttpClientExt + Clone + Default + std::fmt::Debug + 'static,
47{
48 pub(crate) async fn stream(
49 &self,
50 completion_request: CompletionRequest,
51 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
52 {
53 let span = if tracing::Span::current().is_disabled() {
54 info_span!(
55 target: "rig::completions",
56 "interactions_streaming",
57 gen_ai.operation.name = "interactions_streaming",
58 gen_ai.provider.name = "gcp.gemini",
59 gen_ai.request.model = self.model,
60 gen_ai.system_instructions = &completion_request.preamble,
61 gen_ai.response.id = tracing::field::Empty,
62 gen_ai.response.model = tracing::field::Empty,
63 gen_ai.usage.output_tokens = tracing::field::Empty,
64 gen_ai.usage.input_tokens = tracing::field::Empty,
65 )
66 } else {
67 tracing::Span::current()
68 };
69
70 let request = create_request_body(self.model.clone(), completion_request, Some(true))?;
71
72 if enabled!(Level::TRACE) {
73 tracing::trace!(
74 target: "rig::streaming",
75 "Gemini interactions streaming request: {}",
76 serde_json::to_string_pretty(&request)?
77 );
78 }
79
80 let body = serde_json::to_vec(&request)?;
81 let req = self
82 .client
83 .post_sse("/v1beta/interactions")?
84 .header("Content-Type", "application/json")
85 .body(body)
86 .map_err(|e| CompletionError::HttpError(e.into()))?;
87
88 let mut event_source = GenericEventSource::new(self.client.clone(), req);
89
90 let stream = stream! {
91 let mut final_interaction: Option<Interaction> = None;
92 let mut final_usage: Option<InteractionUsage> = None;
93
94 while let Some(event_result) = event_source.next().await {
95 match event_result {
96 Ok(Event::Open) => {
97 tracing::debug!("SSE connection opened");
98 continue;
99 }
100 Ok(Event::Message(message)) => {
101 if message.data.trim().is_empty() {
102 continue;
103 }
104
105 let data = match serde_json::from_str::<InteractionSseEvent>(&message.data)
106 {
107 Ok(data) => data,
108 Err(err) => {
109 tracing::debug!(
110 "Failed to deserialize interactions SSE event: {err}"
111 );
112 continue;
113 }
114 };
115
116 match data {
117 InteractionSseEvent::ContentDelta { delta, .. } => {
118 if let Some(choice) = content_delta_to_choice(delta) {
119 yield Ok(choice);
120 }
121 }
122 InteractionSseEvent::ContentStart { content, .. } => {
123 if let Some(choice) = content_start_to_choice(content) {
124 yield Ok(choice);
125 }
126 }
127 InteractionSseEvent::InteractionComplete { interaction, .. } => {
128 let span = tracing::Span::current();
129 span.record("gen_ai.response.id", &interaction.id);
130 if let Some(model) = interaction.model.clone() {
131 span.record("gen_ai.response.model", model);
132 }
133
134 if let Some(usage) = interaction.usage.clone() {
135 span.record_token_usage(&usage);
136 final_usage = Some(usage);
137 }
138 final_interaction = Some(interaction);
139 }
140 InteractionSseEvent::Error { error, .. } => {
141 yield Err(CompletionError::ProviderError(error.message));
142 break;
143 }
144 _ => continue,
145 }
146 }
147 Err(crate::http_client::Error::StreamEnded) => {
148 break;
149 }
150 Err(error) => {
151 tracing::error!(?error, "SSE error");
152 yield Err(CompletionError::ProviderError(error.to_string()));
153 break;
154 }
155 }
156 }
157
158 event_source.close();
159
160 yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
161 usage: final_usage.or_else(|| final_interaction.as_ref().and_then(|i| i.usage.clone())),
162 interaction: final_interaction,
163 }));
164 }
165 .instrument(span);
166
167 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
168 stream,
169 )))
170 }
171}
172
173pub(crate) fn stream_interaction_events<T>(
174 client: super::InteractionsClient<T>,
175 request: Request<Vec<u8>>,
176) -> InteractionEventStream
177where
178 T: HttpClientExt + Clone + Default + std::fmt::Debug + 'static,
179{
180 let mut event_source = GenericEventSource::new(client.clone(), request);
181
182 let stream = stream! {
183 while let Some(event_result) = event_source.next().await {
184 match event_result {
185 Ok(Event::Open) => continue,
186 Ok(Event::Message(message)) => {
187 if message.data.trim().is_empty() {
188 continue;
189 }
190
191 let data = serde_json::from_str::<InteractionSseEvent>(&message.data);
192 let Ok(data) = data else {
193 let err = data.unwrap_err();
194 tracing::debug!("Failed to deserialize interactions SSE event: {err}");
195 continue;
196 };
197
198 yield Ok(data);
199 }
200 Err(crate::http_client::Error::StreamEnded) => break,
201 Err(error) => {
202 tracing::error!(?error, "SSE error");
203 yield Err(CompletionError::ProviderError(error.to_string()));
204 break;
205 }
206 }
207 }
208
209 event_source.close();
210 };
211
212 Box::pin(stream)
213}
214
215fn content_start_to_choice(
216 content: Content,
217) -> Option<streaming::RawStreamingChoice<StreamingCompletionResponse>> {
218 match content {
219 Content::Text(TextContent { text, .. }) => {
220 if text.is_empty() {
221 None
222 } else {
223 Some(streaming::RawStreamingChoice::Message(text))
224 }
225 }
226 Content::FunctionCall(FunctionCallContent {
227 name,
228 arguments,
229 id,
230 }) => {
231 let name = name?;
232 let call_id = id.unwrap_or_else(|| name.clone());
233 Some(streaming::RawStreamingChoice::ToolCall(
234 streaming::RawStreamingToolCall::new(
235 name.clone(),
236 name,
237 arguments.unwrap_or(Value::Object(Map::new())),
238 )
239 .with_call_id(call_id),
240 ))
241 }
242 _ => None,
243 }
244}
245
246fn content_delta_to_choice(
247 delta: ContentDelta,
248) -> Option<streaming::RawStreamingChoice<StreamingCompletionResponse>> {
249 match delta {
250 ContentDelta::Text(TextDelta {
251 text: Some(text), ..
252 }) => Some(streaming::RawStreamingChoice::Message(text)),
253 ContentDelta::FunctionCall(FunctionCallDelta {
254 name,
255 arguments,
256 id,
257 }) => {
258 let name = name?;
259 let call_id = id.unwrap_or_else(|| name.clone());
260 Some(streaming::RawStreamingChoice::ToolCall(
261 streaming::RawStreamingToolCall::new(
262 name.clone(),
263 name,
264 arguments.unwrap_or(Value::Object(Map::new())),
265 )
266 .with_call_id(call_id),
267 ))
268 }
269 ContentDelta::ThoughtSummary(ThoughtSummaryDelta { content }) => {
270 let text = match content {
271 ThoughtSummaryContent::Text(text) => text.text,
272 _ => return None,
273 };
274 Some(streaming::RawStreamingChoice::ReasoningDelta {
275 id: None,
276 reasoning: text,
277 })
278 }
279 _ => None,
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286 use serde_json::json;
287
288 #[test]
289 fn test_content_delta_text_event() {
290 let event_json = json!({
291 "event_type": "content.delta",
292 "index": 0,
293 "delta": {
294 "type": "text",
295 "text": "Hello"
296 }
297 });
298
299 let event: InteractionSseEvent = serde_json::from_value(event_json).unwrap();
300 let InteractionSseEvent::ContentDelta { delta, .. } = event else {
301 panic!("expected content delta");
302 };
303
304 let choice = content_delta_to_choice(delta).expect("choice should exist");
305 match choice {
306 crate::streaming::RawStreamingChoice::Message(text) => {
307 assert_eq!(text, "Hello");
308 }
309 other => panic!("unexpected choice: {other:?}"),
310 }
311 }
312
313 #[test]
314 fn test_content_delta_function_call_event() {
315 let event_json = json!({
316 "event_type": "content.delta",
317 "index": 0,
318 "delta": {
319 "type": "function_call",
320 "name": "get_weather",
321 "arguments": {"location": "Paris"},
322 "id": "call-1"
323 }
324 });
325
326 let event: InteractionSseEvent = serde_json::from_value(event_json).unwrap();
327 let InteractionSseEvent::ContentDelta { delta, .. } = event else {
328 panic!("expected content delta");
329 };
330
331 let choice = content_delta_to_choice(delta).expect("choice should exist");
332 match choice {
333 crate::streaming::RawStreamingChoice::ToolCall(call) => {
334 assert_eq!(call.name, "get_weather");
335 assert_eq!(call.call_id.as_deref(), Some("call-1"));
336 }
337 other => panic!("unexpected choice: {other:?}"),
338 }
339 }
340}