1use crate::{
2 http_client::{
3 HttpClientExt,
4 sse::{Event, GenericEventSource},
5 },
6 telemetry::SpanCombinator,
7};
8use async_stream::stream;
9use futures::StreamExt;
10use serde::{Deserialize, Serialize};
11use tracing::info_span;
12
13use super::completion::{
14 CompletionModel, create_request_body,
15 gemini_api_types::{ContentCandidate, Part, PartKind},
16};
17use crate::{
18 completion::{CompletionError, CompletionRequest, GetTokenUsage},
19 streaming::{self},
20};
21
22#[derive(Debug, Deserialize, Serialize, Default, Clone)]
23#[serde(rename_all = "camelCase")]
24pub struct PartialUsage {
25 pub total_token_count: i32,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub cached_content_token_count: Option<i32>,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub candidates_token_count: Option<i32>,
30 #[serde(skip_serializing_if = "Option::is_none")]
31 pub thoughts_token_count: Option<i32>,
32 pub prompt_token_count: i32,
33}
34
35impl GetTokenUsage for PartialUsage {
36 fn token_usage(&self) -> Option<crate::completion::Usage> {
37 let mut usage = crate::completion::Usage::new();
38
39 usage.input_tokens = self.prompt_token_count as u64;
40 usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
41 + self.candidates_token_count.unwrap_or_default()
42 + self.thoughts_token_count.unwrap_or_default()) as u64;
43 usage.total_tokens = usage.input_tokens + usage.output_tokens;
44
45 Some(usage)
46 }
47}
48
49#[derive(Debug, Deserialize)]
50#[serde(rename_all = "camelCase")]
51pub struct StreamGenerateContentResponse {
52 pub candidates: Vec<ContentCandidate>,
54 pub model_version: Option<String>,
55 pub usage_metadata: Option<PartialUsage>,
56}
57
58#[derive(Clone, Debug, Serialize, Deserialize)]
59pub struct StreamingCompletionResponse {
60 pub usage_metadata: PartialUsage,
61}
62
63impl GetTokenUsage for StreamingCompletionResponse {
64 fn token_usage(&self) -> Option<crate::completion::Usage> {
65 let mut usage = crate::completion::Usage::new();
66 usage.total_tokens = self.usage_metadata.total_token_count as u64;
67 usage.output_tokens = self
68 .usage_metadata
69 .candidates_token_count
70 .map(|x| x as u64)
71 .unwrap_or(0);
72 usage.input_tokens = self.usage_metadata.prompt_token_count as u64;
73 Some(usage)
74 }
75}
76
77impl<T> CompletionModel<T>
78where
79 T: HttpClientExt + Clone + 'static,
80{
81 pub(crate) async fn stream(
82 &self,
83 completion_request: CompletionRequest,
84 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
85 {
86 let span = if tracing::Span::current().is_disabled() {
87 info_span!(
88 target: "rig::completions",
89 "chat_streaming",
90 gen_ai.operation.name = "chat_streaming",
91 gen_ai.provider.name = "gcp.gemini",
92 gen_ai.request.model = self.model,
93 gen_ai.system_instructions = &completion_request.preamble,
94 gen_ai.response.id = tracing::field::Empty,
95 gen_ai.response.model = self.model,
96 gen_ai.usage.output_tokens = tracing::field::Empty,
97 gen_ai.usage.input_tokens = tracing::field::Empty,
98 gen_ai.input.messages = tracing::field::Empty,
99 gen_ai.output.messages = tracing::field::Empty,
100 )
101 } else {
102 tracing::Span::current()
103 };
104 let request = create_request_body(completion_request)?;
105
106 span.record_model_input(&request.contents);
107
108 tracing::debug!(
109 "Sending completion request to Gemini API {}",
110 serde_json::to_string_pretty(&request)?
111 );
112 let body = serde_json::to_vec(&request)?;
113
114 let req = self
115 .client
116 .post_sse(&format!(
117 "/v1beta/models/{}:streamGenerateContent",
118 self.model
119 ))
120 .header("Content-Type", "application/json")
121 .body(body)
122 .map_err(|e| CompletionError::HttpError(e.into()))?;
123
124 let mut event_source = GenericEventSource::new(self.client.http_client.clone(), req);
125
126 let stream = stream! {
127 let mut text_response = String::new();
128 let mut model_outputs: Vec<Part> = Vec::new();
129 while let Some(event_result) = event_source.next().await {
130 match event_result {
131 Ok(Event::Open) => {
132 tracing::debug!("SSE connection opened");
133 continue;
134 }
135 Ok(Event::Message(message)) => {
136 if message.data.trim().is_empty() {
138 continue;
139 }
140
141 let data = match serde_json::from_str::<StreamGenerateContentResponse>(&message.data) {
142 Ok(d) => d,
143 Err(error) => {
144 tracing::error!(?error, message = message.data, "Failed to parse SSE message");
145 continue;
146 }
147 };
148
149 let Some(choice) = data.candidates.first() else {
151 tracing::debug!("There is no content candidate");
152 continue;
153 };
154
155 let Some(content) = choice.content.as_ref() else {
156 tracing::debug!(finish_reason = ?choice.finish_reason, "Streaming candidate missing content");
157 continue;
158 };
159
160 for part in &content.parts {
161 match part {
162 Part {
163 part: PartKind::Text(text),
164 thought: Some(true),
165 ..
166 } => {
167 yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning: text.clone(), id: None, signature: None });
168 },
169 Part {
170 part: PartKind::Text(text),
171 ..
172 } => {
173 text_response.push_str(text);
174 yield Ok(streaming::RawStreamingChoice::Message(text.clone()));
175 },
176 Part {
177 part: PartKind::FunctionCall(function_call),
178 ..
179 } => {
180 model_outputs.push(part.clone());
181 yield Ok(streaming::RawStreamingChoice::ToolCall {
182 name: function_call.name.clone(),
183 id: function_call.name.clone(),
184 arguments: function_call.args.clone(),
185 call_id: None
186 });
187 },
188 part => {
189 tracing::warn!(?part, "Unsupported response type with streaming");
190 }
191 }
192 }
193
194 if content.parts.is_empty() {
195 tracing::trace!(reason = ?choice.finish_reason, "There is no part in the streaming content");
196 }
197
198 if choice.finish_reason.is_some() {
200 if !text_response.is_empty() {
201 model_outputs.push(Part { thought: None, thought_signature: None, part: PartKind::Text(text_response), additional_params: None });
202 }
203 let span = tracing::Span::current();
204 span.record_model_output(&model_outputs);
205 span.record_token_usage(&data.usage_metadata);
206 yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
207 usage_metadata: data.usage_metadata.unwrap_or_default()
208 }));
209 break;
210 }
211 }
212 Err(crate::http_client::Error::StreamEnded) => {
213 break;
214 }
215 Err(error) => {
216 tracing::error!(?error, "SSE error");
217 yield Err(CompletionError::ResponseError(error.to_string()));
218 break;
219 }
220 }
221 }
222
223 event_source.close();
225 };
226
227 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
228 stream,
229 )))
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use serde_json::json;
237
238 #[test]
239 fn test_deserialize_stream_response_with_single_text_part() {
240 let json_data = json!({
241 "candidates": [{
242 "content": {
243 "parts": [
244 {"text": "Hello, world!"}
245 ],
246 "role": "model"
247 },
248 "finishReason": "STOP",
249 "index": 0
250 }],
251 "usageMetadata": {
252 "promptTokenCount": 10,
253 "candidatesTokenCount": 5,
254 "totalTokenCount": 15
255 }
256 });
257
258 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
259 assert_eq!(response.candidates.len(), 1);
260 let content = response.candidates[0]
261 .content
262 .as_ref()
263 .expect("candidate should contain content");
264 assert_eq!(content.parts.len(), 1);
265
266 if let Part {
267 part: PartKind::Text(text),
268 ..
269 } = &content.parts[0]
270 {
271 assert_eq!(text, "Hello, world!");
272 } else {
273 panic!("Expected text part");
274 }
275 }
276
277 #[test]
278 fn test_deserialize_stream_response_with_multiple_text_parts() {
279 let json_data = json!({
280 "candidates": [{
281 "content": {
282 "parts": [
283 {"text": "Hello, "},
284 {"text": "world!"},
285 {"text": " How are you?"}
286 ],
287 "role": "model"
288 },
289 "finishReason": "STOP",
290 "index": 0
291 }],
292 "usageMetadata": {
293 "promptTokenCount": 10,
294 "candidatesTokenCount": 8,
295 "totalTokenCount": 18
296 }
297 });
298
299 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
300 assert_eq!(response.candidates.len(), 1);
301 let content = response.candidates[0]
302 .content
303 .as_ref()
304 .expect("candidate should contain content");
305 assert_eq!(content.parts.len(), 3);
306
307 for (i, expected_text) in ["Hello, ", "world!", " How are you?"].iter().enumerate() {
309 if let Part {
310 part: PartKind::Text(text),
311 ..
312 } = &content.parts[i]
313 {
314 assert_eq!(text, expected_text);
315 } else {
316 panic!("Expected text part at index {}", i);
317 }
318 }
319 }
320
321 #[test]
322 fn test_deserialize_stream_response_with_multiple_tool_calls() {
323 let json_data = json!({
324 "candidates": [{
325 "content": {
326 "parts": [
327 {
328 "functionCall": {
329 "name": "get_weather",
330 "args": {"city": "San Francisco"}
331 }
332 },
333 {
334 "functionCall": {
335 "name": "get_temperature",
336 "args": {"location": "New York"}
337 }
338 }
339 ],
340 "role": "model"
341 },
342 "finishReason": "STOP",
343 "index": 0
344 }],
345 "usageMetadata": {
346 "promptTokenCount": 50,
347 "candidatesTokenCount": 20,
348 "totalTokenCount": 70
349 }
350 });
351
352 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
353 let content = response.candidates[0]
354 .content
355 .as_ref()
356 .expect("candidate should contain content");
357 assert_eq!(content.parts.len(), 2);
358
359 if let Part {
361 part: PartKind::FunctionCall(call),
362 ..
363 } = &content.parts[0]
364 {
365 assert_eq!(call.name, "get_weather");
366 } else {
367 panic!("Expected function call at index 0");
368 }
369
370 if let Part {
372 part: PartKind::FunctionCall(call),
373 ..
374 } = &content.parts[1]
375 {
376 assert_eq!(call.name, "get_temperature");
377 } else {
378 panic!("Expected function call at index 1");
379 }
380 }
381
382 #[test]
383 fn test_deserialize_stream_response_with_mixed_parts() {
384 let json_data = json!({
385 "candidates": [{
386 "content": {
387 "parts": [
388 {
389 "text": "Let me think about this...",
390 "thought": true
391 },
392 {
393 "text": "Here's my response: "
394 },
395 {
396 "functionCall": {
397 "name": "search",
398 "args": {"query": "rust async"}
399 }
400 },
401 {
402 "text": "I found the answer!"
403 }
404 ],
405 "role": "model"
406 },
407 "finishReason": "STOP",
408 "index": 0
409 }],
410 "usageMetadata": {
411 "promptTokenCount": 100,
412 "candidatesTokenCount": 50,
413 "thoughtsTokenCount": 15,
414 "totalTokenCount": 165
415 }
416 });
417
418 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
419 let content = response.candidates[0]
420 .content
421 .as_ref()
422 .expect("candidate should contain content");
423 let parts = &content.parts;
424 assert_eq!(parts.len(), 4);
425
426 if let Part {
428 part: PartKind::Text(text),
429 thought: Some(true),
430 ..
431 } = &parts[0]
432 {
433 assert_eq!(text, "Let me think about this...");
434 } else {
435 panic!("Expected thought part at index 0");
436 }
437
438 if let Part {
440 part: PartKind::Text(text),
441 thought,
442 ..
443 } = &parts[1]
444 {
445 assert_eq!(text, "Here's my response: ");
446 assert!(thought.is_none() || thought == &Some(false));
447 } else {
448 panic!("Expected text part at index 1");
449 }
450
451 if let Part {
453 part: PartKind::FunctionCall(call),
454 ..
455 } = &parts[2]
456 {
457 assert_eq!(call.name, "search");
458 } else {
459 panic!("Expected function call at index 2");
460 }
461
462 if let Part {
464 part: PartKind::Text(text),
465 ..
466 } = &parts[3]
467 {
468 assert_eq!(text, "I found the answer!");
469 } else {
470 panic!("Expected text part at index 3");
471 }
472 }
473
474 #[test]
475 fn test_deserialize_stream_response_with_empty_parts() {
476 let json_data = json!({
477 "candidates": [{
478 "content": {
479 "parts": [],
480 "role": "model"
481 },
482 "finishReason": "STOP",
483 "index": 0
484 }],
485 "usageMetadata": {
486 "promptTokenCount": 10,
487 "candidatesTokenCount": 0,
488 "totalTokenCount": 10
489 }
490 });
491
492 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
493 let content = response.candidates[0]
494 .content
495 .as_ref()
496 .expect("candidate should contain content");
497 assert_eq!(content.parts.len(), 0);
498 }
499
500 #[test]
501 fn test_partial_usage_token_calculation() {
502 let usage = PartialUsage {
503 total_token_count: 100,
504 cached_content_token_count: Some(20),
505 candidates_token_count: Some(30),
506 thoughts_token_count: Some(10),
507 prompt_token_count: 40,
508 };
509
510 let token_usage = usage.token_usage().unwrap();
511 assert_eq!(token_usage.input_tokens, 40);
512 assert_eq!(token_usage.output_tokens, 60); assert_eq!(token_usage.total_tokens, 100);
514 }
515
516 #[test]
517 fn test_partial_usage_with_missing_counts() {
518 let usage = PartialUsage {
519 total_token_count: 50,
520 cached_content_token_count: None,
521 candidates_token_count: Some(30),
522 thoughts_token_count: None,
523 prompt_token_count: 20,
524 };
525
526 let token_usage = usage.token_usage().unwrap();
527 assert_eq!(token_usage.input_tokens, 20);
528 assert_eq!(token_usage.output_tokens, 30); assert_eq!(token_usage.total_tokens, 50);
530 }
531
532 #[test]
533 fn test_streaming_completion_response_token_usage() {
534 let response = StreamingCompletionResponse {
535 usage_metadata: PartialUsage {
536 total_token_count: 150,
537 cached_content_token_count: None,
538 candidates_token_count: Some(75),
539 thoughts_token_count: None,
540 prompt_token_count: 75,
541 },
542 };
543
544 let token_usage = response.token_usage().unwrap();
545 assert_eq!(token_usage.input_tokens, 75);
546 assert_eq!(token_usage.output_tokens, 75);
547 assert_eq!(token_usage.total_tokens, 150);
548 }
549}