1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use tracing::{Level, enabled, info_span};
5use tracing_futures::Instrument;
6
7use super::completion::gemini_api_types::{ContentCandidate, Part, PartKind};
8use super::completion::{CompletionModel, create_request_body};
9use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
10use crate::http_client::HttpClientExt;
11use crate::http_client::sse::{Event, GenericEventSource};
12use crate::streaming;
13use crate::telemetry::SpanCombinator;
14
15#[derive(Debug, Deserialize, Serialize, Default, Clone)]
16#[serde(rename_all = "camelCase")]
17pub struct PartialUsage {
18 pub total_token_count: i32,
19 #[serde(skip_serializing_if = "Option::is_none")]
20 pub cached_content_token_count: Option<i32>,
21 #[serde(skip_serializing_if = "Option::is_none")]
22 pub candidates_token_count: Option<i32>,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 pub thoughts_token_count: Option<i32>,
25 pub prompt_token_count: i32,
26}
27
28impl GetTokenUsage for PartialUsage {
29 fn token_usage(&self) -> Option<crate::completion::Usage> {
30 let mut usage = crate::completion::Usage::new();
31
32 usage.input_tokens = self.prompt_token_count as u64;
33 usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
34 + self.candidates_token_count.unwrap_or_default()
35 + self.thoughts_token_count.unwrap_or_default()) as u64;
36 usage.total_tokens = usage.input_tokens + usage.output_tokens;
37
38 Some(usage)
39 }
40}
41
42#[derive(Debug, Deserialize)]
43#[serde(rename_all = "camelCase")]
44pub struct StreamGenerateContentResponse {
45 pub candidates: Vec<ContentCandidate>,
47 pub model_version: Option<String>,
48 pub usage_metadata: Option<PartialUsage>,
49}
50
51#[derive(Clone, Debug, Serialize, Deserialize)]
52pub struct StreamingCompletionResponse {
53 pub usage_metadata: PartialUsage,
54}
55
56impl GetTokenUsage for StreamingCompletionResponse {
57 fn token_usage(&self) -> Option<crate::completion::Usage> {
58 let mut usage = crate::completion::Usage::new();
59 usage.total_tokens = self.usage_metadata.total_token_count as u64;
60 usage.output_tokens = self
61 .usage_metadata
62 .candidates_token_count
63 .map(|x| x as u64)
64 .unwrap_or(0);
65 usage.input_tokens = self.usage_metadata.prompt_token_count as u64;
66 Some(usage)
67 }
68}
69
70impl<T> CompletionModel<T>
71where
72 T: HttpClientExt + Clone + 'static,
73{
74 pub(crate) async fn stream(
75 &self,
76 completion_request: CompletionRequest,
77 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
78 {
79 let span = if tracing::Span::current().is_disabled() {
80 info_span!(
81 target: "rig::completions",
82 "chat_streaming",
83 gen_ai.operation.name = "chat_streaming",
84 gen_ai.provider.name = "gcp.gemini",
85 gen_ai.request.model = self.model,
86 gen_ai.system_instructions = &completion_request.preamble,
87 gen_ai.response.id = tracing::field::Empty,
88 gen_ai.response.model = self.model,
89 gen_ai.usage.output_tokens = tracing::field::Empty,
90 gen_ai.usage.input_tokens = tracing::field::Empty,
91 )
92 } else {
93 tracing::Span::current()
94 };
95 let request = create_request_body(completion_request)?;
96
97 if enabled!(Level::TRACE) {
98 tracing::trace!(
99 target: "rig::streaming",
100 "Gemini streaming completion request: {}",
101 serde_json::to_string_pretty(&request)?
102 );
103 }
104
105 let body = serde_json::to_vec(&request)?;
106
107 let req = self
108 .client
109 .post_sse(format!(
110 "/v1beta/models/{}:streamGenerateContent",
111 self.model
112 ))?
113 .header("Content-Type", "application/json")
114 .body(body)
115 .map_err(|e| CompletionError::HttpError(e.into()))?;
116
117 let mut event_source = GenericEventSource::new(self.client.clone(), req);
118
119 let stream = stream! {
120 let mut final_usage = None;
121 while let Some(event_result) = event_source.next().await {
122 match event_result {
123 Ok(Event::Open) => {
124 tracing::debug!("SSE connection opened");
125 continue;
126 }
127 Ok(Event::Message(message)) => {
128 if message.data.trim().is_empty() {
130 continue;
131 }
132
133 let data = match serde_json::from_str::<StreamGenerateContentResponse>(&message.data) {
134 Ok(d) => d,
135 Err(error) => {
136 tracing::error!(?error, message = message.data, "Failed to parse SSE message");
137 continue;
138 }
139 };
140
141 let Some(choice) = data.candidates.into_iter().next() else {
143 tracing::debug!("There is no content candidate");
144 continue;
145 };
146
147 let Some(content) = choice.content else {
148 tracing::debug!(finish_reason = ?choice.finish_reason, "Streaming candidate missing content");
149 continue;
150 };
151
152 if content.parts.is_empty() {
153 tracing::trace!(reason = ?choice.finish_reason, "There is no part in the streaming content");
154 }
155
156 for part in content.parts {
157 match part {
158 Part {
159 part: PartKind::Text(text),
160 thought: Some(true),
161 ..
162 } => {
163 if !text.is_empty() {
164 yield Ok(streaming::RawStreamingChoice::ReasoningDelta {
165 id: None,
166 reasoning: text,
167 });
168 }
169 },
170 Part {
171 part: PartKind::Text(text),
172 ..
173 } => {
174 if !text.is_empty() {
175 yield Ok(streaming::RawStreamingChoice::Message(text));
176 }
177 },
178 Part {
179 part: PartKind::FunctionCall(function_call),
180 thought_signature,
181 ..
182 } => {
183 yield Ok(streaming::RawStreamingChoice::ToolCall(
184 streaming::RawStreamingToolCall::new(function_call.name.clone(), function_call.name.clone(), function_call.args.clone())
185 .with_signature(thought_signature)
186 ));
187 },
188 part => {
189 tracing::warn!(?part, "Unsupported response type with streaming");
190 }
191 }
192 }
193
194 if choice.finish_reason.is_some() {
196 let span = tracing::Span::current();
197 span.record_token_usage(&data.usage_metadata);
198 final_usage = data.usage_metadata;
199 break;
200 }
201 }
202 Err(crate::http_client::Error::StreamEnded) => {
203 break;
204 }
205 Err(error) => {
206 tracing::error!(?error, "SSE error");
207 yield Err(CompletionError::ProviderError(error.to_string()));
208 break;
209 }
210 }
211 }
212
213 event_source.close();
215
216 yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
217 usage_metadata: final_usage.unwrap_or_default()
218 }));
219 }.instrument(span);
220
221 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
222 stream,
223 )))
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230 use serde_json::json;
231
232 #[test]
233 fn test_deserialize_stream_response_with_single_text_part() {
234 let json_data = json!({
235 "candidates": [{
236 "content": {
237 "parts": [
238 {"text": "Hello, world!"}
239 ],
240 "role": "model"
241 },
242 "finishReason": "STOP",
243 "index": 0
244 }],
245 "usageMetadata": {
246 "promptTokenCount": 10,
247 "candidatesTokenCount": 5,
248 "totalTokenCount": 15
249 }
250 });
251
252 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
253 assert_eq!(response.candidates.len(), 1);
254 let content = response.candidates[0]
255 .content
256 .as_ref()
257 .expect("candidate should contain content");
258 assert_eq!(content.parts.len(), 1);
259
260 if let Part {
261 part: PartKind::Text(text),
262 ..
263 } = &content.parts[0]
264 {
265 assert_eq!(text, "Hello, world!");
266 } else {
267 panic!("Expected text part");
268 }
269 }
270
271 #[test]
272 fn test_deserialize_stream_response_with_multiple_text_parts() {
273 let json_data = json!({
274 "candidates": [{
275 "content": {
276 "parts": [
277 {"text": "Hello, "},
278 {"text": "world!"},
279 {"text": " How are you?"}
280 ],
281 "role": "model"
282 },
283 "finishReason": "STOP",
284 "index": 0
285 }],
286 "usageMetadata": {
287 "promptTokenCount": 10,
288 "candidatesTokenCount": 8,
289 "totalTokenCount": 18
290 }
291 });
292
293 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
294 assert_eq!(response.candidates.len(), 1);
295 let content = response.candidates[0]
296 .content
297 .as_ref()
298 .expect("candidate should contain content");
299 assert_eq!(content.parts.len(), 3);
300
301 for (i, expected_text) in ["Hello, ", "world!", " How are you?"].iter().enumerate() {
303 if let Part {
304 part: PartKind::Text(text),
305 ..
306 } = &content.parts[i]
307 {
308 assert_eq!(text, expected_text);
309 } else {
310 panic!("Expected text part at index {}", i);
311 }
312 }
313 }
314
315 #[test]
316 fn test_deserialize_stream_response_with_multiple_tool_calls() {
317 let json_data = json!({
318 "candidates": [{
319 "content": {
320 "parts": [
321 {
322 "functionCall": {
323 "name": "get_weather",
324 "args": {"city": "San Francisco"}
325 }
326 },
327 {
328 "functionCall": {
329 "name": "get_temperature",
330 "args": {"location": "New York"}
331 }
332 }
333 ],
334 "role": "model"
335 },
336 "finishReason": "STOP",
337 "index": 0
338 }],
339 "usageMetadata": {
340 "promptTokenCount": 50,
341 "candidatesTokenCount": 20,
342 "totalTokenCount": 70
343 }
344 });
345
346 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
347 let content = response.candidates[0]
348 .content
349 .as_ref()
350 .expect("candidate should contain content");
351 assert_eq!(content.parts.len(), 2);
352
353 if let Part {
355 part: PartKind::FunctionCall(call),
356 ..
357 } = &content.parts[0]
358 {
359 assert_eq!(call.name, "get_weather");
360 } else {
361 panic!("Expected function call at index 0");
362 }
363
364 if let Part {
366 part: PartKind::FunctionCall(call),
367 ..
368 } = &content.parts[1]
369 {
370 assert_eq!(call.name, "get_temperature");
371 } else {
372 panic!("Expected function call at index 1");
373 }
374 }
375
376 #[test]
377 fn test_deserialize_stream_response_with_mixed_parts() {
378 let json_data = json!({
379 "candidates": [{
380 "content": {
381 "parts": [
382 {
383 "text": "Let me think about this...",
384 "thought": true
385 },
386 {
387 "text": "Here's my response: "
388 },
389 {
390 "functionCall": {
391 "name": "search",
392 "args": {"query": "rust async"}
393 }
394 },
395 {
396 "text": "I found the answer!"
397 }
398 ],
399 "role": "model"
400 },
401 "finishReason": "STOP",
402 "index": 0
403 }],
404 "usageMetadata": {
405 "promptTokenCount": 100,
406 "candidatesTokenCount": 50,
407 "thoughtsTokenCount": 15,
408 "totalTokenCount": 165
409 }
410 });
411
412 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
413 let content = response.candidates[0]
414 .content
415 .as_ref()
416 .expect("candidate should contain content");
417 let parts = &content.parts;
418 assert_eq!(parts.len(), 4);
419
420 if let Part {
422 part: PartKind::Text(text),
423 thought: Some(true),
424 ..
425 } = &parts[0]
426 {
427 assert_eq!(text, "Let me think about this...");
428 } else {
429 panic!("Expected thought part at index 0");
430 }
431
432 if let Part {
434 part: PartKind::Text(text),
435 thought,
436 ..
437 } = &parts[1]
438 {
439 assert_eq!(text, "Here's my response: ");
440 assert!(thought.is_none() || thought == &Some(false));
441 } else {
442 panic!("Expected text part at index 1");
443 }
444
445 if let Part {
447 part: PartKind::FunctionCall(call),
448 ..
449 } = &parts[2]
450 {
451 assert_eq!(call.name, "search");
452 } else {
453 panic!("Expected function call at index 2");
454 }
455
456 if let Part {
458 part: PartKind::Text(text),
459 ..
460 } = &parts[3]
461 {
462 assert_eq!(text, "I found the answer!");
463 } else {
464 panic!("Expected text part at index 3");
465 }
466 }
467
468 #[test]
469 fn test_deserialize_stream_response_with_empty_parts() {
470 let json_data = json!({
471 "candidates": [{
472 "content": {
473 "parts": [],
474 "role": "model"
475 },
476 "finishReason": "STOP",
477 "index": 0
478 }],
479 "usageMetadata": {
480 "promptTokenCount": 10,
481 "candidatesTokenCount": 0,
482 "totalTokenCount": 10
483 }
484 });
485
486 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
487 let content = response.candidates[0]
488 .content
489 .as_ref()
490 .expect("candidate should contain content");
491 assert_eq!(content.parts.len(), 0);
492 }
493
494 #[test]
495 fn test_partial_usage_token_calculation() {
496 let usage = PartialUsage {
497 total_token_count: 100,
498 cached_content_token_count: Some(20),
499 candidates_token_count: Some(30),
500 thoughts_token_count: Some(10),
501 prompt_token_count: 40,
502 };
503
504 let token_usage = usage.token_usage().unwrap();
505 assert_eq!(token_usage.input_tokens, 40);
506 assert_eq!(token_usage.output_tokens, 60); assert_eq!(token_usage.total_tokens, 100);
508 }
509
510 #[test]
511 fn test_partial_usage_with_missing_counts() {
512 let usage = PartialUsage {
513 total_token_count: 50,
514 cached_content_token_count: None,
515 candidates_token_count: Some(30),
516 thoughts_token_count: None,
517 prompt_token_count: 20,
518 };
519
520 let token_usage = usage.token_usage().unwrap();
521 assert_eq!(token_usage.input_tokens, 20);
522 assert_eq!(token_usage.output_tokens, 30); assert_eq!(token_usage.total_tokens, 50);
524 }
525
526 #[test]
527 fn test_streaming_completion_response_token_usage() {
528 let response = StreamingCompletionResponse {
529 usage_metadata: PartialUsage {
530 total_token_count: 150,
531 cached_content_token_count: None,
532 candidates_token_count: Some(75),
533 thoughts_token_count: None,
534 prompt_token_count: 75,
535 },
536 };
537
538 let token_usage = response.token_usage().unwrap();
539 assert_eq!(token_usage.input_tokens, 75);
540 assert_eq!(token_usage.output_tokens, 75);
541 assert_eq!(token_usage.total_tokens, 150);
542 }
543}