1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use tracing::{Level, enabled, info_span};
5use tracing_futures::Instrument;
6
7use super::completion::gemini_api_types::{ContentCandidate, Part, PartKind};
8use super::completion::{
9 CompletionModel, create_request_body, resolve_request_model, streaming_endpoint,
10};
11use crate::completion::message::ReasoningContent;
12use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
13use crate::http_client::HttpClientExt;
14use crate::http_client::sse::{Event, GenericEventSource};
15use crate::streaming;
16use crate::telemetry::SpanCombinator;
17
18#[derive(Debug, Deserialize, Serialize, Default, Clone)]
19#[serde(rename_all = "camelCase")]
20pub struct PartialUsage {
21 pub total_token_count: i32,
22 #[serde(skip_serializing_if = "Option::is_none")]
23 pub cached_content_token_count: Option<i32>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub candidates_token_count: Option<i32>,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub thoughts_token_count: Option<i32>,
28 pub prompt_token_count: i32,
29}
30
31impl GetTokenUsage for PartialUsage {
32 fn token_usage(&self) -> Option<crate::completion::Usage> {
33 let mut usage = crate::completion::Usage::new();
34
35 usage.input_tokens = self.prompt_token_count as u64;
36 usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
37 + self.candidates_token_count.unwrap_or_default()
38 + self.thoughts_token_count.unwrap_or_default()) as u64;
39 usage.total_tokens = usage.input_tokens + usage.output_tokens;
40
41 Some(usage)
42 }
43}
44
45#[derive(Debug, Deserialize)]
46#[serde(rename_all = "camelCase")]
47pub struct StreamGenerateContentResponse {
48 pub candidates: Vec<ContentCandidate>,
50 pub model_version: Option<String>,
51 pub usage_metadata: Option<PartialUsage>,
52}
53
54#[derive(Clone, Debug, Serialize, Deserialize)]
55pub struct StreamingCompletionResponse {
56 pub usage_metadata: PartialUsage,
57}
58
59impl GetTokenUsage for StreamingCompletionResponse {
60 fn token_usage(&self) -> Option<crate::completion::Usage> {
61 let mut usage = crate::completion::Usage::new();
62 usage.total_tokens = self.usage_metadata.total_token_count as u64;
63 usage.output_tokens = self
64 .usage_metadata
65 .candidates_token_count
66 .map(|x| x as u64)
67 .unwrap_or(0);
68 usage.input_tokens = self.usage_metadata.prompt_token_count as u64;
69 Some(usage)
70 }
71}
72
73impl<T> CompletionModel<T>
74where
75 T: HttpClientExt + Clone + 'static,
76{
77 pub(crate) async fn stream(
78 &self,
79 completion_request: CompletionRequest,
80 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
81 {
82 let request_model = resolve_request_model(&self.model, &completion_request);
83 let span = if tracing::Span::current().is_disabled() {
84 info_span!(
85 target: "rig::completions",
86 "chat_streaming",
87 gen_ai.operation.name = "chat_streaming",
88 gen_ai.provider.name = "gcp.gemini",
89 gen_ai.request.model = &request_model,
90 gen_ai.system_instructions = &completion_request.preamble,
91 gen_ai.response.id = tracing::field::Empty,
92 gen_ai.response.model = &request_model,
93 gen_ai.usage.output_tokens = tracing::field::Empty,
94 gen_ai.usage.input_tokens = tracing::field::Empty,
95 )
96 } else {
97 tracing::Span::current()
98 };
99 let request = create_request_body(completion_request)?;
100
101 if enabled!(Level::TRACE) {
102 tracing::trace!(
103 target: "rig::streaming",
104 "Gemini streaming completion request: {}",
105 serde_json::to_string_pretty(&request)?
106 );
107 }
108
109 let body = serde_json::to_vec(&request)?;
110
111 let req = self
112 .client
113 .post_sse(streaming_endpoint(&request_model))?
114 .header("Content-Type", "application/json")
115 .body(body)
116 .map_err(|e| CompletionError::HttpError(e.into()))?;
117
118 let mut event_source = GenericEventSource::new(self.client.clone(), req);
119
120 let stream = stream! {
121 let mut final_usage = None;
122 while let Some(event_result) = event_source.next().await {
123 match event_result {
124 Ok(Event::Open) => {
125 tracing::debug!("SSE connection opened");
126 continue;
127 }
128 Ok(Event::Message(message)) => {
129 if message.data.trim().is_empty() {
131 continue;
132 }
133
134 let data = match serde_json::from_str::<StreamGenerateContentResponse>(&message.data) {
135 Ok(d) => d,
136 Err(error) => {
137 tracing::error!(?error, message = message.data, "Failed to parse SSE message");
138 continue;
139 }
140 };
141
142 let Some(choice) = data.candidates.into_iter().next() else {
144 tracing::debug!("There is no content candidate");
145 continue;
146 };
147
148 let Some(content) = choice.content else {
149 tracing::debug!(finish_reason = ?choice.finish_reason, "Streaming candidate missing content");
150 continue;
151 };
152
153 if content.parts.is_empty() {
154 tracing::trace!(reason = ?choice.finish_reason, "There is no part in the streaming content");
155 }
156
157 for part in content.parts {
158 match part {
159 Part {
160 part: PartKind::Text(text),
161 thought: Some(true),
162 thought_signature,
163 ..
164 } => {
165 if !text.is_empty() {
166 if thought_signature.is_some() {
167 yield Ok(streaming::RawStreamingChoice::Reasoning {
172 id: None,
173 content: ReasoningContent::Text {
174 text,
175 signature: thought_signature,
176 },
177 });
178 } else {
179 yield Ok(streaming::RawStreamingChoice::ReasoningDelta {
180 id: None,
181 reasoning: text,
182 });
183 }
184 }
185 },
186 Part {
187 part: PartKind::Text(text),
188 ..
189 } => {
190 if !text.is_empty() {
191 yield Ok(streaming::RawStreamingChoice::Message(text));
192 }
193 },
194 Part {
195 part: PartKind::FunctionCall(function_call),
196 thought_signature,
197 ..
198 } => {
199 yield Ok(streaming::RawStreamingChoice::ToolCall(
200 streaming::RawStreamingToolCall::new(function_call.name.clone(), function_call.name.clone(), function_call.args.clone())
201 .with_signature(thought_signature)
202 ));
203 },
204 part => {
205 tracing::warn!(?part, "Unsupported response type with streaming");
206 }
207 }
208 }
209
210 if choice.finish_reason.is_some() {
212 let span = tracing::Span::current();
213 span.record_token_usage(&data.usage_metadata);
214 final_usage = data.usage_metadata;
215 break;
216 }
217 }
218 Err(crate::http_client::Error::StreamEnded) => {
219 break;
220 }
221 Err(error) => {
222 tracing::error!(?error, "SSE error");
223 yield Err(CompletionError::ProviderError(error.to_string()));
224 break;
225 }
226 }
227 }
228
229 event_source.close();
231
232 yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
233 usage_metadata: final_usage.unwrap_or_default()
234 }));
235 }.instrument(span);
236
237 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
238 stream,
239 )))
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use serde_json::json;
247
248 #[test]
249 fn test_deserialize_stream_response_with_single_text_part() {
250 let json_data = json!({
251 "candidates": [{
252 "content": {
253 "parts": [
254 {"text": "Hello, world!"}
255 ],
256 "role": "model"
257 },
258 "finishReason": "STOP",
259 "index": 0
260 }],
261 "usageMetadata": {
262 "promptTokenCount": 10,
263 "candidatesTokenCount": 5,
264 "totalTokenCount": 15
265 }
266 });
267
268 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
269 assert_eq!(response.candidates.len(), 1);
270 let content = response.candidates[0]
271 .content
272 .as_ref()
273 .expect("candidate should contain content");
274 assert_eq!(content.parts.len(), 1);
275
276 if let Part {
277 part: PartKind::Text(text),
278 ..
279 } = &content.parts[0]
280 {
281 assert_eq!(text, "Hello, world!");
282 } else {
283 panic!("Expected text part");
284 }
285 }
286
287 #[test]
288 fn test_deserialize_stream_response_with_multiple_text_parts() {
289 let json_data = json!({
290 "candidates": [{
291 "content": {
292 "parts": [
293 {"text": "Hello, "},
294 {"text": "world!"},
295 {"text": " How are you?"}
296 ],
297 "role": "model"
298 },
299 "finishReason": "STOP",
300 "index": 0
301 }],
302 "usageMetadata": {
303 "promptTokenCount": 10,
304 "candidatesTokenCount": 8,
305 "totalTokenCount": 18
306 }
307 });
308
309 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
310 assert_eq!(response.candidates.len(), 1);
311 let content = response.candidates[0]
312 .content
313 .as_ref()
314 .expect("candidate should contain content");
315 assert_eq!(content.parts.len(), 3);
316
317 for (i, expected_text) in ["Hello, ", "world!", " How are you?"].iter().enumerate() {
319 if let Part {
320 part: PartKind::Text(text),
321 ..
322 } = &content.parts[i]
323 {
324 assert_eq!(text, expected_text);
325 } else {
326 panic!("Expected text part at index {}", i);
327 }
328 }
329 }
330
331 #[test]
332 fn test_deserialize_stream_response_with_multiple_tool_calls() {
333 let json_data = json!({
334 "candidates": [{
335 "content": {
336 "parts": [
337 {
338 "functionCall": {
339 "name": "get_weather",
340 "args": {"city": "San Francisco"}
341 }
342 },
343 {
344 "functionCall": {
345 "name": "get_temperature",
346 "args": {"location": "New York"}
347 }
348 }
349 ],
350 "role": "model"
351 },
352 "finishReason": "STOP",
353 "index": 0
354 }],
355 "usageMetadata": {
356 "promptTokenCount": 50,
357 "candidatesTokenCount": 20,
358 "totalTokenCount": 70
359 }
360 });
361
362 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
363 let content = response.candidates[0]
364 .content
365 .as_ref()
366 .expect("candidate should contain content");
367 assert_eq!(content.parts.len(), 2);
368
369 if let Part {
371 part: PartKind::FunctionCall(call),
372 ..
373 } = &content.parts[0]
374 {
375 assert_eq!(call.name, "get_weather");
376 } else {
377 panic!("Expected function call at index 0");
378 }
379
380 if let Part {
382 part: PartKind::FunctionCall(call),
383 ..
384 } = &content.parts[1]
385 {
386 assert_eq!(call.name, "get_temperature");
387 } else {
388 panic!("Expected function call at index 1");
389 }
390 }
391
392 #[test]
393 fn test_deserialize_stream_response_with_mixed_parts() {
394 let json_data = json!({
395 "candidates": [{
396 "content": {
397 "parts": [
398 {
399 "text": "Let me think about this...",
400 "thought": true
401 },
402 {
403 "text": "Here's my response: "
404 },
405 {
406 "functionCall": {
407 "name": "search",
408 "args": {"query": "rust async"}
409 }
410 },
411 {
412 "text": "I found the answer!"
413 }
414 ],
415 "role": "model"
416 },
417 "finishReason": "STOP",
418 "index": 0
419 }],
420 "usageMetadata": {
421 "promptTokenCount": 100,
422 "candidatesTokenCount": 50,
423 "thoughtsTokenCount": 15,
424 "totalTokenCount": 165
425 }
426 });
427
428 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
429 let content = response.candidates[0]
430 .content
431 .as_ref()
432 .expect("candidate should contain content");
433 let parts = &content.parts;
434 assert_eq!(parts.len(), 4);
435
436 if let Part {
438 part: PartKind::Text(text),
439 thought: Some(true),
440 ..
441 } = &parts[0]
442 {
443 assert_eq!(text, "Let me think about this...");
444 } else {
445 panic!("Expected thought part at index 0");
446 }
447
448 if let Part {
450 part: PartKind::Text(text),
451 thought,
452 ..
453 } = &parts[1]
454 {
455 assert_eq!(text, "Here's my response: ");
456 assert!(thought.is_none() || thought == &Some(false));
457 } else {
458 panic!("Expected text part at index 1");
459 }
460
461 if let Part {
463 part: PartKind::FunctionCall(call),
464 ..
465 } = &parts[2]
466 {
467 assert_eq!(call.name, "search");
468 } else {
469 panic!("Expected function call at index 2");
470 }
471
472 if let Part {
474 part: PartKind::Text(text),
475 ..
476 } = &parts[3]
477 {
478 assert_eq!(text, "I found the answer!");
479 } else {
480 panic!("Expected text part at index 3");
481 }
482 }
483
484 #[test]
485 fn test_deserialize_stream_response_with_empty_parts() {
486 let json_data = json!({
487 "candidates": [{
488 "content": {
489 "parts": [],
490 "role": "model"
491 },
492 "finishReason": "STOP",
493 "index": 0
494 }],
495 "usageMetadata": {
496 "promptTokenCount": 10,
497 "candidatesTokenCount": 0,
498 "totalTokenCount": 10
499 }
500 });
501
502 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
503 let content = response.candidates[0]
504 .content
505 .as_ref()
506 .expect("candidate should contain content");
507 assert_eq!(content.parts.len(), 0);
508 }
509
510 #[test]
511 fn test_partial_usage_token_calculation() {
512 let usage = PartialUsage {
513 total_token_count: 100,
514 cached_content_token_count: Some(20),
515 candidates_token_count: Some(30),
516 thoughts_token_count: Some(10),
517 prompt_token_count: 40,
518 };
519
520 let token_usage = usage.token_usage().unwrap();
521 assert_eq!(token_usage.input_tokens, 40);
522 assert_eq!(token_usage.output_tokens, 60); assert_eq!(token_usage.total_tokens, 100);
524 }
525
526 #[test]
527 fn test_partial_usage_with_missing_counts() {
528 let usage = PartialUsage {
529 total_token_count: 50,
530 cached_content_token_count: None,
531 candidates_token_count: Some(30),
532 thoughts_token_count: None,
533 prompt_token_count: 20,
534 };
535
536 let token_usage = usage.token_usage().unwrap();
537 assert_eq!(token_usage.input_tokens, 20);
538 assert_eq!(token_usage.output_tokens, 30); assert_eq!(token_usage.total_tokens, 50);
540 }
541
542 #[test]
543 fn test_streaming_completion_response_token_usage() {
544 let response = StreamingCompletionResponse {
545 usage_metadata: PartialUsage {
546 total_token_count: 150,
547 cached_content_token_count: None,
548 candidates_token_count: Some(75),
549 thoughts_token_count: None,
550 prompt_token_count: 75,
551 },
552 };
553
554 let token_usage = response.token_usage().unwrap();
555 assert_eq!(token_usage.input_tokens, 75);
556 assert_eq!(token_usage.output_tokens, 75);
557 assert_eq!(token_usage.total_tokens, 150);
558 }
559}