1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use tracing::info_span;
5
6use super::completion::gemini_api_types::{ContentCandidate, Part, PartKind};
7use super::completion::{CompletionModel, create_request_body};
8use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
9use crate::http_client::HttpClientExt;
10use crate::http_client::sse::{Event, GenericEventSource};
11use crate::streaming;
12use crate::telemetry::SpanCombinator;
13
14#[derive(Debug, Deserialize, Serialize, Default, Clone)]
15#[serde(rename_all = "camelCase")]
16pub struct PartialUsage {
17 pub total_token_count: i32,
18 #[serde(skip_serializing_if = "Option::is_none")]
19 pub cached_content_token_count: Option<i32>,
20 #[serde(skip_serializing_if = "Option::is_none")]
21 pub candidates_token_count: Option<i32>,
22 #[serde(skip_serializing_if = "Option::is_none")]
23 pub thoughts_token_count: Option<i32>,
24 pub prompt_token_count: i32,
25}
26
27impl GetTokenUsage for PartialUsage {
28 fn token_usage(&self) -> Option<crate::completion::Usage> {
29 let mut usage = crate::completion::Usage::new();
30
31 usage.input_tokens = self.prompt_token_count as u64;
32 usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
33 + self.candidates_token_count.unwrap_or_default()
34 + self.thoughts_token_count.unwrap_or_default()) as u64;
35 usage.total_tokens = usage.input_tokens + usage.output_tokens;
36
37 Some(usage)
38 }
39}
40
41#[derive(Debug, Deserialize)]
42#[serde(rename_all = "camelCase")]
43pub struct StreamGenerateContentResponse {
44 pub candidates: Vec<ContentCandidate>,
46 pub model_version: Option<String>,
47 pub usage_metadata: Option<PartialUsage>,
48}
49
50#[derive(Clone, Debug, Serialize, Deserialize)]
51pub struct StreamingCompletionResponse {
52 pub usage_metadata: PartialUsage,
53}
54
55impl GetTokenUsage for StreamingCompletionResponse {
56 fn token_usage(&self) -> Option<crate::completion::Usage> {
57 let mut usage = crate::completion::Usage::new();
58 usage.total_tokens = self.usage_metadata.total_token_count as u64;
59 usage.output_tokens = self
60 .usage_metadata
61 .candidates_token_count
62 .map(|x| x as u64)
63 .unwrap_or(0);
64 usage.input_tokens = self.usage_metadata.prompt_token_count as u64;
65 Some(usage)
66 }
67}
68
69impl<T> CompletionModel<T>
70where
71 T: HttpClientExt + Clone + 'static,
72{
73 pub(crate) async fn stream(
74 &self,
75 completion_request: CompletionRequest,
76 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
77 {
78 let span = if tracing::Span::current().is_disabled() {
79 info_span!(
80 target: "rig::completions",
81 "chat_streaming",
82 gen_ai.operation.name = "chat_streaming",
83 gen_ai.provider.name = "gcp.gemini",
84 gen_ai.request.model = self.model,
85 gen_ai.system_instructions = &completion_request.preamble,
86 gen_ai.response.id = tracing::field::Empty,
87 gen_ai.response.model = self.model,
88 gen_ai.usage.output_tokens = tracing::field::Empty,
89 gen_ai.usage.input_tokens = tracing::field::Empty,
90 gen_ai.input.messages = tracing::field::Empty,
91 gen_ai.output.messages = tracing::field::Empty,
92 )
93 } else {
94 tracing::Span::current()
95 };
96 let request = create_request_body(completion_request)?;
97
98 span.record_model_input(&request.contents);
99
100 tracing::trace!(
101 target: "rig::streaming",
102 "Sending completion request to Gemini API {}",
103 serde_json::to_string_pretty(&request)?
104 );
105 let body = serde_json::to_vec(&request)?;
106
107 let req = self
108 .client
109 .post_sse(format!(
110 "/v1beta/models/{}:streamGenerateContent",
111 self.model
112 ))?
113 .header("Content-Type", "application/json")
114 .body(body)
115 .map_err(|e| CompletionError::HttpError(e.into()))?;
116
117 let mut event_source = GenericEventSource::new(self.client.http_client().clone(), req);
118
119 let stream = stream! {
120 let mut text_response = String::new();
121 let mut model_outputs: Vec<Part> = Vec::new();
122 let mut final_usage = None;
123 while let Some(event_result) = event_source.next().await {
124 match event_result {
125 Ok(Event::Open) => {
126 tracing::debug!("SSE connection opened");
127 continue;
128 }
129 Ok(Event::Message(message)) => {
130 if message.data.trim().is_empty() {
132 continue;
133 }
134
135 let data = match serde_json::from_str::<StreamGenerateContentResponse>(&message.data) {
136 Ok(d) => d,
137 Err(error) => {
138 tracing::error!(?error, message = message.data, "Failed to parse SSE message");
139 continue;
140 }
141 };
142
143 let Some(choice) = data.candidates.first() else {
145 tracing::debug!("There is no content candidate");
146 continue;
147 };
148
149 let Some(content) = choice.content.as_ref() else {
150 tracing::debug!(finish_reason = ?choice.finish_reason, "Streaming candidate missing content");
151 continue;
152 };
153
154 for part in &content.parts {
155 match part {
156 Part {
157 part: PartKind::Text(text),
158 thought: Some(true),
159 ..
160 } => {
161 yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning: text.clone(), id: None, signature: None });
162 },
163 Part {
164 part: PartKind::Text(text),
165 ..
166 } => {
167 text_response.push_str(text);
168 yield Ok(streaming::RawStreamingChoice::Message(text.clone()));
169 },
170 Part {
171 part: PartKind::FunctionCall(function_call),
172 ..
173 } => {
174 model_outputs.push(part.clone());
175 yield Ok(streaming::RawStreamingChoice::ToolCall {
176 name: function_call.name.clone(),
177 id: function_call.name.clone(),
178 arguments: function_call.args.clone(),
179 call_id: None
180 });
181 },
182 part => {
183 tracing::warn!(?part, "Unsupported response type with streaming");
184 }
185 }
186 }
187
188 if content.parts.is_empty() {
189 tracing::trace!(reason = ?choice.finish_reason, "There is no part in the streaming content");
190 }
191
192 if choice.finish_reason.is_some() {
194 if !text_response.is_empty() {
195 model_outputs.push(Part { thought: None, thought_signature: None, part: PartKind::Text(text_response), additional_params: None });
196 }
197 let span = tracing::Span::current();
198 span.record_model_output(&model_outputs);
199 span.record_token_usage(&data.usage_metadata);
200 final_usage = data.usage_metadata;
201 break;
202 }
203 }
204 Err(crate::http_client::Error::StreamEnded) => {
205 break;
206 }
207 Err(error) => {
208 tracing::error!(?error, "SSE error");
209 yield Err(CompletionError::ProviderError(error.to_string()));
210 break;
211 }
212 }
213 }
214
215 event_source.close();
217
218 yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
219 usage_metadata: final_usage.unwrap_or_default()
220 }));
221 };
222
223 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
224 stream,
225 )))
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use serde_json::json;
233
234 #[test]
235 fn test_deserialize_stream_response_with_single_text_part() {
236 let json_data = json!({
237 "candidates": [{
238 "content": {
239 "parts": [
240 {"text": "Hello, world!"}
241 ],
242 "role": "model"
243 },
244 "finishReason": "STOP",
245 "index": 0
246 }],
247 "usageMetadata": {
248 "promptTokenCount": 10,
249 "candidatesTokenCount": 5,
250 "totalTokenCount": 15
251 }
252 });
253
254 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
255 assert_eq!(response.candidates.len(), 1);
256 let content = response.candidates[0]
257 .content
258 .as_ref()
259 .expect("candidate should contain content");
260 assert_eq!(content.parts.len(), 1);
261
262 if let Part {
263 part: PartKind::Text(text),
264 ..
265 } = &content.parts[0]
266 {
267 assert_eq!(text, "Hello, world!");
268 } else {
269 panic!("Expected text part");
270 }
271 }
272
273 #[test]
274 fn test_deserialize_stream_response_with_multiple_text_parts() {
275 let json_data = json!({
276 "candidates": [{
277 "content": {
278 "parts": [
279 {"text": "Hello, "},
280 {"text": "world!"},
281 {"text": " How are you?"}
282 ],
283 "role": "model"
284 },
285 "finishReason": "STOP",
286 "index": 0
287 }],
288 "usageMetadata": {
289 "promptTokenCount": 10,
290 "candidatesTokenCount": 8,
291 "totalTokenCount": 18
292 }
293 });
294
295 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
296 assert_eq!(response.candidates.len(), 1);
297 let content = response.candidates[0]
298 .content
299 .as_ref()
300 .expect("candidate should contain content");
301 assert_eq!(content.parts.len(), 3);
302
303 for (i, expected_text) in ["Hello, ", "world!", " How are you?"].iter().enumerate() {
305 if let Part {
306 part: PartKind::Text(text),
307 ..
308 } = &content.parts[i]
309 {
310 assert_eq!(text, expected_text);
311 } else {
312 panic!("Expected text part at index {}", i);
313 }
314 }
315 }
316
317 #[test]
318 fn test_deserialize_stream_response_with_multiple_tool_calls() {
319 let json_data = json!({
320 "candidates": [{
321 "content": {
322 "parts": [
323 {
324 "functionCall": {
325 "name": "get_weather",
326 "args": {"city": "San Francisco"}
327 }
328 },
329 {
330 "functionCall": {
331 "name": "get_temperature",
332 "args": {"location": "New York"}
333 }
334 }
335 ],
336 "role": "model"
337 },
338 "finishReason": "STOP",
339 "index": 0
340 }],
341 "usageMetadata": {
342 "promptTokenCount": 50,
343 "candidatesTokenCount": 20,
344 "totalTokenCount": 70
345 }
346 });
347
348 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
349 let content = response.candidates[0]
350 .content
351 .as_ref()
352 .expect("candidate should contain content");
353 assert_eq!(content.parts.len(), 2);
354
355 if let Part {
357 part: PartKind::FunctionCall(call),
358 ..
359 } = &content.parts[0]
360 {
361 assert_eq!(call.name, "get_weather");
362 } else {
363 panic!("Expected function call at index 0");
364 }
365
366 if let Part {
368 part: PartKind::FunctionCall(call),
369 ..
370 } = &content.parts[1]
371 {
372 assert_eq!(call.name, "get_temperature");
373 } else {
374 panic!("Expected function call at index 1");
375 }
376 }
377
378 #[test]
379 fn test_deserialize_stream_response_with_mixed_parts() {
380 let json_data = json!({
381 "candidates": [{
382 "content": {
383 "parts": [
384 {
385 "text": "Let me think about this...",
386 "thought": true
387 },
388 {
389 "text": "Here's my response: "
390 },
391 {
392 "functionCall": {
393 "name": "search",
394 "args": {"query": "rust async"}
395 }
396 },
397 {
398 "text": "I found the answer!"
399 }
400 ],
401 "role": "model"
402 },
403 "finishReason": "STOP",
404 "index": 0
405 }],
406 "usageMetadata": {
407 "promptTokenCount": 100,
408 "candidatesTokenCount": 50,
409 "thoughtsTokenCount": 15,
410 "totalTokenCount": 165
411 }
412 });
413
414 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
415 let content = response.candidates[0]
416 .content
417 .as_ref()
418 .expect("candidate should contain content");
419 let parts = &content.parts;
420 assert_eq!(parts.len(), 4);
421
422 if let Part {
424 part: PartKind::Text(text),
425 thought: Some(true),
426 ..
427 } = &parts[0]
428 {
429 assert_eq!(text, "Let me think about this...");
430 } else {
431 panic!("Expected thought part at index 0");
432 }
433
434 if let Part {
436 part: PartKind::Text(text),
437 thought,
438 ..
439 } = &parts[1]
440 {
441 assert_eq!(text, "Here's my response: ");
442 assert!(thought.is_none() || thought == &Some(false));
443 } else {
444 panic!("Expected text part at index 1");
445 }
446
447 if let Part {
449 part: PartKind::FunctionCall(call),
450 ..
451 } = &parts[2]
452 {
453 assert_eq!(call.name, "search");
454 } else {
455 panic!("Expected function call at index 2");
456 }
457
458 if let Part {
460 part: PartKind::Text(text),
461 ..
462 } = &parts[3]
463 {
464 assert_eq!(text, "I found the answer!");
465 } else {
466 panic!("Expected text part at index 3");
467 }
468 }
469
470 #[test]
471 fn test_deserialize_stream_response_with_empty_parts() {
472 let json_data = json!({
473 "candidates": [{
474 "content": {
475 "parts": [],
476 "role": "model"
477 },
478 "finishReason": "STOP",
479 "index": 0
480 }],
481 "usageMetadata": {
482 "promptTokenCount": 10,
483 "candidatesTokenCount": 0,
484 "totalTokenCount": 10
485 }
486 });
487
488 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
489 let content = response.candidates[0]
490 .content
491 .as_ref()
492 .expect("candidate should contain content");
493 assert_eq!(content.parts.len(), 0);
494 }
495
496 #[test]
497 fn test_partial_usage_token_calculation() {
498 let usage = PartialUsage {
499 total_token_count: 100,
500 cached_content_token_count: Some(20),
501 candidates_token_count: Some(30),
502 thoughts_token_count: Some(10),
503 prompt_token_count: 40,
504 };
505
506 let token_usage = usage.token_usage().unwrap();
507 assert_eq!(token_usage.input_tokens, 40);
508 assert_eq!(token_usage.output_tokens, 60); assert_eq!(token_usage.total_tokens, 100);
510 }
511
512 #[test]
513 fn test_partial_usage_with_missing_counts() {
514 let usage = PartialUsage {
515 total_token_count: 50,
516 cached_content_token_count: None,
517 candidates_token_count: Some(30),
518 thoughts_token_count: None,
519 prompt_token_count: 20,
520 };
521
522 let token_usage = usage.token_usage().unwrap();
523 assert_eq!(token_usage.input_tokens, 20);
524 assert_eq!(token_usage.output_tokens, 30); assert_eq!(token_usage.total_tokens, 50);
526 }
527
528 #[test]
529 fn test_streaming_completion_response_token_usage() {
530 let response = StreamingCompletionResponse {
531 usage_metadata: PartialUsage {
532 total_token_count: 150,
533 cached_content_token_count: None,
534 candidates_token_count: Some(75),
535 thoughts_token_count: None,
536 prompt_token_count: 75,
537 },
538 };
539
540 let token_usage = response.token_usage().unwrap();
541 assert_eq!(token_usage.input_tokens, 75);
542 assert_eq!(token_usage.output_tokens, 75);
543 assert_eq!(token_usage.total_tokens, 150);
544 }
545}