1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use tracing::{Level, enabled, info_span};
5use tracing_futures::Instrument;
6
7use super::completion::gemini_api_types::{
8 ContentCandidate, ModalityTokenCount, Part, PartKind, TrafficType,
9};
10use super::completion::{
11 CompletionModel, create_request_body, resolve_request_model, streaming_endpoint,
12};
13use crate::completion::message::ReasoningContent;
14use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
15use crate::http_client::HttpClientExt;
16use crate::http_client::sse::{Event, GenericEventSource};
17use crate::streaming;
18use crate::telemetry::SpanCombinator;
19
20#[derive(Debug, Deserialize, Serialize, Default, Clone)]
21#[serde(rename_all = "camelCase")]
22pub struct PartialUsage {
23 pub total_token_count: i32,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub cached_content_token_count: Option<i32>,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub candidates_token_count: Option<i32>,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub thoughts_token_count: Option<i32>,
30 #[serde(default)]
31 pub prompt_token_count: i32,
32 #[serde(default, skip_serializing_if = "Option::is_none")]
33 pub prompt_tokens_details: Option<Vec<ModalityTokenCount>>,
34 #[serde(default, skip_serializing_if = "Option::is_none")]
35 pub cache_tokens_details: Option<Vec<ModalityTokenCount>>,
36 #[serde(default, skip_serializing_if = "Option::is_none")]
37 pub candidates_tokens_details: Option<Vec<ModalityTokenCount>>,
38 #[serde(default, skip_serializing_if = "Option::is_none")]
39 pub tool_use_prompt_token_count: Option<i32>,
40 #[serde(default, skip_serializing_if = "Option::is_none")]
41 pub tool_use_prompt_tokens_details: Option<Vec<ModalityTokenCount>>,
42 #[serde(default, skip_serializing_if = "Option::is_none")]
43 pub traffic_type: Option<TrafficType>,
44}
45
46impl GetTokenUsage for PartialUsage {
47 fn token_usage(&self) -> Option<crate::completion::Usage> {
48 let mut usage = crate::completion::Usage::new();
49
50 usage.input_tokens = self.prompt_token_count as u64;
51 usage.output_tokens = self.candidates_token_count.unwrap_or_default() as u64;
52 usage.cached_input_tokens = self.cached_content_token_count.unwrap_or_default() as u64;
53 usage.reasoning_tokens = self.thoughts_token_count.unwrap_or_default() as u64;
54 usage.total_tokens = self.total_token_count as u64;
55
56 Some(usage)
57 }
58}
59
60#[derive(Debug, Deserialize)]
61#[serde(rename_all = "camelCase")]
62pub struct StreamGenerateContentResponse {
63 pub response_id: Option<String>,
64 #[serde(default)]
66 pub candidates: Vec<ContentCandidate>,
67 pub model_version: Option<String>,
68 pub usage_metadata: Option<PartialUsage>,
69}
70
71#[derive(Clone, Debug, Serialize, Deserialize)]
72pub struct StreamingCompletionResponse {
73 pub usage_metadata: PartialUsage,
74}
75
76impl GetTokenUsage for StreamingCompletionResponse {
77 fn token_usage(&self) -> Option<crate::completion::Usage> {
78 self.usage_metadata.token_usage()
79 }
80}
81
82impl<T> CompletionModel<T>
83where
84 T: HttpClientExt + Clone + 'static,
85{
86 pub(crate) async fn stream(
87 &self,
88 completion_request: CompletionRequest,
89 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
90 {
91 let request_model = resolve_request_model(&self.model, &completion_request);
92 let span = if tracing::Span::current().is_disabled() {
93 info_span!(
94 target: "rig::completions",
95 "chat_streaming",
96 gen_ai.operation.name = "chat_streaming",
97 gen_ai.provider.name = "gcp.gemini",
98 gen_ai.request.model = &request_model,
99 gen_ai.system_instructions = &completion_request.preamble,
100 gen_ai.response.id = tracing::field::Empty,
101 gen_ai.response.model = &request_model,
102 gen_ai.usage.output_tokens = tracing::field::Empty,
103 gen_ai.usage.input_tokens = tracing::field::Empty,
104 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
105 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
106 gen_ai.usage.reasoning_tokens = tracing::field::Empty,
107 )
108 } else {
109 tracing::Span::current()
110 };
111 let request = create_request_body(completion_request)?;
112
113 if enabled!(Level::TRACE) {
114 tracing::trace!(
115 target: "rig::streaming",
116 "Gemini streaming completion request: {}",
117 serde_json::to_string_pretty(&request)?
118 );
119 }
120
121 let body = serde_json::to_vec(&request)?;
122
123 let req = self
124 .client
125 .post_sse(streaming_endpoint(&request_model))?
126 .header("Content-Type", "application/json")
127 .body(body)
128 .map_err(|e| CompletionError::HttpError(e.into()))?;
129
130 let mut event_source = GenericEventSource::new(self.client.clone(), req);
131
132 let stream = stream! {
133 let mut final_usage = None;
134 while let Some(event_result) = event_source.next().await {
135 match event_result {
136 Ok(Event::Open) => {
137 tracing::debug!("SSE connection opened");
138 continue;
139 }
140 Ok(Event::Message(message)) => {
141 if message.data.trim().is_empty() {
143 continue;
144 }
145
146 let data = match serde_json::from_str::<StreamGenerateContentResponse>(&message.data) {
147 Ok(d) => d,
148 Err(error) => {
149 tracing::error!(?error, message = message.data, "Failed to parse SSE message");
150 continue;
151 }
152 };
153
154 let span = tracing::Span::current();
155 if let Some(response_id) = data.response_id.as_deref() {
156 span.record("gen_ai.response.id", response_id);
157 }
158 if let Some(model_version) = data.model_version.as_deref() {
159 span.record("gen_ai.response.model", model_version);
160 }
161 if let Some(usage) = data.usage_metadata.as_ref() {
162 span.record_token_usage(usage);
163 final_usage = Some(usage.clone());
164 }
165
166 let Some(choice) = data.candidates.into_iter().next() else {
168 tracing::debug!("There is no content candidate");
169 continue;
170 };
171
172 let Some(content) = choice.content else {
173 tracing::debug!(finish_reason = ?choice.finish_reason, "Streaming candidate missing content");
174 continue;
175 };
176
177 if content.parts.is_empty() {
178 tracing::trace!(reason = ?choice.finish_reason, "There is no part in the streaming content");
179 }
180
181 for part in content.parts {
182 match part {
183 Part {
184 part: PartKind::Text(text),
185 thought: Some(true),
186 thought_signature,
187 ..
188 } => {
189 if !text.is_empty() {
190 if thought_signature.is_some() {
191 yield Ok(streaming::RawStreamingChoice::Reasoning {
196 id: None,
197 content: ReasoningContent::Text {
198 text,
199 signature: thought_signature,
200 },
201 });
202 } else {
203 yield Ok(streaming::RawStreamingChoice::ReasoningDelta {
204 id: None,
205 reasoning: text,
206 });
207 }
208 }
209 },
210 Part {
211 part: PartKind::Text(text),
212 ..
213 } => {
214 if !text.is_empty() {
215 yield Ok(streaming::RawStreamingChoice::Message(text));
216 }
217 },
218 Part {
219 part: PartKind::FunctionCall(function_call),
220 thought_signature,
221 ..
222 } => {
223 yield Ok(streaming::RawStreamingChoice::ToolCall(
224 streaming::RawStreamingToolCall::new(function_call.name.clone(), function_call.name.clone(), function_call.args.clone())
225 .with_signature(thought_signature)
226 ));
227 },
228 part => {
229 tracing::warn!(?part, "Unsupported response type with streaming");
230 }
231 }
232 }
233
234 if choice.finish_reason.is_some() {
236 break;
237 }
238 }
239 Err(crate::http_client::Error::StreamEnded) => {
240 break;
241 }
242 Err(error) => {
243 tracing::error!(?error, "SSE error");
244 yield Err(CompletionError::ProviderError(error.to_string()));
245 break;
246 }
247 }
248 }
249
250 event_source.close();
252
253 yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
254 usage_metadata: final_usage.unwrap_or_default()
255 }));
256 }.instrument(span);
257
258 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
259 stream,
260 )))
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use serde_json::json;
268
269 #[test]
270 fn test_deserialize_stream_response_with_single_text_part() {
271 let json_data = json!({
272 "candidates": [{
273 "content": {
274 "parts": [
275 {"text": "Hello, world!"}
276 ],
277 "role": "model"
278 },
279 "finishReason": "STOP",
280 "index": 0
281 }],
282 "usageMetadata": {
283 "promptTokenCount": 10,
284 "candidatesTokenCount": 5,
285 "totalTokenCount": 15
286 }
287 });
288
289 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
290 assert_eq!(response.candidates.len(), 1);
291 let content = response.candidates[0]
292 .content
293 .as_ref()
294 .expect("candidate should contain content");
295 assert_eq!(content.parts.len(), 1);
296
297 if let Part {
298 part: PartKind::Text(text),
299 ..
300 } = &content.parts[0]
301 {
302 assert_eq!(text, "Hello, world!");
303 } else {
304 panic!("Expected text part");
305 }
306 }
307
308 #[test]
309 fn test_deserialize_stream_response_with_usage_only_chunk() {
310 let json_data = json!({
311 "responseId": "response-123",
312 "modelVersion": "gemini-2.0-flash-001",
313 "usageMetadata": {
314 "promptTokenCount": 10,
315 "candidatesTokenCount": 5,
316 "totalTokenCount": 15
317 }
318 });
319
320 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
321 assert_eq!(response.response_id.as_deref(), Some("response-123"));
322 assert_eq!(
323 response.model_version.as_deref(),
324 Some("gemini-2.0-flash-001")
325 );
326 assert!(response.candidates.is_empty());
327
328 let usage = response
329 .usage_metadata
330 .as_ref()
331 .and_then(GetTokenUsage::token_usage)
332 .unwrap();
333 assert_eq!(usage.input_tokens, 10);
334 assert_eq!(usage.output_tokens, 5);
335 assert_eq!(usage.total_tokens, 15);
336 }
337
338 #[test]
339 fn test_deserialize_stream_response_with_multiple_text_parts() {
340 let json_data = json!({
341 "candidates": [{
342 "content": {
343 "parts": [
344 {"text": "Hello, "},
345 {"text": "world!"},
346 {"text": " How are you?"}
347 ],
348 "role": "model"
349 },
350 "finishReason": "STOP",
351 "index": 0
352 }],
353 "usageMetadata": {
354 "promptTokenCount": 10,
355 "candidatesTokenCount": 8,
356 "totalTokenCount": 18
357 }
358 });
359
360 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
361 assert_eq!(response.candidates.len(), 1);
362 let content = response.candidates[0]
363 .content
364 .as_ref()
365 .expect("candidate should contain content");
366 assert_eq!(content.parts.len(), 3);
367
368 for (i, expected_text) in ["Hello, ", "world!", " How are you?"].iter().enumerate() {
370 if let Part {
371 part: PartKind::Text(text),
372 ..
373 } = &content.parts[i]
374 {
375 assert_eq!(text, expected_text);
376 } else {
377 panic!("Expected text part at index {}", i);
378 }
379 }
380 }
381
382 #[test]
383 fn test_deserialize_stream_response_with_multiple_tool_calls() {
384 let json_data = json!({
385 "candidates": [{
386 "content": {
387 "parts": [
388 {
389 "functionCall": {
390 "name": "get_weather",
391 "args": {"city": "San Francisco"}
392 }
393 },
394 {
395 "functionCall": {
396 "name": "get_temperature",
397 "args": {"location": "New York"}
398 }
399 }
400 ],
401 "role": "model"
402 },
403 "finishReason": "STOP",
404 "index": 0
405 }],
406 "usageMetadata": {
407 "promptTokenCount": 50,
408 "candidatesTokenCount": 20,
409 "totalTokenCount": 70
410 }
411 });
412
413 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
414 let content = response.candidates[0]
415 .content
416 .as_ref()
417 .expect("candidate should contain content");
418 assert_eq!(content.parts.len(), 2);
419
420 if let Part {
422 part: PartKind::FunctionCall(call),
423 ..
424 } = &content.parts[0]
425 {
426 assert_eq!(call.name, "get_weather");
427 } else {
428 panic!("Expected function call at index 0");
429 }
430
431 if let Part {
433 part: PartKind::FunctionCall(call),
434 ..
435 } = &content.parts[1]
436 {
437 assert_eq!(call.name, "get_temperature");
438 } else {
439 panic!("Expected function call at index 1");
440 }
441 }
442
443 #[test]
444 fn test_deserialize_stream_response_with_mixed_parts() {
445 let json_data = json!({
446 "candidates": [{
447 "content": {
448 "parts": [
449 {
450 "text": "Let me think about this...",
451 "thought": true
452 },
453 {
454 "text": "Here's my response: "
455 },
456 {
457 "functionCall": {
458 "name": "search",
459 "args": {"query": "rust async"}
460 }
461 },
462 {
463 "text": "I found the answer!"
464 }
465 ],
466 "role": "model"
467 },
468 "finishReason": "STOP",
469 "index": 0
470 }],
471 "usageMetadata": {
472 "promptTokenCount": 100,
473 "candidatesTokenCount": 50,
474 "thoughtsTokenCount": 15,
475 "totalTokenCount": 165
476 }
477 });
478
479 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
480 let content = response.candidates[0]
481 .content
482 .as_ref()
483 .expect("candidate should contain content");
484 let parts = &content.parts;
485 assert_eq!(parts.len(), 4);
486
487 if let Part {
489 part: PartKind::Text(text),
490 thought: Some(true),
491 ..
492 } = &parts[0]
493 {
494 assert_eq!(text, "Let me think about this...");
495 } else {
496 panic!("Expected thought part at index 0");
497 }
498
499 if let Part {
501 part: PartKind::Text(text),
502 thought,
503 ..
504 } = &parts[1]
505 {
506 assert_eq!(text, "Here's my response: ");
507 assert!(thought.is_none() || thought == &Some(false));
508 } else {
509 panic!("Expected text part at index 1");
510 }
511
512 if let Part {
514 part: PartKind::FunctionCall(call),
515 ..
516 } = &parts[2]
517 {
518 assert_eq!(call.name, "search");
519 } else {
520 panic!("Expected function call at index 2");
521 }
522
523 if let Part {
525 part: PartKind::Text(text),
526 ..
527 } = &parts[3]
528 {
529 assert_eq!(text, "I found the answer!");
530 } else {
531 panic!("Expected text part at index 3");
532 }
533 }
534
535 #[test]
536 fn test_deserialize_stream_response_with_empty_parts() {
537 let json_data = json!({
538 "candidates": [{
539 "content": {
540 "parts": [],
541 "role": "model"
542 },
543 "finishReason": "STOP",
544 "index": 0
545 }],
546 "usageMetadata": {
547 "promptTokenCount": 10,
548 "candidatesTokenCount": 0,
549 "totalTokenCount": 10
550 }
551 });
552
553 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
554 let content = response.candidates[0]
555 .content
556 .as_ref()
557 .expect("candidate should contain content");
558 assert_eq!(content.parts.len(), 0);
559 }
560
561 #[test]
562 fn test_partial_usage_token_calculation() {
563 let usage = PartialUsage {
564 total_token_count: 100,
565 cached_content_token_count: Some(20),
566 candidates_token_count: Some(30),
567 thoughts_token_count: Some(10),
568 prompt_token_count: 40,
569 prompt_tokens_details: None,
570 cache_tokens_details: None,
571 candidates_tokens_details: None,
572 tool_use_prompt_token_count: None,
573 tool_use_prompt_tokens_details: None,
574 traffic_type: None,
575 };
576
577 let token_usage = usage.token_usage().unwrap();
578 assert_eq!(token_usage.input_tokens, 40);
579 assert_eq!(token_usage.cached_input_tokens, 20);
580 assert_eq!(token_usage.output_tokens, 30);
581 assert_eq!(token_usage.reasoning_tokens, 10);
582 assert_eq!(token_usage.total_tokens, 100);
583 }
584
585 #[test]
586 fn test_partial_usage_with_missing_counts() {
587 let usage = PartialUsage {
588 total_token_count: 50,
589 cached_content_token_count: None,
590 candidates_token_count: Some(30),
591 thoughts_token_count: None,
592 prompt_token_count: 20,
593 prompt_tokens_details: None,
594 cache_tokens_details: None,
595 candidates_tokens_details: None,
596 tool_use_prompt_token_count: None,
597 tool_use_prompt_tokens_details: None,
598 traffic_type: None,
599 };
600
601 let token_usage = usage.token_usage().unwrap();
602 assert_eq!(token_usage.input_tokens, 20);
603 assert_eq!(token_usage.cached_input_tokens, 0);
604 assert_eq!(token_usage.output_tokens, 30);
605 assert_eq!(token_usage.reasoning_tokens, 0);
606 assert_eq!(token_usage.total_tokens, 50);
607 }
608
609 #[test]
610 fn test_streaming_completion_response_token_usage() {
611 let response = StreamingCompletionResponse {
612 usage_metadata: PartialUsage {
613 total_token_count: 150,
614 cached_content_token_count: None,
615 candidates_token_count: Some(75),
616 thoughts_token_count: None,
617 prompt_token_count: 75,
618 prompt_tokens_details: None,
619 cache_tokens_details: None,
620 candidates_tokens_details: None,
621 tool_use_prompt_token_count: None,
622 tool_use_prompt_tokens_details: None,
623 traffic_type: None,
624 },
625 };
626
627 let token_usage = response.token_usage().unwrap();
628 assert_eq!(token_usage.input_tokens, 75);
629 assert_eq!(token_usage.output_tokens, 75);
630 assert_eq!(token_usage.reasoning_tokens, 0);
631 assert_eq!(token_usage.cached_input_tokens, 0);
632 assert_eq!(token_usage.total_tokens, 150);
633 }
634
635 #[test]
636 fn test_partial_usage_serde_roundtrip_with_all_optional_fields() {
637 let json_data = serde_json::json!({
638 "promptTokenCount": 100,
639 "cachedContentTokenCount": 25,
640 "candidatesTokenCount": 50,
641 "thoughtsTokenCount": 15,
642 "totalTokenCount": 190,
643 "promptTokensDetails": [
644 { "modality": "TEXT", "tokenCount": 80 },
645 { "modality": "IMAGE", "tokenCount": 20 }
646 ],
647 "cacheTokensDetails": [
648 { "modality": "TEXT", "tokenCount": 25 }
649 ],
650 "candidatesTokensDetails": [
651 { "modality": "TEXT", "tokenCount": 50 }
652 ],
653 "toolUsePromptTokenCount": 12,
654 "toolUsePromptTokensDetails": [
655 { "modality": "TEXT", "tokenCount": 12 }
656 ],
657 "trafficType": "PROVISIONED_THROUGHPUT"
658 });
659
660 let usage: PartialUsage = serde_json::from_value(json_data).unwrap();
661 assert_eq!(usage.prompt_token_count, 100);
662 assert_eq!(usage.cached_content_token_count, Some(25));
663 assert_eq!(usage.candidates_token_count, Some(50));
664 assert_eq!(usage.thoughts_token_count, Some(15));
665 assert_eq!(usage.total_token_count, 190);
666 assert!(usage.prompt_tokens_details.is_some());
667 assert_eq!(usage.prompt_tokens_details.as_ref().unwrap().len(), 2);
668 assert!(usage.cache_tokens_details.is_some());
669 assert!(usage.candidates_tokens_details.is_some());
670 assert_eq!(usage.tool_use_prompt_token_count, Some(12));
671 assert!(usage.tool_use_prompt_tokens_details.is_some());
672 assert!(matches!(
673 usage.traffic_type,
674 Some(TrafficType::ProvisionedThroughput)
675 ));
676
677 let token_usage = usage.token_usage().unwrap();
678 assert_eq!(token_usage.input_tokens, 100);
679 assert_eq!(token_usage.cached_input_tokens, 25);
680 assert_eq!(token_usage.output_tokens, 50);
681 assert_eq!(token_usage.reasoning_tokens, 15);
682 assert_eq!(token_usage.total_tokens, 190);
683 }
684}