1use crate::{
2 http_client::{
3 HttpClientExt,
4 sse::{Event, GenericEventSource},
5 },
6 telemetry::SpanCombinator,
7};
8use async_stream::stream;
9use futures::StreamExt;
10use serde::{Deserialize, Serialize};
11use tracing::info_span;
12
13use super::completion::{
14 CompletionModel, create_request_body,
15 gemini_api_types::{ContentCandidate, Part, PartKind},
16};
17use crate::{
18 completion::{CompletionError, CompletionRequest, GetTokenUsage},
19 streaming::{self},
20};
21
22#[derive(Debug, Deserialize, Serialize, Default, Clone)]
23#[serde(rename_all = "camelCase")]
24pub struct PartialUsage {
25 pub total_token_count: i32,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub cached_content_token_count: Option<i32>,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub candidates_token_count: Option<i32>,
30 #[serde(skip_serializing_if = "Option::is_none")]
31 pub thoughts_token_count: Option<i32>,
32 pub prompt_token_count: i32,
33}
34
35impl GetTokenUsage for PartialUsage {
36 fn token_usage(&self) -> Option<crate::completion::Usage> {
37 let mut usage = crate::completion::Usage::new();
38
39 usage.input_tokens = self.prompt_token_count as u64;
40 usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
41 + self.candidates_token_count.unwrap_or_default()
42 + self.thoughts_token_count.unwrap_or_default()) as u64;
43 usage.total_tokens = usage.input_tokens + usage.output_tokens;
44
45 Some(usage)
46 }
47}
48
49#[derive(Debug, Deserialize)]
50#[serde(rename_all = "camelCase")]
51pub struct StreamGenerateContentResponse {
52 pub candidates: Vec<ContentCandidate>,
54 pub model_version: Option<String>,
55 pub usage_metadata: Option<PartialUsage>,
56}
57
58#[derive(Clone, Debug, Serialize, Deserialize)]
59pub struct StreamingCompletionResponse {
60 pub usage_metadata: PartialUsage,
61}
62
63impl GetTokenUsage for StreamingCompletionResponse {
64 fn token_usage(&self) -> Option<crate::completion::Usage> {
65 let mut usage = crate::completion::Usage::new();
66 usage.total_tokens = self.usage_metadata.total_token_count as u64;
67 usage.output_tokens = self
68 .usage_metadata
69 .candidates_token_count
70 .map(|x| x as u64)
71 .unwrap_or(0);
72 usage.input_tokens = self.usage_metadata.prompt_token_count as u64;
73 Some(usage)
74 }
75}
76
77impl<T> CompletionModel<T>
78where
79 T: HttpClientExt + Clone + 'static,
80{
81 pub(crate) async fn stream(
82 &self,
83 completion_request: CompletionRequest,
84 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
85 {
86 let span = if tracing::Span::current().is_disabled() {
87 info_span!(
88 target: "rig::completions",
89 "chat_streaming",
90 gen_ai.operation.name = "chat_streaming",
91 gen_ai.provider.name = "gcp.gemini",
92 gen_ai.request.model = self.model,
93 gen_ai.system_instructions = &completion_request.preamble,
94 gen_ai.response.id = tracing::field::Empty,
95 gen_ai.response.model = self.model,
96 gen_ai.usage.output_tokens = tracing::field::Empty,
97 gen_ai.usage.input_tokens = tracing::field::Empty,
98 gen_ai.input.messages = tracing::field::Empty,
99 gen_ai.output.messages = tracing::field::Empty,
100 )
101 } else {
102 tracing::Span::current()
103 };
104 let request = create_request_body(completion_request)?;
105
106 span.record_model_input(&request.contents);
107
108 tracing::debug!(
109 "Sending completion request to Gemini API {}",
110 serde_json::to_string_pretty(&request)?
111 );
112 let body = serde_json::to_vec(&request)?;
113
114 let req = self
115 .client
116 .post_sse(&format!(
117 "/v1beta/models/{}:streamGenerateContent",
118 self.model
119 ))
120 .header("Content-Type", "application/json")
121 .body(body)
122 .map_err(|e| CompletionError::HttpError(e.into()))?;
123
124 let mut event_source = GenericEventSource::new(self.client.http_client.clone(), req);
125
126 let stream = stream! {
127 let mut text_response = String::new();
128 let mut model_outputs: Vec<Part> = Vec::new();
129 while let Some(event_result) = event_source.next().await {
130 match event_result {
131 Ok(Event::Open) => {
132 tracing::debug!("SSE connection opened");
133 continue;
134 }
135 Ok(Event::Message(message)) => {
136 if message.data.trim().is_empty() {
138 continue;
139 }
140
141 let data = match serde_json::from_str::<StreamGenerateContentResponse>(&message.data) {
142 Ok(d) => d,
143 Err(error) => {
144 tracing::error!(?error, message = message.data, "Failed to parse SSE message");
145 continue;
146 }
147 };
148
149 let Some(choice) = data.candidates.first() else {
151 tracing::debug!("There is no content candidate");
152 continue;
153 };
154
155 for part in &choice.content.parts {
156 match part {
157 Part {
158 part: PartKind::Text(text),
159 thought: Some(true),
160 ..
161 } => {
162 yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning: text.clone(), id: None, signature: None });
163 },
164 Part {
165 part: PartKind::Text(text),
166 ..
167 } => {
168 text_response += text;
169 yield Ok(streaming::RawStreamingChoice::Message(text.clone()));
170 },
171 Part {
172 part: PartKind::FunctionCall(function_call),
173 ..
174 } => {
175 model_outputs.push(part.clone());
176 yield Ok(streaming::RawStreamingChoice::ToolCall {
177 name: function_call.name.clone(),
178 id: function_call.name.clone(),
179 arguments: function_call.args.clone(),
180 call_id: None
181 });
182 },
183 part => {
184 tracing::warn!(?part, "Unsupported response type with streaming");
185 }
186 }
187 }
188
189 if choice.content.parts.is_empty() {
190 tracing::trace!(reason = ?choice.finish_reason, "There is no part in the streaming content");
191 }
192
193 if choice.finish_reason.is_some() {
195 if !text_response.is_empty() {
196 model_outputs.push(Part { thought: None, thought_signature: None, part: PartKind::Text(text_response), additional_params: None });
197 }
198 let span = tracing::Span::current();
199 span.record_model_output(&model_outputs);
200 span.record_token_usage(&data.usage_metadata);
201 yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
202 usage_metadata: data.usage_metadata.unwrap_or_default()
203 }));
204 break;
205 }
206 }
207 Err(crate::http_client::Error::StreamEnded) => {
208 break;
209 }
210 Err(error) => {
211 tracing::error!(?error, "SSE error");
212 yield Err(CompletionError::ResponseError(error.to_string()));
213 break;
214 }
215 }
216 }
217
218 event_source.close();
220 };
221
222 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
223 stream,
224 )))
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use serde_json::json;
232
233 #[test]
234 fn test_deserialize_stream_response_with_single_text_part() {
235 let json_data = json!({
236 "candidates": [{
237 "content": {
238 "parts": [
239 {"text": "Hello, world!"}
240 ],
241 "role": "model"
242 },
243 "finishReason": "STOP",
244 "index": 0
245 }],
246 "usageMetadata": {
247 "promptTokenCount": 10,
248 "candidatesTokenCount": 5,
249 "totalTokenCount": 15
250 }
251 });
252
253 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
254 assert_eq!(response.candidates.len(), 1);
255 assert_eq!(response.candidates[0].content.parts.len(), 1);
256
257 if let Part {
258 part: PartKind::Text(text),
259 ..
260 } = &response.candidates[0].content.parts[0]
261 {
262 assert_eq!(text, "Hello, world!");
263 } else {
264 panic!("Expected text part");
265 }
266 }
267
268 #[test]
269 fn test_deserialize_stream_response_with_multiple_text_parts() {
270 let json_data = json!({
271 "candidates": [{
272 "content": {
273 "parts": [
274 {"text": "Hello, "},
275 {"text": "world!"},
276 {"text": " How are you?"}
277 ],
278 "role": "model"
279 },
280 "finishReason": "STOP",
281 "index": 0
282 }],
283 "usageMetadata": {
284 "promptTokenCount": 10,
285 "candidatesTokenCount": 8,
286 "totalTokenCount": 18
287 }
288 });
289
290 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
291 assert_eq!(response.candidates.len(), 1);
292 assert_eq!(response.candidates[0].content.parts.len(), 3);
293
294 for (i, expected_text) in ["Hello, ", "world!", " How are you?"].iter().enumerate() {
296 if let Part {
297 part: PartKind::Text(text),
298 ..
299 } = &response.candidates[0].content.parts[i]
300 {
301 assert_eq!(text, expected_text);
302 } else {
303 panic!("Expected text part at index {}", i);
304 }
305 }
306 }
307
308 #[test]
309 fn test_deserialize_stream_response_with_multiple_tool_calls() {
310 let json_data = json!({
311 "candidates": [{
312 "content": {
313 "parts": [
314 {
315 "functionCall": {
316 "name": "get_weather",
317 "args": {"city": "San Francisco"}
318 }
319 },
320 {
321 "functionCall": {
322 "name": "get_temperature",
323 "args": {"location": "New York"}
324 }
325 }
326 ],
327 "role": "model"
328 },
329 "finishReason": "STOP",
330 "index": 0
331 }],
332 "usageMetadata": {
333 "promptTokenCount": 50,
334 "candidatesTokenCount": 20,
335 "totalTokenCount": 70
336 }
337 });
338
339 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
340 assert_eq!(response.candidates[0].content.parts.len(), 2);
341
342 if let Part {
344 part: PartKind::FunctionCall(call),
345 ..
346 } = &response.candidates[0].content.parts[0]
347 {
348 assert_eq!(call.name, "get_weather");
349 } else {
350 panic!("Expected function call at index 0");
351 }
352
353 if let Part {
355 part: PartKind::FunctionCall(call),
356 ..
357 } = &response.candidates[0].content.parts[1]
358 {
359 assert_eq!(call.name, "get_temperature");
360 } else {
361 panic!("Expected function call at index 1");
362 }
363 }
364
365 #[test]
366 fn test_deserialize_stream_response_with_mixed_parts() {
367 let json_data = json!({
368 "candidates": [{
369 "content": {
370 "parts": [
371 {
372 "text": "Let me think about this...",
373 "thought": true
374 },
375 {
376 "text": "Here's my response: "
377 },
378 {
379 "functionCall": {
380 "name": "search",
381 "args": {"query": "rust async"}
382 }
383 },
384 {
385 "text": "I found the answer!"
386 }
387 ],
388 "role": "model"
389 },
390 "finishReason": "STOP",
391 "index": 0
392 }],
393 "usageMetadata": {
394 "promptTokenCount": 100,
395 "candidatesTokenCount": 50,
396 "thoughtsTokenCount": 15,
397 "totalTokenCount": 165
398 }
399 });
400
401 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
402 let parts = &response.candidates[0].content.parts;
403 assert_eq!(parts.len(), 4);
404
405 if let Part {
407 part: PartKind::Text(text),
408 thought: Some(true),
409 ..
410 } = &parts[0]
411 {
412 assert_eq!(text, "Let me think about this...");
413 } else {
414 panic!("Expected thought part at index 0");
415 }
416
417 if let Part {
419 part: PartKind::Text(text),
420 thought,
421 ..
422 } = &parts[1]
423 {
424 assert_eq!(text, "Here's my response: ");
425 assert!(thought.is_none() || thought == &Some(false));
426 } else {
427 panic!("Expected text part at index 1");
428 }
429
430 if let Part {
432 part: PartKind::FunctionCall(call),
433 ..
434 } = &parts[2]
435 {
436 assert_eq!(call.name, "search");
437 } else {
438 panic!("Expected function call at index 2");
439 }
440
441 if let Part {
443 part: PartKind::Text(text),
444 ..
445 } = &parts[3]
446 {
447 assert_eq!(text, "I found the answer!");
448 } else {
449 panic!("Expected text part at index 3");
450 }
451 }
452
453 #[test]
454 fn test_deserialize_stream_response_with_empty_parts() {
455 let json_data = json!({
456 "candidates": [{
457 "content": {
458 "parts": [],
459 "role": "model"
460 },
461 "finishReason": "STOP",
462 "index": 0
463 }],
464 "usageMetadata": {
465 "promptTokenCount": 10,
466 "candidatesTokenCount": 0,
467 "totalTokenCount": 10
468 }
469 });
470
471 let response: StreamGenerateContentResponse = serde_json::from_value(json_data).unwrap();
472 assert_eq!(response.candidates[0].content.parts.len(), 0);
473 }
474
475 #[test]
476 fn test_partial_usage_token_calculation() {
477 let usage = PartialUsage {
478 total_token_count: 100,
479 cached_content_token_count: Some(20),
480 candidates_token_count: Some(30),
481 thoughts_token_count: Some(10),
482 prompt_token_count: 40,
483 };
484
485 let token_usage = usage.token_usage().unwrap();
486 assert_eq!(token_usage.input_tokens, 40);
487 assert_eq!(token_usage.output_tokens, 60); assert_eq!(token_usage.total_tokens, 100);
489 }
490
491 #[test]
492 fn test_partial_usage_with_missing_counts() {
493 let usage = PartialUsage {
494 total_token_count: 50,
495 cached_content_token_count: None,
496 candidates_token_count: Some(30),
497 thoughts_token_count: None,
498 prompt_token_count: 20,
499 };
500
501 let token_usage = usage.token_usage().unwrap();
502 assert_eq!(token_usage.input_tokens, 20);
503 assert_eq!(token_usage.output_tokens, 30); assert_eq!(token_usage.total_tokens, 50);
505 }
506
507 #[test]
508 fn test_streaming_completion_response_token_usage() {
509 let response = StreamingCompletionResponse {
510 usage_metadata: PartialUsage {
511 total_token_count: 150,
512 cached_content_token_count: None,
513 candidates_token_count: Some(75),
514 thoughts_token_count: None,
515 prompt_token_count: 75,
516 },
517 };
518
519 let token_usage = response.token_usage().unwrap();
520 assert_eq!(token_usage.input_tokens, 75);
521 assert_eq!(token_usage.output_tokens, 75);
522 assert_eq!(token_usage.total_tokens, 150);
523 }
524}