1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use tracing::{Level, enabled, info_span};
5use tracing_futures::Instrument;
6
7use super::completion::gemini_api_types::{ContentCandidate, Part, PartKind};
8use super::completion::{
9 CompletionModel, create_request_body, resolve_request_model, streaming_endpoint,
10};
11use crate::completion::message::ReasoningContent;
12use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
13use crate::http_client::HttpClientExt;
14use crate::http_client::sse::{Event, GenericEventSource};
15use crate::streaming;
16use crate::telemetry::SpanCombinator;
17
18#[derive(Debug, Deserialize, Serialize, Default, Clone)]
19#[serde(rename_all = "camelCase")]
20pub struct PartialUsage {
21 pub total_token_count: i32,
22 #[serde(skip_serializing_if = "Option::is_none")]
23 pub cached_content_token_count: Option<i32>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub candidates_token_count: Option<i32>,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub thoughts_token_count: Option<i32>,
28 #[serde(default)]
29 pub prompt_token_count: i32,
30}
31
32impl GetTokenUsage for PartialUsage {
33 fn token_usage(&self) -> Option<crate::completion::Usage> {
34 let mut usage = crate::completion::Usage::new();
35
36 usage.input_tokens = self.prompt_token_count as u64;
37 usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
38 + self.candidates_token_count.unwrap_or_default()
39 + self.thoughts_token_count.unwrap_or_default()) as u64;
40 usage.total_tokens = usage.input_tokens + usage.output_tokens;
41
42 Some(usage)
43 }
44}
45
46#[derive(Debug, Deserialize)]
47#[serde(rename_all = "camelCase")]
48pub struct StreamGenerateContentResponse {
49 pub candidates: Vec<ContentCandidate>,
51 pub model_version: Option<String>,
52 pub usage_metadata: Option<PartialUsage>,
53}
54
55#[derive(Clone, Debug, Serialize, Deserialize)]
56pub struct StreamingCompletionResponse {
57 pub usage_metadata: PartialUsage,
58}
59
60impl GetTokenUsage for StreamingCompletionResponse {
61 fn token_usage(&self) -> Option<crate::completion::Usage> {
62 let mut usage = crate::completion::Usage::new();
63 usage.total_tokens = self.usage_metadata.total_token_count as u64;
64 usage.output_tokens = self
65 .usage_metadata
66 .candidates_token_count
67 .map(|x| x as u64)
68 .unwrap_or(0);
69 usage.input_tokens = self.usage_metadata.prompt_token_count as u64;
70 Some(usage)
71 }
72}
73
74impl<T> CompletionModel<T>
75where
76 T: HttpClientExt + Clone + 'static,
77{
78 pub(crate) async fn stream(
79 &self,
80 completion_request: CompletionRequest,
81 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
82 {
83 let request_model = resolve_request_model(&self.model, &completion_request);
84 let span = if tracing::Span::current().is_disabled() {
85 info_span!(
86 target: "rig::completions",
87 "chat_streaming",
88 gen_ai.operation.name = "chat_streaming",
89 gen_ai.provider.name = "gcp.gemini",
90 gen_ai.request.model = &request_model,
91 gen_ai.system_instructions = &completion_request.preamble,
92 gen_ai.response.id = tracing::field::Empty,
93 gen_ai.response.model = &request_model,
94 gen_ai.usage.output_tokens = tracing::field::Empty,
95 gen_ai.usage.input_tokens = tracing::field::Empty,
96 gen_ai.usage.cached_tokens = tracing::field::Empty,
97 )
98 } else {
99 tracing::Span::current()
100 };
101 let request = create_request_body(completion_request)?;
102
103 if enabled!(Level::TRACE) {
104 tracing::trace!(
105 target: "rig::streaming",
106 "Gemini streaming completion request: {}",
107 serde_json::to_string_pretty(&request)?
108 );
109 }
110
111 let body = serde_json::to_vec(&request)?;
112
113 let req = self
114 .client
115 .post_sse(streaming_endpoint(&request_model))?
116 .header("Content-Type", "application/json")
117 .body(body)
118 .map_err(|e| CompletionError::HttpError(e.into()))?;
119
120 let mut event_source = GenericEventSource::new(self.client.clone(), req);
121
122 let stream = stream! {
123 let mut final_usage = None;
124 while let Some(event_result) = event_source.next().await {
125 match event_result {
126 Ok(Event::Open) => {
127 tracing::debug!("SSE connection opened");
128 continue;
129 }
130 Ok(Event::Message(message)) => {
131 if message.data.trim().is_empty() {
133 continue;
134 }
135
136 let data = match serde_json::from_str::<StreamGenerateContentResponse>(&message.data) {
137 Ok(d) => d,
138 Err(error) => {
139 tracing::error!(?error, message = message.data, "Failed to parse SSE message");
140 continue;
141 }
142 };
143
144 let Some(choice) = data.candidates.into_iter().next() else {
146 tracing::debug!("There is no content candidate");
147 continue;
148 };
149
150 let Some(content) = choice.content else {
151 tracing::debug!(finish_reason = ?choice.finish_reason, "Streaming candidate missing content");
152 continue;
153 };
154
155 if content.parts.is_empty() {
156 tracing::trace!(reason = ?choice.finish_reason, "There is no part in the streaming content");
157 }
158
159 for part in content.parts {
160 match part {
161 Part {
162 part: PartKind::Text(text),
163 thought: Some(true),
164 thought_signature,
165 ..
166 } => {
167 if !text.is_empty() {
168 if thought_signature.is_some() {
169 yield Ok(streaming::RawStreamingChoice::Reasoning {
174 id: None,
175 content: ReasoningContent::Text {
176 text,
177 signature: thought_signature,
178 },
179 });
180 } else {
181 yield Ok(streaming::RawStreamingChoice::ReasoningDelta {
182 id: None,
183 reasoning: text,
184 });
185 }
186 }
187 },
188 Part {
189 part: PartKind::Text(text),
190 ..
191 } => {
192 if !text.is_empty() {
193 yield Ok(streaming::RawStreamingChoice::Message(text));
194 }
195 },
196 Part {
197 part: PartKind::FunctionCall(function_call),
198 thought_signature,
199 ..
200 } => {
201 yield Ok(streaming::RawStreamingChoice::ToolCall(
202 streaming::RawStreamingToolCall::new(function_call.name.clone(), function_call.name.clone(), function_call.args.clone())
203 .with_signature(thought_signature)
204 ));
205 },
206 part => {
207 tracing::warn!(?part, "Unsupported response type with streaming");
208 }
209 }
210 }
211
212 if choice.finish_reason.is_some() {
214 let span = tracing::Span::current();
215 span.record_token_usage(&data.usage_metadata);
216 final_usage = data.usage_metadata;
217 break;
218 }
219 }
220 Err(crate::http_client::Error::StreamEnded) => {
221 break;
222 }
223 Err(error) => {
224 tracing::error!(?error, "SSE error");
225 yield Err(CompletionError::ProviderError(error.to_string()));
226 break;
227 }
228 }
229 }
230
231 event_source.close();
233
234 yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
235 usage_metadata: final_usage.unwrap_or_default()
236 }));
237 }.instrument(span);
238
239 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
240 stream,
241 )))
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248 use serde_json::json;
249
250 #[test]
251 fn test_deserialize_stream_response_with_single_text_part() {
252 let json_data = json!({
253 "candidates": [{
254 "content": {
255 "parts": [
256 {"text": "Hello, world!"}
257 ],
258 "role": "model"
259 },
260 "finishReason": "STOP",
261 "index": 0
262 }],
263 "usageMetadata": {
264 "promptTokenCount": 10,
265 "candidatesTokenCount": 5,
266 "totalTokenCount": 15
267 }
268 });
269
270 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
271 assert_eq!(response.candidates.len(), 1);
272 let content = response.candidates[0]
273 .content
274 .as_ref()
275 .expect("candidate should contain content");
276 assert_eq!(content.parts.len(), 1);
277
278 if let Part {
279 part: PartKind::Text(text),
280 ..
281 } = &content.parts[0]
282 {
283 assert_eq!(text, "Hello, world!");
284 } else {
285 panic!("Expected text part");
286 }
287 }
288
289 #[test]
290 fn test_deserialize_stream_response_with_multiple_text_parts() {
291 let json_data = json!({
292 "candidates": [{
293 "content": {
294 "parts": [
295 {"text": "Hello, "},
296 {"text": "world!"},
297 {"text": " How are you?"}
298 ],
299 "role": "model"
300 },
301 "finishReason": "STOP",
302 "index": 0
303 }],
304 "usageMetadata": {
305 "promptTokenCount": 10,
306 "candidatesTokenCount": 8,
307 "totalTokenCount": 18
308 }
309 });
310
311 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
312 assert_eq!(response.candidates.len(), 1);
313 let content = response.candidates[0]
314 .content
315 .as_ref()
316 .expect("candidate should contain content");
317 assert_eq!(content.parts.len(), 3);
318
319 for (i, expected_text) in ["Hello, ", "world!", " How are you?"].iter().enumerate() {
321 if let Part {
322 part: PartKind::Text(text),
323 ..
324 } = &content.parts[i]
325 {
326 assert_eq!(text, expected_text);
327 } else {
328 panic!("Expected text part at index {}", i);
329 }
330 }
331 }
332
333 #[test]
334 fn test_deserialize_stream_response_with_multiple_tool_calls() {
335 let json_data = json!({
336 "candidates": [{
337 "content": {
338 "parts": [
339 {
340 "functionCall": {
341 "name": "get_weather",
342 "args": {"city": "San Francisco"}
343 }
344 },
345 {
346 "functionCall": {
347 "name": "get_temperature",
348 "args": {"location": "New York"}
349 }
350 }
351 ],
352 "role": "model"
353 },
354 "finishReason": "STOP",
355 "index": 0
356 }],
357 "usageMetadata": {
358 "promptTokenCount": 50,
359 "candidatesTokenCount": 20,
360 "totalTokenCount": 70
361 }
362 });
363
364 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
365 let content = response.candidates[0]
366 .content
367 .as_ref()
368 .expect("candidate should contain content");
369 assert_eq!(content.parts.len(), 2);
370
371 if let Part {
373 part: PartKind::FunctionCall(call),
374 ..
375 } = &content.parts[0]
376 {
377 assert_eq!(call.name, "get_weather");
378 } else {
379 panic!("Expected function call at index 0");
380 }
381
382 if let Part {
384 part: PartKind::FunctionCall(call),
385 ..
386 } = &content.parts[1]
387 {
388 assert_eq!(call.name, "get_temperature");
389 } else {
390 panic!("Expected function call at index 1");
391 }
392 }
393
394 #[test]
395 fn test_deserialize_stream_response_with_mixed_parts() {
396 let json_data = json!({
397 "candidates": [{
398 "content": {
399 "parts": [
400 {
401 "text": "Let me think about this...",
402 "thought": true
403 },
404 {
405 "text": "Here's my response: "
406 },
407 {
408 "functionCall": {
409 "name": "search",
410 "args": {"query": "rust async"}
411 }
412 },
413 {
414 "text": "I found the answer!"
415 }
416 ],
417 "role": "model"
418 },
419 "finishReason": "STOP",
420 "index": 0
421 }],
422 "usageMetadata": {
423 "promptTokenCount": 100,
424 "candidatesTokenCount": 50,
425 "thoughtsTokenCount": 15,
426 "totalTokenCount": 165
427 }
428 });
429
430 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
431 let content = response.candidates[0]
432 .content
433 .as_ref()
434 .expect("candidate should contain content");
435 let parts = &content.parts;
436 assert_eq!(parts.len(), 4);
437
438 if let Part {
440 part: PartKind::Text(text),
441 thought: Some(true),
442 ..
443 } = &parts[0]
444 {
445 assert_eq!(text, "Let me think about this...");
446 } else {
447 panic!("Expected thought part at index 0");
448 }
449
450 if let Part {
452 part: PartKind::Text(text),
453 thought,
454 ..
455 } = &parts[1]
456 {
457 assert_eq!(text, "Here's my response: ");
458 assert!(thought.is_none() || thought == &Some(false));
459 } else {
460 panic!("Expected text part at index 1");
461 }
462
463 if let Part {
465 part: PartKind::FunctionCall(call),
466 ..
467 } = &parts[2]
468 {
469 assert_eq!(call.name, "search");
470 } else {
471 panic!("Expected function call at index 2");
472 }
473
474 if let Part {
476 part: PartKind::Text(text),
477 ..
478 } = &parts[3]
479 {
480 assert_eq!(text, "I found the answer!");
481 } else {
482 panic!("Expected text part at index 3");
483 }
484 }
485
486 #[test]
487 fn test_deserialize_stream_response_with_empty_parts() {
488 let json_data = json!({
489 "candidates": [{
490 "content": {
491 "parts": [],
492 "role": "model"
493 },
494 "finishReason": "STOP",
495 "index": 0
496 }],
497 "usageMetadata": {
498 "promptTokenCount": 10,
499 "candidatesTokenCount": 0,
500 "totalTokenCount": 10
501 }
502 });
503
504 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
505 let content = response.candidates[0]
506 .content
507 .as_ref()
508 .expect("candidate should contain content");
509 assert_eq!(content.parts.len(), 0);
510 }
511
512 #[test]
513 fn test_partial_usage_token_calculation() {
514 let usage = PartialUsage {
515 total_token_count: 100,
516 cached_content_token_count: Some(20),
517 candidates_token_count: Some(30),
518 thoughts_token_count: Some(10),
519 prompt_token_count: 40,
520 };
521
522 let token_usage = usage.token_usage().unwrap();
523 assert_eq!(token_usage.input_tokens, 40);
524 assert_eq!(token_usage.output_tokens, 60); assert_eq!(token_usage.total_tokens, 100);
526 }
527
528 #[test]
529 fn test_partial_usage_with_missing_counts() {
530 let usage = PartialUsage {
531 total_token_count: 50,
532 cached_content_token_count: None,
533 candidates_token_count: Some(30),
534 thoughts_token_count: None,
535 prompt_token_count: 20,
536 };
537
538 let token_usage = usage.token_usage().unwrap();
539 assert_eq!(token_usage.input_tokens, 20);
540 assert_eq!(token_usage.output_tokens, 30); assert_eq!(token_usage.total_tokens, 50);
542 }
543
544 #[test]
545 fn test_streaming_completion_response_token_usage() {
546 let response = StreamingCompletionResponse {
547 usage_metadata: PartialUsage {
548 total_token_count: 150,
549 cached_content_token_count: None,
550 candidates_token_count: Some(75),
551 thoughts_token_count: None,
552 prompt_token_count: 75,
553 },
554 };
555
556 let token_usage = response.token_usage().unwrap();
557 assert_eq!(token_usage.input_tokens, 75);
558 assert_eq!(token_usage.output_tokens, 75);
559 assert_eq!(token_usage.total_tokens, 150);
560 }
561}