rig/providers/gemini/interactions_api/
streaming.rs1use async_stream::stream;
2use futures::{Stream, StreamExt};
3use serde::{Deserialize, Serialize};
4use std::pin::Pin;
5use tracing::{Level, enabled, info_span};
6use tracing_futures::Instrument;
7
8use super::InteractionsCompletionModel;
9use super::create_request_body;
10use super::interactions_api_types::{
11 Content, ContentDelta, FunctionCallContent, FunctionCallDelta, Interaction,
12 InteractionSseEvent, InteractionUsage, TextContent, TextDelta, ThoughtSummaryContent,
13 ThoughtSummaryDelta,
14};
15use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
16use crate::http_client::HttpClientExt;
17use crate::http_client::Request;
18use crate::http_client::sse::{Event, GenericEventSource};
19use crate::streaming;
20use crate::telemetry::SpanCombinator;
21use serde_json::{Map, Value};
22
23#[derive(Debug, Serialize, Deserialize, Default, Clone)]
25pub struct StreamingCompletionResponse {
26 pub usage: Option<InteractionUsage>,
27 pub interaction: Option<Interaction>,
28}
29
30#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
31pub type InteractionEventStream =
32 Pin<Box<dyn Stream<Item = Result<InteractionSseEvent, CompletionError>> + Send>>;
33
34#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
35pub type InteractionEventStream =
36 Pin<Box<dyn Stream<Item = Result<InteractionSseEvent, CompletionError>>>>;
37
38impl GetTokenUsage for StreamingCompletionResponse {
39 fn token_usage(&self) -> Option<crate::completion::Usage> {
40 self.usage.as_ref().and_then(|usage| usage.token_usage())
41 }
42}
43
44impl<T> InteractionsCompletionModel<T>
45where
46 T: HttpClientExt + Clone + Default + std::fmt::Debug + 'static,
47{
48 pub(crate) async fn stream(
49 &self,
50 completion_request: CompletionRequest,
51 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
52 {
53 let span = if tracing::Span::current().is_disabled() {
54 info_span!(
55 target: "rig::completions",
56 "interactions_streaming",
57 gen_ai.operation.name = "interactions_streaming",
58 gen_ai.provider.name = "gcp.gemini",
59 gen_ai.request.model = self.model,
60 gen_ai.system_instructions = &completion_request.preamble,
61 gen_ai.response.id = tracing::field::Empty,
62 gen_ai.response.model = tracing::field::Empty,
63 gen_ai.usage.output_tokens = tracing::field::Empty,
64 gen_ai.usage.input_tokens = tracing::field::Empty,
65 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
66 )
67 } else {
68 tracing::Span::current()
69 };
70
71 let request = create_request_body(self.model.clone(), completion_request, Some(true))?;
72
73 if enabled!(Level::TRACE) {
74 tracing::trace!(
75 target: "rig::streaming",
76 "Gemini interactions streaming request: {}",
77 serde_json::to_string_pretty(&request)?
78 );
79 }
80
81 let body = serde_json::to_vec(&request)?;
82 let req = self
83 .client
84 .post_sse("/v1beta/interactions")?
85 .header("Content-Type", "application/json")
86 .body(body)
87 .map_err(|e| CompletionError::HttpError(e.into()))?;
88
89 let mut event_source = GenericEventSource::new(self.client.clone(), req);
90
91 let stream = stream! {
92 let mut final_interaction: Option<Interaction> = None;
93 let mut final_usage: Option<InteractionUsage> = None;
94
95 while let Some(event_result) = event_source.next().await {
96 match event_result {
97 Ok(Event::Open) => {
98 tracing::debug!("SSE connection opened");
99 continue;
100 }
101 Ok(Event::Message(message)) => {
102 if message.data.trim().is_empty() {
103 continue;
104 }
105
106 let data = match serde_json::from_str::<InteractionSseEvent>(&message.data)
107 {
108 Ok(data) => data,
109 Err(err) => {
110 tracing::debug!(
111 "Failed to deserialize interactions SSE event: {err}"
112 );
113 continue;
114 }
115 };
116
117 match data {
118 InteractionSseEvent::ContentDelta { delta, .. } => {
119 if let Some(choice) = content_delta_to_choice(delta) {
120 yield Ok(choice);
121 }
122 }
123 InteractionSseEvent::ContentStart { content, .. } => {
124 if let Some(choice) = content_start_to_choice(content) {
125 yield Ok(choice);
126 }
127 }
128 InteractionSseEvent::InteractionComplete { interaction, .. } => {
129 let span = tracing::Span::current();
130 span.record("gen_ai.response.id", &interaction.id);
131 if let Some(model) = interaction.model.clone() {
132 span.record("gen_ai.response.model", model);
133 }
134
135 if let Some(usage) = interaction.usage.clone() {
136 span.record_token_usage(&usage);
137 final_usage = Some(usage);
138 }
139 final_interaction = Some(interaction);
140 }
141 InteractionSseEvent::Error { error, .. } => {
142 yield Err(CompletionError::ProviderError(error.message));
143 break;
144 }
145 _ => continue,
146 }
147 }
148 Err(crate::http_client::Error::StreamEnded) => {
149 break;
150 }
151 Err(error) => {
152 tracing::error!(?error, "SSE error");
153 yield Err(CompletionError::ProviderError(error.to_string()));
154 break;
155 }
156 }
157 }
158
159 event_source.close();
160
161 yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
162 usage: final_usage.or_else(|| final_interaction.as_ref().and_then(|i| i.usage.clone())),
163 interaction: final_interaction,
164 }));
165 }
166 .instrument(span);
167
168 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
169 stream,
170 )))
171 }
172}
173
174pub(crate) fn stream_interaction_events<T>(
175 client: super::InteractionsClient<T>,
176 request: Request<Vec<u8>>,
177) -> InteractionEventStream
178where
179 T: HttpClientExt + Clone + Default + std::fmt::Debug + 'static,
180{
181 let mut event_source = GenericEventSource::new(client.clone(), request);
182
183 let stream = stream! {
184 while let Some(event_result) = event_source.next().await {
185 match event_result {
186 Ok(Event::Open) => continue,
187 Ok(Event::Message(message)) => {
188 if message.data.trim().is_empty() {
189 continue;
190 }
191
192 let data = serde_json::from_str::<InteractionSseEvent>(&message.data);
193 let Ok(data) = data else {
194 let Err(err) = data else {
195 continue;
196 };
197 tracing::debug!("Failed to deserialize interactions SSE event: {err}");
198 continue;
199 };
200
201 yield Ok(data);
202 }
203 Err(crate::http_client::Error::StreamEnded) => break,
204 Err(error) => {
205 tracing::error!(?error, "SSE error");
206 yield Err(CompletionError::ProviderError(error.to_string()));
207 break;
208 }
209 }
210 }
211
212 event_source.close();
213 };
214
215 Box::pin(stream)
216}
217
218fn content_start_to_choice(
219 content: Content,
220) -> Option<streaming::RawStreamingChoice<StreamingCompletionResponse>> {
221 match content {
222 Content::Text(TextContent { text, .. }) => {
223 if text.is_empty() {
224 None
225 } else {
226 Some(streaming::RawStreamingChoice::Message(text))
227 }
228 }
229 Content::FunctionCall(FunctionCallContent {
230 name,
231 arguments,
232 id,
233 }) => {
234 let name = name?;
235 let call_id = id.unwrap_or_else(|| name.clone());
236 Some(streaming::RawStreamingChoice::ToolCall(
237 streaming::RawStreamingToolCall::new(
238 name.clone(),
239 name,
240 arguments.unwrap_or(Value::Object(Map::new())),
241 )
242 .with_call_id(call_id),
243 ))
244 }
245 _ => None,
246 }
247}
248
249fn content_delta_to_choice(
250 delta: ContentDelta,
251) -> Option<streaming::RawStreamingChoice<StreamingCompletionResponse>> {
252 match delta {
253 ContentDelta::Text(TextDelta {
254 text: Some(text), ..
255 }) => Some(streaming::RawStreamingChoice::Message(text)),
256 ContentDelta::FunctionCall(FunctionCallDelta {
257 name,
258 arguments,
259 id,
260 }) => {
261 let name = name?;
262 let call_id = id.unwrap_or_else(|| name.clone());
263 Some(streaming::RawStreamingChoice::ToolCall(
264 streaming::RawStreamingToolCall::new(
265 name.clone(),
266 name,
267 arguments.unwrap_or(Value::Object(Map::new())),
268 )
269 .with_call_id(call_id),
270 ))
271 }
272 ContentDelta::ThoughtSummary(ThoughtSummaryDelta { content }) => {
273 let text = match content {
274 ThoughtSummaryContent::Text(text) => text.text,
275 _ => return None,
276 };
277 Some(streaming::RawStreamingChoice::ReasoningDelta {
278 id: None,
279 reasoning: text,
280 })
281 }
282 _ => None,
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289 use serde_json::json;
290
291 #[test]
292 fn test_content_delta_text_event() {
293 let event_json = json!({
294 "event_type": "content.delta",
295 "index": 0,
296 "delta": {
297 "type": "text",
298 "text": "Hello"
299 }
300 });
301
302 let event: InteractionSseEvent = serde_json::from_value(event_json).unwrap();
303 let InteractionSseEvent::ContentDelta { delta, .. } = event else {
304 panic!("expected content delta");
305 };
306
307 let choice = content_delta_to_choice(delta).expect("choice should exist");
308 match choice {
309 crate::streaming::RawStreamingChoice::Message(text) => {
310 assert_eq!(text, "Hello");
311 }
312 other => panic!("unexpected choice: {other:?}"),
313 }
314 }
315
316 #[test]
317 fn test_content_delta_function_call_event() {
318 let event_json = json!({
319 "event_type": "content.delta",
320 "index": 0,
321 "delta": {
322 "type": "function_call",
323 "name": "get_weather",
324 "arguments": {"location": "Paris"},
325 "id": "call-1"
326 }
327 });
328
329 let event: InteractionSseEvent = serde_json::from_value(event_json).unwrap();
330 let InteractionSseEvent::ContentDelta { delta, .. } = event else {
331 panic!("expected content delta");
332 };
333
334 let choice = content_delta_to_choice(delta).expect("choice should exist");
335 match choice {
336 crate::streaming::RawStreamingChoice::ToolCall(call) => {
337 assert_eq!(call.name, "get_weather");
338 assert_eq!(call.call_id.as_deref(), Some("call-1"));
339 }
340 other => panic!("unexpected choice: {other:?}"),
341 }
342 }
343}