1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use tracing::{Level, enabled, info_span};
5use tracing_futures::Instrument;
6
7use super::completion::gemini_api_types::{ContentCandidate, Part, PartKind};
8use super::completion::{
9 CompletionModel, create_request_body, resolve_request_model, streaming_endpoint,
10};
11use crate::completion::message::ReasoningContent;
12use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
13use crate::http_client::HttpClientExt;
14use crate::http_client::sse::{Event, GenericEventSource};
15use crate::streaming;
16use crate::telemetry::SpanCombinator;
17
18#[derive(Debug, Deserialize, Serialize, Default, Clone)]
19#[serde(rename_all = "camelCase")]
20pub struct PartialUsage {
21 pub total_token_count: i32,
22 #[serde(skip_serializing_if = "Option::is_none")]
23 pub cached_content_token_count: Option<i32>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub candidates_token_count: Option<i32>,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub thoughts_token_count: Option<i32>,
28 pub prompt_token_count: i32,
29}
30
31impl GetTokenUsage for PartialUsage {
32 fn token_usage(&self) -> Option<crate::completion::Usage> {
33 let mut usage = crate::completion::Usage::new();
34
35 usage.input_tokens = self.prompt_token_count as u64;
36 usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
37 + self.candidates_token_count.unwrap_or_default()
38 + self.thoughts_token_count.unwrap_or_default()) as u64;
39 usage.total_tokens = usage.input_tokens + usage.output_tokens;
40
41 Some(usage)
42 }
43}
44
45#[derive(Debug, Deserialize)]
46#[serde(rename_all = "camelCase")]
47pub struct StreamGenerateContentResponse {
48 pub candidates: Vec<ContentCandidate>,
50 pub model_version: Option<String>,
51 pub usage_metadata: Option<PartialUsage>,
52}
53
54#[derive(Clone, Debug, Serialize, Deserialize)]
55pub struct StreamingCompletionResponse {
56 pub usage_metadata: PartialUsage,
57}
58
59impl GetTokenUsage for StreamingCompletionResponse {
60 fn token_usage(&self) -> Option<crate::completion::Usage> {
61 let mut usage = crate::completion::Usage::new();
62 usage.total_tokens = self.usage_metadata.total_token_count as u64;
63 usage.output_tokens = self
64 .usage_metadata
65 .candidates_token_count
66 .map(|x| x as u64)
67 .unwrap_or(0);
68 usage.input_tokens = self.usage_metadata.prompt_token_count as u64;
69 Some(usage)
70 }
71}
72
73impl<T> CompletionModel<T>
74where
75 T: HttpClientExt + Clone + 'static,
76{
77 pub(crate) async fn stream(
78 &self,
79 completion_request: CompletionRequest,
80 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
81 {
82 let request_model = resolve_request_model(&self.model, &completion_request);
83 let span = if tracing::Span::current().is_disabled() {
84 info_span!(
85 target: "rig::completions",
86 "chat_streaming",
87 gen_ai.operation.name = "chat_streaming",
88 gen_ai.provider.name = "gcp.gemini",
89 gen_ai.request.model = &request_model,
90 gen_ai.system_instructions = &completion_request.preamble,
91 gen_ai.response.id = tracing::field::Empty,
92 gen_ai.response.model = &request_model,
93 gen_ai.usage.output_tokens = tracing::field::Empty,
94 gen_ai.usage.input_tokens = tracing::field::Empty,
95 gen_ai.usage.cached_tokens = tracing::field::Empty,
96 )
97 } else {
98 tracing::Span::current()
99 };
100 let request = create_request_body(completion_request)?;
101
102 if enabled!(Level::TRACE) {
103 tracing::trace!(
104 target: "rig::streaming",
105 "Gemini streaming completion request: {}",
106 serde_json::to_string_pretty(&request)?
107 );
108 }
109
110 let body = serde_json::to_vec(&request)?;
111
112 let req = self
113 .client
114 .post_sse(streaming_endpoint(&request_model))?
115 .header("Content-Type", "application/json")
116 .body(body)
117 .map_err(|e| CompletionError::HttpError(e.into()))?;
118
119 let mut event_source = GenericEventSource::new(self.client.clone(), req);
120
121 let stream = stream! {
122 let mut final_usage = None;
123 while let Some(event_result) = event_source.next().await {
124 match event_result {
125 Ok(Event::Open) => {
126 tracing::debug!("SSE connection opened");
127 continue;
128 }
129 Ok(Event::Message(message)) => {
130 if message.data.trim().is_empty() {
132 continue;
133 }
134
135 let data = match serde_json::from_str::<StreamGenerateContentResponse>(&message.data) {
136 Ok(d) => d,
137 Err(error) => {
138 tracing::error!(?error, message = message.data, "Failed to parse SSE message");
139 continue;
140 }
141 };
142
143 let Some(choice) = data.candidates.into_iter().next() else {
145 tracing::debug!("There is no content candidate");
146 continue;
147 };
148
149 let Some(content) = choice.content else {
150 tracing::debug!(finish_reason = ?choice.finish_reason, "Streaming candidate missing content");
151 continue;
152 };
153
154 if content.parts.is_empty() {
155 tracing::trace!(reason = ?choice.finish_reason, "There is no part in the streaming content");
156 }
157
158 for part in content.parts {
159 match part {
160 Part {
161 part: PartKind::Text(text),
162 thought: Some(true),
163 thought_signature,
164 ..
165 } => {
166 if !text.is_empty() {
167 if thought_signature.is_some() {
168 yield Ok(streaming::RawStreamingChoice::Reasoning {
173 id: None,
174 content: ReasoningContent::Text {
175 text,
176 signature: thought_signature,
177 },
178 });
179 } else {
180 yield Ok(streaming::RawStreamingChoice::ReasoningDelta {
181 id: None,
182 reasoning: text,
183 });
184 }
185 }
186 },
187 Part {
188 part: PartKind::Text(text),
189 ..
190 } => {
191 if !text.is_empty() {
192 yield Ok(streaming::RawStreamingChoice::Message(text));
193 }
194 },
195 Part {
196 part: PartKind::FunctionCall(function_call),
197 thought_signature,
198 ..
199 } => {
200 yield Ok(streaming::RawStreamingChoice::ToolCall(
201 streaming::RawStreamingToolCall::new(function_call.name.clone(), function_call.name.clone(), function_call.args.clone())
202 .with_signature(thought_signature)
203 ));
204 },
205 part => {
206 tracing::warn!(?part, "Unsupported response type with streaming");
207 }
208 }
209 }
210
211 if choice.finish_reason.is_some() {
213 let span = tracing::Span::current();
214 span.record_token_usage(&data.usage_metadata);
215 final_usage = data.usage_metadata;
216 break;
217 }
218 }
219 Err(crate::http_client::Error::StreamEnded) => {
220 break;
221 }
222 Err(error) => {
223 tracing::error!(?error, "SSE error");
224 yield Err(CompletionError::ProviderError(error.to_string()));
225 break;
226 }
227 }
228 }
229
230 event_source.close();
232
233 yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
234 usage_metadata: final_usage.unwrap_or_default()
235 }));
236 }.instrument(span);
237
238 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
239 stream,
240 )))
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use serde_json::json;
248
249 #[test]
250 fn test_deserialize_stream_response_with_single_text_part() {
251 let json_data = json!({
252 "candidates": [{
253 "content": {
254 "parts": [
255 {"text": "Hello, world!"}
256 ],
257 "role": "model"
258 },
259 "finishReason": "STOP",
260 "index": 0
261 }],
262 "usageMetadata": {
263 "promptTokenCount": 10,
264 "candidatesTokenCount": 5,
265 "totalTokenCount": 15
266 }
267 });
268
269 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
270 assert_eq!(response.candidates.len(), 1);
271 let content = response.candidates[0]
272 .content
273 .as_ref()
274 .expect("candidate should contain content");
275 assert_eq!(content.parts.len(), 1);
276
277 if let Part {
278 part: PartKind::Text(text),
279 ..
280 } = &content.parts[0]
281 {
282 assert_eq!(text, "Hello, world!");
283 } else {
284 panic!("Expected text part");
285 }
286 }
287
288 #[test]
289 fn test_deserialize_stream_response_with_multiple_text_parts() {
290 let json_data = json!({
291 "candidates": [{
292 "content": {
293 "parts": [
294 {"text": "Hello, "},
295 {"text": "world!"},
296 {"text": " How are you?"}
297 ],
298 "role": "model"
299 },
300 "finishReason": "STOP",
301 "index": 0
302 }],
303 "usageMetadata": {
304 "promptTokenCount": 10,
305 "candidatesTokenCount": 8,
306 "totalTokenCount": 18
307 }
308 });
309
310 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
311 assert_eq!(response.candidates.len(), 1);
312 let content = response.candidates[0]
313 .content
314 .as_ref()
315 .expect("candidate should contain content");
316 assert_eq!(content.parts.len(), 3);
317
318 for (i, expected_text) in ["Hello, ", "world!", " How are you?"].iter().enumerate() {
320 if let Part {
321 part: PartKind::Text(text),
322 ..
323 } = &content.parts[i]
324 {
325 assert_eq!(text, expected_text);
326 } else {
327 panic!("Expected text part at index {}", i);
328 }
329 }
330 }
331
332 #[test]
333 fn test_deserialize_stream_response_with_multiple_tool_calls() {
334 let json_data = json!({
335 "candidates": [{
336 "content": {
337 "parts": [
338 {
339 "functionCall": {
340 "name": "get_weather",
341 "args": {"city": "San Francisco"}
342 }
343 },
344 {
345 "functionCall": {
346 "name": "get_temperature",
347 "args": {"location": "New York"}
348 }
349 }
350 ],
351 "role": "model"
352 },
353 "finishReason": "STOP",
354 "index": 0
355 }],
356 "usageMetadata": {
357 "promptTokenCount": 50,
358 "candidatesTokenCount": 20,
359 "totalTokenCount": 70
360 }
361 });
362
363 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
364 let content = response.candidates[0]
365 .content
366 .as_ref()
367 .expect("candidate should contain content");
368 assert_eq!(content.parts.len(), 2);
369
370 if let Part {
372 part: PartKind::FunctionCall(call),
373 ..
374 } = &content.parts[0]
375 {
376 assert_eq!(call.name, "get_weather");
377 } else {
378 panic!("Expected function call at index 0");
379 }
380
381 if let Part {
383 part: PartKind::FunctionCall(call),
384 ..
385 } = &content.parts[1]
386 {
387 assert_eq!(call.name, "get_temperature");
388 } else {
389 panic!("Expected function call at index 1");
390 }
391 }
392
393 #[test]
394 fn test_deserialize_stream_response_with_mixed_parts() {
395 let json_data = json!({
396 "candidates": [{
397 "content": {
398 "parts": [
399 {
400 "text": "Let me think about this...",
401 "thought": true
402 },
403 {
404 "text": "Here's my response: "
405 },
406 {
407 "functionCall": {
408 "name": "search",
409 "args": {"query": "rust async"}
410 }
411 },
412 {
413 "text": "I found the answer!"
414 }
415 ],
416 "role": "model"
417 },
418 "finishReason": "STOP",
419 "index": 0
420 }],
421 "usageMetadata": {
422 "promptTokenCount": 100,
423 "candidatesTokenCount": 50,
424 "thoughtsTokenCount": 15,
425 "totalTokenCount": 165
426 }
427 });
428
429 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
430 let content = response.candidates[0]
431 .content
432 .as_ref()
433 .expect("candidate should contain content");
434 let parts = &content.parts;
435 assert_eq!(parts.len(), 4);
436
437 if let Part {
439 part: PartKind::Text(text),
440 thought: Some(true),
441 ..
442 } = &parts[0]
443 {
444 assert_eq!(text, "Let me think about this...");
445 } else {
446 panic!("Expected thought part at index 0");
447 }
448
449 if let Part {
451 part: PartKind::Text(text),
452 thought,
453 ..
454 } = &parts[1]
455 {
456 assert_eq!(text, "Here's my response: ");
457 assert!(thought.is_none() || thought == &Some(false));
458 } else {
459 panic!("Expected text part at index 1");
460 }
461
462 if let Part {
464 part: PartKind::FunctionCall(call),
465 ..
466 } = &parts[2]
467 {
468 assert_eq!(call.name, "search");
469 } else {
470 panic!("Expected function call at index 2");
471 }
472
473 if let Part {
475 part: PartKind::Text(text),
476 ..
477 } = &parts[3]
478 {
479 assert_eq!(text, "I found the answer!");
480 } else {
481 panic!("Expected text part at index 3");
482 }
483 }
484
485 #[test]
486 fn test_deserialize_stream_response_with_empty_parts() {
487 let json_data = json!({
488 "candidates": [{
489 "content": {
490 "parts": [],
491 "role": "model"
492 },
493 "finishReason": "STOP",
494 "index": 0
495 }],
496 "usageMetadata": {
497 "promptTokenCount": 10,
498 "candidatesTokenCount": 0,
499 "totalTokenCount": 10
500 }
501 });
502
503 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
504 let content = response.candidates[0]
505 .content
506 .as_ref()
507 .expect("candidate should contain content");
508 assert_eq!(content.parts.len(), 0);
509 }
510
511 #[test]
512 fn test_partial_usage_token_calculation() {
513 let usage = PartialUsage {
514 total_token_count: 100,
515 cached_content_token_count: Some(20),
516 candidates_token_count: Some(30),
517 thoughts_token_count: Some(10),
518 prompt_token_count: 40,
519 };
520
521 let token_usage = usage.token_usage().unwrap();
522 assert_eq!(token_usage.input_tokens, 40);
523 assert_eq!(token_usage.output_tokens, 60); assert_eq!(token_usage.total_tokens, 100);
525 }
526
527 #[test]
528 fn test_partial_usage_with_missing_counts() {
529 let usage = PartialUsage {
530 total_token_count: 50,
531 cached_content_token_count: None,
532 candidates_token_count: Some(30),
533 thoughts_token_count: None,
534 prompt_token_count: 20,
535 };
536
537 let token_usage = usage.token_usage().unwrap();
538 assert_eq!(token_usage.input_tokens, 20);
539 assert_eq!(token_usage.output_tokens, 30); assert_eq!(token_usage.total_tokens, 50);
541 }
542
543 #[test]
544 fn test_streaming_completion_response_token_usage() {
545 let response = StreamingCompletionResponse {
546 usage_metadata: PartialUsage {
547 total_token_count: 150,
548 cached_content_token_count: None,
549 candidates_token_count: Some(75),
550 thoughts_token_count: None,
551 prompt_token_count: 75,
552 },
553 };
554
555 let token_usage = response.token_usage().unwrap();
556 assert_eq!(token_usage.input_tokens, 75);
557 assert_eq!(token_usage.output_tokens, 75);
558 assert_eq!(token_usage.total_tokens, 150);
559 }
560}