rig_core/providers/gemini/interactions_api/
streaming.rs1use async_stream::stream;
2use futures::{Stream, StreamExt};
3use serde::{Deserialize, Serialize};
4use std::pin::Pin;
5use tracing::{Level, enabled, info_span};
6use tracing_futures::Instrument;
7
8use super::InteractionsCompletionModel;
9use super::create_request_body;
10use super::interactions_api_types::{
11 Content, ContentDelta, FunctionCallContent, FunctionCallDelta, Interaction,
12 InteractionSseEvent, InteractionUsage, TextContent, TextDelta, ThoughtSummaryContent,
13 ThoughtSummaryDelta,
14};
15use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
16use crate::http_client::HttpClientExt;
17use crate::http_client::Request;
18use crate::http_client::sse::{Event, GenericEventSource};
19use crate::streaming;
20use crate::telemetry::SpanCombinator;
21use serde_json::{Map, Value};
22
23#[derive(Debug, Serialize, Deserialize, Default, Clone)]
25pub struct StreamingCompletionResponse {
26 pub usage: Option<InteractionUsage>,
27 pub interaction: Option<Interaction>,
28}
29
30#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
31pub type InteractionEventStream =
32 Pin<Box<dyn Stream<Item = Result<InteractionSseEvent, CompletionError>> + Send>>;
33
34#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
35pub type InteractionEventStream =
36 Pin<Box<dyn Stream<Item = Result<InteractionSseEvent, CompletionError>>>>;
37
38impl GetTokenUsage for StreamingCompletionResponse {
39 fn token_usage(&self) -> Option<crate::completion::Usage> {
40 self.usage.as_ref().and_then(|usage| usage.token_usage())
41 }
42}
43
44impl<T> InteractionsCompletionModel<T>
45where
46 T: HttpClientExt + Clone + Default + std::fmt::Debug + 'static,
47{
48 pub(crate) async fn stream(
49 &self,
50 completion_request: CompletionRequest,
51 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
52 {
53 let span = if tracing::Span::current().is_disabled() {
54 info_span!(
55 target: "rig::completions",
56 "interactions_streaming",
57 gen_ai.operation.name = "interactions_streaming",
58 gen_ai.provider.name = "gcp.gemini",
59 gen_ai.request.model = self.model,
60 gen_ai.system_instructions = &completion_request.preamble,
61 gen_ai.response.id = tracing::field::Empty,
62 gen_ai.response.model = tracing::field::Empty,
63 gen_ai.usage.output_tokens = tracing::field::Empty,
64 gen_ai.usage.input_tokens = tracing::field::Empty,
65 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
66 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
67 gen_ai.usage.reasoning_tokens = tracing::field::Empty,
68 )
69 } else {
70 tracing::Span::current()
71 };
72
73 let request = create_request_body(self.model.clone(), completion_request, Some(true))?;
74
75 if enabled!(Level::TRACE) {
76 tracing::trace!(
77 target: "rig::streaming",
78 "Gemini interactions streaming request: {}",
79 serde_json::to_string_pretty(&request)?
80 );
81 }
82
83 let body = serde_json::to_vec(&request)?;
84 let req = self
85 .client
86 .post_sse("/v1beta/interactions")?
87 .header("Content-Type", "application/json")
88 .body(body)
89 .map_err(|e| CompletionError::HttpError(e.into()))?;
90
91 let mut event_source = GenericEventSource::new(self.client.clone(), req);
92
93 let stream = stream! {
94 let mut final_interaction: Option<Interaction> = None;
95 let mut final_usage: Option<InteractionUsage> = None;
96
97 while let Some(event_result) = event_source.next().await {
98 match event_result {
99 Ok(Event::Open) => {
100 tracing::debug!("SSE connection opened");
101 continue;
102 }
103 Ok(Event::Message(message)) => {
104 if message.data.trim().is_empty() {
105 continue;
106 }
107
108 let data = match serde_json::from_str::<InteractionSseEvent>(&message.data)
109 {
110 Ok(data) => data,
111 Err(err) => {
112 tracing::debug!(
113 "Failed to deserialize interactions SSE event: {err}"
114 );
115 continue;
116 }
117 };
118
119 match data {
120 InteractionSseEvent::ContentDelta { delta, .. } => {
121 if let Some(choice) = content_delta_to_choice(delta) {
122 yield Ok(choice);
123 }
124 }
125 InteractionSseEvent::ContentStart { content, .. } => {
126 if let Some(choice) = content_start_to_choice(content) {
127 yield Ok(choice);
128 }
129 }
130 InteractionSseEvent::InteractionComplete { interaction, .. } => {
131 let span = tracing::Span::current();
132 span.record("gen_ai.response.id", &interaction.id);
133 if let Some(model) = interaction.model.clone() {
134 span.record("gen_ai.response.model", model);
135 }
136
137 if let Some(usage) = interaction.usage.clone() {
138 span.record_token_usage(&usage);
139 final_usage = Some(usage);
140 }
141 final_interaction = Some(interaction);
142 }
143 InteractionSseEvent::Error { error, .. } => {
144 yield Err(CompletionError::ProviderError(error.message));
145 break;
146 }
147 _ => continue,
148 }
149 }
150 Err(crate::http_client::Error::StreamEnded) => {
151 break;
152 }
153 Err(error) => {
154 tracing::error!(?error, "SSE error");
155 yield Err(CompletionError::ProviderError(error.to_string()));
156 break;
157 }
158 }
159 }
160
161 event_source.close();
162
163 yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
164 usage: final_usage.or_else(|| final_interaction.as_ref().and_then(|i| i.usage.clone())),
165 interaction: final_interaction,
166 }));
167 }
168 .instrument(span);
169
170 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
171 stream,
172 )))
173 }
174}
175
176pub(crate) fn stream_interaction_events<T>(
177 client: super::InteractionsClient<T>,
178 request: Request<Vec<u8>>,
179) -> InteractionEventStream
180where
181 T: HttpClientExt + Clone + Default + std::fmt::Debug + 'static,
182{
183 let mut event_source = GenericEventSource::new(client.clone(), request);
184
185 let stream = stream! {
186 while let Some(event_result) = event_source.next().await {
187 match event_result {
188 Ok(Event::Open) => continue,
189 Ok(Event::Message(message)) => {
190 if message.data.trim().is_empty() {
191 continue;
192 }
193
194 let data = serde_json::from_str::<InteractionSseEvent>(&message.data);
195 let Ok(data) = data else {
196 let Err(err) = data else {
197 continue;
198 };
199 tracing::debug!("Failed to deserialize interactions SSE event: {err}");
200 continue;
201 };
202
203 yield Ok(data);
204 }
205 Err(crate::http_client::Error::StreamEnded) => break,
206 Err(error) => {
207 tracing::error!(?error, "SSE error");
208 yield Err(CompletionError::ProviderError(error.to_string()));
209 break;
210 }
211 }
212 }
213
214 event_source.close();
215 };
216
217 Box::pin(stream)
218}
219
220fn content_start_to_choice(
221 content: Content,
222) -> Option<streaming::RawStreamingChoice<StreamingCompletionResponse>> {
223 match content {
224 Content::Text(TextContent { text, .. }) => {
225 if text.is_empty() {
226 None
227 } else {
228 Some(streaming::RawStreamingChoice::Message(text))
229 }
230 }
231 Content::FunctionCall(FunctionCallContent {
232 name,
233 arguments,
234 id,
235 }) => {
236 let name = name?;
237 let call_id = id.unwrap_or_else(|| name.clone());
238 Some(streaming::RawStreamingChoice::ToolCall(
239 streaming::RawStreamingToolCall::new(
240 name.clone(),
241 name,
242 arguments.unwrap_or(Value::Object(Map::new())),
243 )
244 .with_call_id(call_id),
245 ))
246 }
247 _ => None,
248 }
249}
250
251fn content_delta_to_choice(
252 delta: ContentDelta,
253) -> Option<streaming::RawStreamingChoice<StreamingCompletionResponse>> {
254 match delta {
255 ContentDelta::Text(TextDelta {
256 text: Some(text), ..
257 }) => Some(streaming::RawStreamingChoice::Message(text)),
258 ContentDelta::FunctionCall(FunctionCallDelta {
259 name,
260 arguments,
261 id,
262 }) => {
263 let name = name?;
264 let call_id = id.unwrap_or_else(|| name.clone());
265 Some(streaming::RawStreamingChoice::ToolCall(
266 streaming::RawStreamingToolCall::new(
267 name.clone(),
268 name,
269 arguments.unwrap_or(Value::Object(Map::new())),
270 )
271 .with_call_id(call_id),
272 ))
273 }
274 ContentDelta::ThoughtSummary(ThoughtSummaryDelta { content }) => {
275 let text = match content {
276 ThoughtSummaryContent::Text(text) => text.text,
277 _ => return None,
278 };
279 Some(streaming::RawStreamingChoice::ReasoningDelta {
280 id: None,
281 reasoning: text,
282 })
283 }
284 _ => None,
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291 use serde_json::json;
292
293 #[test]
294 fn test_content_delta_text_event() {
295 let event_json = json!({
296 "event_type": "content.delta",
297 "index": 0,
298 "delta": {
299 "type": "text",
300 "text": "Hello"
301 }
302 });
303
304 let event: InteractionSseEvent = serde_json::from_value(event_json).unwrap();
305 let InteractionSseEvent::ContentDelta { delta, .. } = event else {
306 panic!("expected content delta");
307 };
308
309 let choice = content_delta_to_choice(delta).expect("choice should exist");
310 match choice {
311 crate::streaming::RawStreamingChoice::Message(text) => {
312 assert_eq!(text, "Hello");
313 }
314 other => panic!("unexpected choice: {other:?}"),
315 }
316 }
317
318 #[test]
319 fn test_content_delta_function_call_event() {
320 let event_json = json!({
321 "event_type": "content.delta",
322 "index": 0,
323 "delta": {
324 "type": "function_call",
325 "name": "get_weather",
326 "arguments": {"location": "Paris"},
327 "id": "call-1"
328 }
329 });
330
331 let event: InteractionSseEvent = serde_json::from_value(event_json).unwrap();
332 let InteractionSseEvent::ContentDelta { delta, .. } = event else {
333 panic!("expected content delta");
334 };
335
336 let choice = content_delta_to_choice(delta).expect("choice should exist");
337 match choice {
338 crate::streaming::RawStreamingChoice::ToolCall(call) => {
339 assert_eq!(call.name, "get_weather");
340 assert_eq!(call.call_id.as_deref(), Some("call-1"));
341 }
342 other => panic!("unexpected choice: {other:?}"),
343 }
344 }
345}