1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use tracing::{Level, enabled, info_span};
5use tracing_futures::Instrument;
6
7use super::completion::gemini_api_types::{ContentCandidate, Part, PartKind};
8use super::completion::{CompletionModel, create_request_body};
9use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
10use crate::http_client::HttpClientExt;
11use crate::http_client::sse::{Event, GenericEventSource};
12use crate::streaming;
13use crate::telemetry::SpanCombinator;
14
15#[derive(Debug, Deserialize, Serialize, Default, Clone)]
16#[serde(rename_all = "camelCase")]
17pub struct PartialUsage {
18 pub total_token_count: i32,
19 #[serde(skip_serializing_if = "Option::is_none")]
20 pub cached_content_token_count: Option<i32>,
21 #[serde(skip_serializing_if = "Option::is_none")]
22 pub candidates_token_count: Option<i32>,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 pub thoughts_token_count: Option<i32>,
25 pub prompt_token_count: i32,
26}
27
28impl GetTokenUsage for PartialUsage {
29 fn token_usage(&self) -> Option<crate::completion::Usage> {
30 let mut usage = crate::completion::Usage::new();
31
32 usage.input_tokens = self.prompt_token_count as u64;
33 usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
34 + self.candidates_token_count.unwrap_or_default()
35 + self.thoughts_token_count.unwrap_or_default()) as u64;
36 usage.total_tokens = usage.input_tokens + usage.output_tokens;
37
38 Some(usage)
39 }
40}
41
42#[derive(Debug, Deserialize)]
43#[serde(rename_all = "camelCase")]
44pub struct StreamGenerateContentResponse {
45 pub candidates: Vec<ContentCandidate>,
47 pub model_version: Option<String>,
48 pub usage_metadata: Option<PartialUsage>,
49}
50
51#[derive(Clone, Debug, Serialize, Deserialize)]
52pub struct StreamingCompletionResponse {
53 pub usage_metadata: PartialUsage,
54}
55
56impl GetTokenUsage for StreamingCompletionResponse {
57 fn token_usage(&self) -> Option<crate::completion::Usage> {
58 let mut usage = crate::completion::Usage::new();
59 usage.total_tokens = self.usage_metadata.total_token_count as u64;
60 usage.output_tokens = self
61 .usage_metadata
62 .candidates_token_count
63 .map(|x| x as u64)
64 .unwrap_or(0);
65 usage.input_tokens = self.usage_metadata.prompt_token_count as u64;
66 Some(usage)
67 }
68}
69
70impl<T> CompletionModel<T>
71where
72 T: HttpClientExt + Clone + 'static,
73{
74 pub(crate) async fn stream(
75 &self,
76 completion_request: CompletionRequest,
77 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
78 {
79 let span = if tracing::Span::current().is_disabled() {
80 info_span!(
81 target: "rig::completions",
82 "chat_streaming",
83 gen_ai.operation.name = "chat_streaming",
84 gen_ai.provider.name = "gcp.gemini",
85 gen_ai.request.model = self.model,
86 gen_ai.system_instructions = &completion_request.preamble,
87 gen_ai.response.id = tracing::field::Empty,
88 gen_ai.response.model = self.model,
89 gen_ai.usage.output_tokens = tracing::field::Empty,
90 gen_ai.usage.input_tokens = tracing::field::Empty,
91 )
92 } else {
93 tracing::Span::current()
94 };
95 let request = create_request_body(completion_request)?;
96
97 if enabled!(Level::TRACE) {
98 tracing::trace!(
99 target: "rig::streaming",
100 "Gemini streaming completion request: {}",
101 serde_json::to_string_pretty(&request)?
102 );
103 }
104
105 let body = serde_json::to_vec(&request)?;
106
107 let req = self
108 .client
109 .post_sse(format!(
110 "/v1beta/models/{}:streamGenerateContent",
111 self.model
112 ))?
113 .header("Content-Type", "application/json")
114 .body(body)
115 .map_err(|e| CompletionError::HttpError(e.into()))?;
116
117 let mut event_source = GenericEventSource::new(self.client.clone(), req);
118
119 let stream = stream! {
120 let mut final_usage = None;
121 while let Some(event_result) = event_source.next().await {
122 match event_result {
123 Ok(Event::Open) => {
124 tracing::debug!("SSE connection opened");
125 continue;
126 }
127 Ok(Event::Message(message)) => {
128 if message.data.trim().is_empty() {
130 continue;
131 }
132
133 let data = match serde_json::from_str::<StreamGenerateContentResponse>(&message.data) {
134 Ok(d) => d,
135 Err(error) => {
136 tracing::error!(?error, message = message.data, "Failed to parse SSE message");
137 continue;
138 }
139 };
140
141 let Some(choice) = data.candidates.into_iter().next() else {
143 tracing::debug!("There is no content candidate");
144 continue;
145 };
146
147 let Some(content) = choice.content else {
148 tracing::debug!(finish_reason = ?choice.finish_reason, "Streaming candidate missing content");
149 continue;
150 };
151
152 if content.parts.is_empty() {
153 tracing::trace!(reason = ?choice.finish_reason, "There is no part in the streaming content");
154 }
155
156 for part in content.parts {
157 match part {
158 Part {
159 part: PartKind::Text(text),
160 thought: Some(true),
161 ..
162 } => {
163 yield Ok(streaming::RawStreamingChoice::ReasoningDelta {
164 id: None,
165 reasoning: text,
166 });
167 },
168 Part {
169 part: PartKind::Text(text),
170 ..
171 } => {
172 yield Ok(streaming::RawStreamingChoice::Message(text));
173 },
174 Part {
175 part: PartKind::FunctionCall(function_call),
176 thought_signature,
177 ..
178 } => {
179 yield Ok(streaming::RawStreamingChoice::ToolCall(
180 streaming::RawStreamingToolCall::new(function_call.name.clone(), function_call.name.clone(), function_call.args.clone())
181 .with_signature(thought_signature)
182 ));
183 },
184 part => {
185 tracing::warn!(?part, "Unsupported response type with streaming");
186 }
187 }
188 }
189
190 if choice.finish_reason.is_some() {
192 let span = tracing::Span::current();
193 span.record_token_usage(&data.usage_metadata);
194 final_usage = data.usage_metadata;
195 break;
196 }
197 }
198 Err(crate::http_client::Error::StreamEnded) => {
199 break;
200 }
201 Err(error) => {
202 tracing::error!(?error, "SSE error");
203 yield Err(CompletionError::ProviderError(error.to_string()));
204 break;
205 }
206 }
207 }
208
209 event_source.close();
211
212 yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
213 usage_metadata: final_usage.unwrap_or_default()
214 }));
215 }.instrument(span);
216
217 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
218 stream,
219 )))
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226 use serde_json::json;
227
228 #[test]
229 fn test_deserialize_stream_response_with_single_text_part() {
230 let json_data = json!({
231 "candidates": [{
232 "content": {
233 "parts": [
234 {"text": "Hello, world!"}
235 ],
236 "role": "model"
237 },
238 "finishReason": "STOP",
239 "index": 0
240 }],
241 "usageMetadata": {
242 "promptTokenCount": 10,
243 "candidatesTokenCount": 5,
244 "totalTokenCount": 15
245 }
246 });
247
248 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
249 assert_eq!(response.candidates.len(), 1);
250 let content = response.candidates[0]
251 .content
252 .as_ref()
253 .expect("candidate should contain content");
254 assert_eq!(content.parts.len(), 1);
255
256 if let Part {
257 part: PartKind::Text(text),
258 ..
259 } = &content.parts[0]
260 {
261 assert_eq!(text, "Hello, world!");
262 } else {
263 panic!("Expected text part");
264 }
265 }
266
267 #[test]
268 fn test_deserialize_stream_response_with_multiple_text_parts() {
269 let json_data = json!({
270 "candidates": [{
271 "content": {
272 "parts": [
273 {"text": "Hello, "},
274 {"text": "world!"},
275 {"text": " How are you?"}
276 ],
277 "role": "model"
278 },
279 "finishReason": "STOP",
280 "index": 0
281 }],
282 "usageMetadata": {
283 "promptTokenCount": 10,
284 "candidatesTokenCount": 8,
285 "totalTokenCount": 18
286 }
287 });
288
289 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
290 assert_eq!(response.candidates.len(), 1);
291 let content = response.candidates[0]
292 .content
293 .as_ref()
294 .expect("candidate should contain content");
295 assert_eq!(content.parts.len(), 3);
296
297 for (i, expected_text) in ["Hello, ", "world!", " How are you?"].iter().enumerate() {
299 if let Part {
300 part: PartKind::Text(text),
301 ..
302 } = &content.parts[i]
303 {
304 assert_eq!(text, expected_text);
305 } else {
306 panic!("Expected text part at index {}", i);
307 }
308 }
309 }
310
311 #[test]
312 fn test_deserialize_stream_response_with_multiple_tool_calls() {
313 let json_data = json!({
314 "candidates": [{
315 "content": {
316 "parts": [
317 {
318 "functionCall": {
319 "name": "get_weather",
320 "args": {"city": "San Francisco"}
321 }
322 },
323 {
324 "functionCall": {
325 "name": "get_temperature",
326 "args": {"location": "New York"}
327 }
328 }
329 ],
330 "role": "model"
331 },
332 "finishReason": "STOP",
333 "index": 0
334 }],
335 "usageMetadata": {
336 "promptTokenCount": 50,
337 "candidatesTokenCount": 20,
338 "totalTokenCount": 70
339 }
340 });
341
342 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
343 let content = response.candidates[0]
344 .content
345 .as_ref()
346 .expect("candidate should contain content");
347 assert_eq!(content.parts.len(), 2);
348
349 if let Part {
351 part: PartKind::FunctionCall(call),
352 ..
353 } = &content.parts[0]
354 {
355 assert_eq!(call.name, "get_weather");
356 } else {
357 panic!("Expected function call at index 0");
358 }
359
360 if let Part {
362 part: PartKind::FunctionCall(call),
363 ..
364 } = &content.parts[1]
365 {
366 assert_eq!(call.name, "get_temperature");
367 } else {
368 panic!("Expected function call at index 1");
369 }
370 }
371
372 #[test]
373 fn test_deserialize_stream_response_with_mixed_parts() {
374 let json_data = json!({
375 "candidates": [{
376 "content": {
377 "parts": [
378 {
379 "text": "Let me think about this...",
380 "thought": true
381 },
382 {
383 "text": "Here's my response: "
384 },
385 {
386 "functionCall": {
387 "name": "search",
388 "args": {"query": "rust async"}
389 }
390 },
391 {
392 "text": "I found the answer!"
393 }
394 ],
395 "role": "model"
396 },
397 "finishReason": "STOP",
398 "index": 0
399 }],
400 "usageMetadata": {
401 "promptTokenCount": 100,
402 "candidatesTokenCount": 50,
403 "thoughtsTokenCount": 15,
404 "totalTokenCount": 165
405 }
406 });
407
408 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
409 let content = response.candidates[0]
410 .content
411 .as_ref()
412 .expect("candidate should contain content");
413 let parts = &content.parts;
414 assert_eq!(parts.len(), 4);
415
416 if let Part {
418 part: PartKind::Text(text),
419 thought: Some(true),
420 ..
421 } = &parts[0]
422 {
423 assert_eq!(text, "Let me think about this...");
424 } else {
425 panic!("Expected thought part at index 0");
426 }
427
428 if let Part {
430 part: PartKind::Text(text),
431 thought,
432 ..
433 } = &parts[1]
434 {
435 assert_eq!(text, "Here's my response: ");
436 assert!(thought.is_none() || thought == &Some(false));
437 } else {
438 panic!("Expected text part at index 1");
439 }
440
441 if let Part {
443 part: PartKind::FunctionCall(call),
444 ..
445 } = &parts[2]
446 {
447 assert_eq!(call.name, "search");
448 } else {
449 panic!("Expected function call at index 2");
450 }
451
452 if let Part {
454 part: PartKind::Text(text),
455 ..
456 } = &parts[3]
457 {
458 assert_eq!(text, "I found the answer!");
459 } else {
460 panic!("Expected text part at index 3");
461 }
462 }
463
464 #[test]
465 fn test_deserialize_stream_response_with_empty_parts() {
466 let json_data = json!({
467 "candidates": [{
468 "content": {
469 "parts": [],
470 "role": "model"
471 },
472 "finishReason": "STOP",
473 "index": 0
474 }],
475 "usageMetadata": {
476 "promptTokenCount": 10,
477 "candidatesTokenCount": 0,
478 "totalTokenCount": 10
479 }
480 });
481
482 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
483 let content = response.candidates[0]
484 .content
485 .as_ref()
486 .expect("candidate should contain content");
487 assert_eq!(content.parts.len(), 0);
488 }
489
490 #[test]
491 fn test_partial_usage_token_calculation() {
492 let usage = PartialUsage {
493 total_token_count: 100,
494 cached_content_token_count: Some(20),
495 candidates_token_count: Some(30),
496 thoughts_token_count: Some(10),
497 prompt_token_count: 40,
498 };
499
500 let token_usage = usage.token_usage().unwrap();
501 assert_eq!(token_usage.input_tokens, 40);
502 assert_eq!(token_usage.output_tokens, 60); assert_eq!(token_usage.total_tokens, 100);
504 }
505
506 #[test]
507 fn test_partial_usage_with_missing_counts() {
508 let usage = PartialUsage {
509 total_token_count: 50,
510 cached_content_token_count: None,
511 candidates_token_count: Some(30),
512 thoughts_token_count: None,
513 prompt_token_count: 20,
514 };
515
516 let token_usage = usage.token_usage().unwrap();
517 assert_eq!(token_usage.input_tokens, 20);
518 assert_eq!(token_usage.output_tokens, 30); assert_eq!(token_usage.total_tokens, 50);
520 }
521
522 #[test]
523 fn test_streaming_completion_response_token_usage() {
524 let response = StreamingCompletionResponse {
525 usage_metadata: PartialUsage {
526 total_token_count: 150,
527 cached_content_token_count: None,
528 candidates_token_count: Some(75),
529 thoughts_token_count: None,
530 prompt_token_count: 75,
531 },
532 };
533
534 let token_usage = response.token_usage().unwrap();
535 assert_eq!(token_usage.input_tokens, 75);
536 assert_eq!(token_usage.output_tokens, 75);
537 assert_eq!(token_usage.total_tokens, 150);
538 }
539}