1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use tracing::{Level, enabled, info_span};
5use tracing_futures::Instrument;
6
7use super::completion::gemini_api_types::{ContentCandidate, Part, PartKind};
8use super::completion::{CompletionModel, create_request_body};
9use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
10use crate::http_client::HttpClientExt;
11use crate::http_client::sse::{Event, GenericEventSource};
12use crate::streaming;
13use crate::telemetry::SpanCombinator;
14
15#[derive(Debug, Deserialize, Serialize, Default, Clone)]
16#[serde(rename_all = "camelCase")]
17pub struct PartialUsage {
18 pub total_token_count: i32,
19 #[serde(skip_serializing_if = "Option::is_none")]
20 pub cached_content_token_count: Option<i32>,
21 #[serde(skip_serializing_if = "Option::is_none")]
22 pub candidates_token_count: Option<i32>,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 pub thoughts_token_count: Option<i32>,
25 pub prompt_token_count: i32,
26}
27
28impl GetTokenUsage for PartialUsage {
29 fn token_usage(&self) -> Option<crate::completion::Usage> {
30 let mut usage = crate::completion::Usage::new();
31
32 usage.input_tokens = self.prompt_token_count as u64;
33 usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
34 + self.candidates_token_count.unwrap_or_default()
35 + self.thoughts_token_count.unwrap_or_default()) as u64;
36 usage.total_tokens = usage.input_tokens + usage.output_tokens;
37
38 Some(usage)
39 }
40}
41
42#[derive(Debug, Deserialize)]
43#[serde(rename_all = "camelCase")]
44pub struct StreamGenerateContentResponse {
45 pub candidates: Vec<ContentCandidate>,
47 pub model_version: Option<String>,
48 pub usage_metadata: Option<PartialUsage>,
49}
50
51#[derive(Clone, Debug, Serialize, Deserialize)]
52pub struct StreamingCompletionResponse {
53 pub usage_metadata: PartialUsage,
54}
55
56impl GetTokenUsage for StreamingCompletionResponse {
57 fn token_usage(&self) -> Option<crate::completion::Usage> {
58 let mut usage = crate::completion::Usage::new();
59 usage.total_tokens = self.usage_metadata.total_token_count as u64;
60 usage.output_tokens = self
61 .usage_metadata
62 .candidates_token_count
63 .map(|x| x as u64)
64 .unwrap_or(0);
65 usage.input_tokens = self.usage_metadata.prompt_token_count as u64;
66 Some(usage)
67 }
68}
69
70impl<T> CompletionModel<T>
71where
72 T: HttpClientExt + Clone + 'static,
73{
74 pub(crate) async fn stream(
75 &self,
76 completion_request: CompletionRequest,
77 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
78 {
79 let span = if tracing::Span::current().is_disabled() {
80 info_span!(
81 target: "rig::completions",
82 "chat_streaming",
83 gen_ai.operation.name = "chat_streaming",
84 gen_ai.provider.name = "gcp.gemini",
85 gen_ai.request.model = self.model,
86 gen_ai.system_instructions = &completion_request.preamble,
87 gen_ai.response.id = tracing::field::Empty,
88 gen_ai.response.model = self.model,
89 gen_ai.usage.output_tokens = tracing::field::Empty,
90 gen_ai.usage.input_tokens = tracing::field::Empty,
91 )
92 } else {
93 tracing::Span::current()
94 };
95 let request = create_request_body(completion_request)?;
96
97 if enabled!(Level::TRACE) {
98 tracing::trace!(
99 target: "rig::streaming",
100 "Gemini streaming completion request: {}",
101 serde_json::to_string_pretty(&request)?
102 );
103 }
104
105 let body = serde_json::to_vec(&request)?;
106
107 let req = self
108 .client
109 .post_sse(format!(
110 "/v1beta/models/{}:streamGenerateContent",
111 self.model
112 ))?
113 .header("Content-Type", "application/json")
114 .body(body)
115 .map_err(|e| CompletionError::HttpError(e.into()))?;
116
117 let mut event_source = GenericEventSource::new(self.client.clone(), req);
118
119 let stream = stream! {
120 let mut text_response = String::new();
121 let mut model_outputs: Vec<Part> = Vec::new();
122 let mut final_usage = None;
123 while let Some(event_result) = event_source.next().await {
124 match event_result {
125 Ok(Event::Open) => {
126 tracing::debug!("SSE connection opened");
127 continue;
128 }
129 Ok(Event::Message(message)) => {
130 if message.data.trim().is_empty() {
132 continue;
133 }
134
135 let data = match serde_json::from_str::<StreamGenerateContentResponse>(&message.data) {
136 Ok(d) => d,
137 Err(error) => {
138 tracing::error!(?error, message = message.data, "Failed to parse SSE message");
139 continue;
140 }
141 };
142
143 let Some(choice) = data.candidates.first() else {
145 tracing::debug!("There is no content candidate");
146 continue;
147 };
148
149 let Some(content) = choice.content.as_ref() else {
150 tracing::debug!(finish_reason = ?choice.finish_reason, "Streaming candidate missing content");
151 continue;
152 };
153
154 for part in &content.parts {
155 match part {
156 Part {
157 part: PartKind::Text(text),
158 thought: Some(true),
159 ..
160 } => {
161 yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning: text.clone(), id: None, signature: None });
162 },
163 Part {
164 part: PartKind::Text(text),
165 ..
166 } => {
167 text_response.push_str(text);
168 yield Ok(streaming::RawStreamingChoice::Message(text.clone()));
169 },
170 Part {
171 part: PartKind::FunctionCall(function_call),
172 ..
173 } => {
174 model_outputs.push(part.clone());
175 yield Ok(streaming::RawStreamingChoice::ToolCall {
176 name: function_call.name.clone(),
177 id: function_call.name.clone(),
178 arguments: function_call.args.clone(),
179 call_id: None
180 });
181 },
182 part => {
183 tracing::warn!(?part, "Unsupported response type with streaming");
184 }
185 }
186 }
187
188 if content.parts.is_empty() {
189 tracing::trace!(reason = ?choice.finish_reason, "There is no part in the streaming content");
190 }
191
192 if choice.finish_reason.is_some() {
194 if !text_response.is_empty() {
195 model_outputs.push(Part { thought: None, thought_signature: None, part: PartKind::Text(text_response), additional_params: None });
196 }
197 let span = tracing::Span::current();
198 span.record_token_usage(&data.usage_metadata);
199 final_usage = data.usage_metadata;
200 break;
201 }
202 }
203 Err(crate::http_client::Error::StreamEnded) => {
204 break;
205 }
206 Err(error) => {
207 tracing::error!(?error, "SSE error");
208 yield Err(CompletionError::ProviderError(error.to_string()));
209 break;
210 }
211 }
212 }
213
214 event_source.close();
216
217 yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
218 usage_metadata: final_usage.unwrap_or_default()
219 }));
220 }.instrument(span);
221
222 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
223 stream,
224 )))
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use serde_json::json;
232
233 #[test]
234 fn test_deserialize_stream_response_with_single_text_part() {
235 let json_data = json!({
236 "candidates": [{
237 "content": {
238 "parts": [
239 {"text": "Hello, world!"}
240 ],
241 "role": "model"
242 },
243 "finishReason": "STOP",
244 "index": 0
245 }],
246 "usageMetadata": {
247 "promptTokenCount": 10,
248 "candidatesTokenCount": 5,
249 "totalTokenCount": 15
250 }
251 });
252
253 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
254 assert_eq!(response.candidates.len(), 1);
255 let content = response.candidates[0]
256 .content
257 .as_ref()
258 .expect("candidate should contain content");
259 assert_eq!(content.parts.len(), 1);
260
261 if let Part {
262 part: PartKind::Text(text),
263 ..
264 } = &content.parts[0]
265 {
266 assert_eq!(text, "Hello, world!");
267 } else {
268 panic!("Expected text part");
269 }
270 }
271
272 #[test]
273 fn test_deserialize_stream_response_with_multiple_text_parts() {
274 let json_data = json!({
275 "candidates": [{
276 "content": {
277 "parts": [
278 {"text": "Hello, "},
279 {"text": "world!"},
280 {"text": " How are you?"}
281 ],
282 "role": "model"
283 },
284 "finishReason": "STOP",
285 "index": 0
286 }],
287 "usageMetadata": {
288 "promptTokenCount": 10,
289 "candidatesTokenCount": 8,
290 "totalTokenCount": 18
291 }
292 });
293
294 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
295 assert_eq!(response.candidates.len(), 1);
296 let content = response.candidates[0]
297 .content
298 .as_ref()
299 .expect("candidate should contain content");
300 assert_eq!(content.parts.len(), 3);
301
302 for (i, expected_text) in ["Hello, ", "world!", " How are you?"].iter().enumerate() {
304 if let Part {
305 part: PartKind::Text(text),
306 ..
307 } = &content.parts[i]
308 {
309 assert_eq!(text, expected_text);
310 } else {
311 panic!("Expected text part at index {}", i);
312 }
313 }
314 }
315
316 #[test]
317 fn test_deserialize_stream_response_with_multiple_tool_calls() {
318 let json_data = json!({
319 "candidates": [{
320 "content": {
321 "parts": [
322 {
323 "functionCall": {
324 "name": "get_weather",
325 "args": {"city": "San Francisco"}
326 }
327 },
328 {
329 "functionCall": {
330 "name": "get_temperature",
331 "args": {"location": "New York"}
332 }
333 }
334 ],
335 "role": "model"
336 },
337 "finishReason": "STOP",
338 "index": 0
339 }],
340 "usageMetadata": {
341 "promptTokenCount": 50,
342 "candidatesTokenCount": 20,
343 "totalTokenCount": 70
344 }
345 });
346
347 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
348 let content = response.candidates[0]
349 .content
350 .as_ref()
351 .expect("candidate should contain content");
352 assert_eq!(content.parts.len(), 2);
353
354 if let Part {
356 part: PartKind::FunctionCall(call),
357 ..
358 } = &content.parts[0]
359 {
360 assert_eq!(call.name, "get_weather");
361 } else {
362 panic!("Expected function call at index 0");
363 }
364
365 if let Part {
367 part: PartKind::FunctionCall(call),
368 ..
369 } = &content.parts[1]
370 {
371 assert_eq!(call.name, "get_temperature");
372 } else {
373 panic!("Expected function call at index 1");
374 }
375 }
376
377 #[test]
378 fn test_deserialize_stream_response_with_mixed_parts() {
379 let json_data = json!({
380 "candidates": [{
381 "content": {
382 "parts": [
383 {
384 "text": "Let me think about this...",
385 "thought": true
386 },
387 {
388 "text": "Here's my response: "
389 },
390 {
391 "functionCall": {
392 "name": "search",
393 "args": {"query": "rust async"}
394 }
395 },
396 {
397 "text": "I found the answer!"
398 }
399 ],
400 "role": "model"
401 },
402 "finishReason": "STOP",
403 "index": 0
404 }],
405 "usageMetadata": {
406 "promptTokenCount": 100,
407 "candidatesTokenCount": 50,
408 "thoughtsTokenCount": 15,
409 "totalTokenCount": 165
410 }
411 });
412
413 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
414 let content = response.candidates[0]
415 .content
416 .as_ref()
417 .expect("candidate should contain content");
418 let parts = &content.parts;
419 assert_eq!(parts.len(), 4);
420
421 if let Part {
423 part: PartKind::Text(text),
424 thought: Some(true),
425 ..
426 } = &parts[0]
427 {
428 assert_eq!(text, "Let me think about this...");
429 } else {
430 panic!("Expected thought part at index 0");
431 }
432
433 if let Part {
435 part: PartKind::Text(text),
436 thought,
437 ..
438 } = &parts[1]
439 {
440 assert_eq!(text, "Here's my response: ");
441 assert!(thought.is_none() || thought == &Some(false));
442 } else {
443 panic!("Expected text part at index 1");
444 }
445
446 if let Part {
448 part: PartKind::FunctionCall(call),
449 ..
450 } = &parts[2]
451 {
452 assert_eq!(call.name, "search");
453 } else {
454 panic!("Expected function call at index 2");
455 }
456
457 if let Part {
459 part: PartKind::Text(text),
460 ..
461 } = &parts[3]
462 {
463 assert_eq!(text, "I found the answer!");
464 } else {
465 panic!("Expected text part at index 3");
466 }
467 }
468
469 #[test]
470 fn test_deserialize_stream_response_with_empty_parts() {
471 let json_data = json!({
472 "candidates": [{
473 "content": {
474 "parts": [],
475 "role": "model"
476 },
477 "finishReason": "STOP",
478 "index": 0
479 }],
480 "usageMetadata": {
481 "promptTokenCount": 10,
482 "candidatesTokenCount": 0,
483 "totalTokenCount": 10
484 }
485 });
486
487 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
488 let content = response.candidates[0]
489 .content
490 .as_ref()
491 .expect("candidate should contain content");
492 assert_eq!(content.parts.len(), 0);
493 }
494
495 #[test]
496 fn test_partial_usage_token_calculation() {
497 let usage = PartialUsage {
498 total_token_count: 100,
499 cached_content_token_count: Some(20),
500 candidates_token_count: Some(30),
501 thoughts_token_count: Some(10),
502 prompt_token_count: 40,
503 };
504
505 let token_usage = usage.token_usage().unwrap();
506 assert_eq!(token_usage.input_tokens, 40);
507 assert_eq!(token_usage.output_tokens, 60); assert_eq!(token_usage.total_tokens, 100);
509 }
510
511 #[test]
512 fn test_partial_usage_with_missing_counts() {
513 let usage = PartialUsage {
514 total_token_count: 50,
515 cached_content_token_count: None,
516 candidates_token_count: Some(30),
517 thoughts_token_count: None,
518 prompt_token_count: 20,
519 };
520
521 let token_usage = usage.token_usage().unwrap();
522 assert_eq!(token_usage.input_tokens, 20);
523 assert_eq!(token_usage.output_tokens, 30); assert_eq!(token_usage.total_tokens, 50);
525 }
526
527 #[test]
528 fn test_streaming_completion_response_token_usage() {
529 let response = StreamingCompletionResponse {
530 usage_metadata: PartialUsage {
531 total_token_count: 150,
532 cached_content_token_count: None,
533 candidates_token_count: Some(75),
534 thoughts_token_count: None,
535 prompt_token_count: 75,
536 },
537 };
538
539 let token_usage = response.token_usage().unwrap();
540 assert_eq!(token_usage.input_tokens, 75);
541 assert_eq!(token_usage.output_tokens, 75);
542 assert_eq!(token_usage.total_tokens, 150);
543 }
544}