rig/providers/gemini/
streaming.rs1use async_stream::stream;
2use futures::StreamExt;
3use reqwest_eventsource::{Event, RequestBuilderExt};
4use serde::{Deserialize, Serialize};
5
6use super::completion::{
7 CompletionModel, create_request_body,
8 gemini_api_types::{ContentCandidate, Part, PartKind},
9};
10use crate::{
11 completion::{CompletionError, CompletionRequest, GetTokenUsage},
12 streaming::{self},
13};
14
15#[derive(Debug, Deserialize, Serialize, Default, Clone)]
16#[serde(rename_all = "camelCase")]
17pub struct PartialUsage {
18 pub total_token_count: i32,
19 #[serde(skip_serializing_if = "Option::is_none")]
20 pub cached_content_token_count: Option<i32>,
21 #[serde(skip_serializing_if = "Option::is_none")]
22 pub candidates_token_count: Option<i32>,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 pub thoughts_token_count: Option<i32>,
25 pub prompt_token_count: i32,
26}
27
28#[derive(Debug, Deserialize)]
29#[serde(rename_all = "camelCase")]
30pub struct StreamGenerateContentResponse {
31 pub candidates: Vec<ContentCandidate>,
33 pub model_version: Option<String>,
34 pub usage_metadata: Option<PartialUsage>,
35}
36
37#[derive(Clone, Debug, Serialize, Deserialize)]
38pub struct StreamingCompletionResponse {
39 pub usage_metadata: PartialUsage,
40}
41
42impl GetTokenUsage for StreamingCompletionResponse {
43 fn token_usage(&self) -> Option<crate::completion::Usage> {
44 let mut usage = crate::completion::Usage::new();
45 usage.total_tokens = self.usage_metadata.total_token_count as u64;
46 usage.output_tokens = self
47 .usage_metadata
48 .candidates_token_count
49 .map(|x| x as u64)
50 .unwrap_or(0);
51 usage.input_tokens = self.usage_metadata.prompt_token_count as u64;
52 Some(usage)
53 }
54}
55
56impl CompletionModel {
57 pub(crate) async fn stream(
58 &self,
59 completion_request: CompletionRequest,
60 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
61 {
62 let request = create_request_body(completion_request)?;
63
64 tracing::debug!(
65 "Sending completion request to Gemini API {}",
66 serde_json::to_string_pretty(&request)?
67 );
68
69 let mut event_source = self
71 .client
72 .post_sse(&format!(
73 "/v1beta/models/{}:streamGenerateContent",
74 self.model
75 ))
76 .json(&request)
77 .eventsource()
78 .expect("Cloning request must always succeed");
79
80 let stream = Box::pin(stream! {
81 while let Some(event_result) = event_source.next().await {
82 match event_result {
83 Ok(Event::Open) => {
84 tracing::trace!("SSE connection opened");
85 continue;
86 }
87 Ok(Event::Message(message)) => {
88 if message.data.trim().is_empty() {
90 continue;
91 }
92
93 let data = match serde_json::from_str::<StreamGenerateContentResponse>(&message.data) {
94 Ok(d) => d,
95 Err(error) => {
96 tracing::error!(?error, message = message.data, "Failed to parse SSE message");
97 continue;
98 }
99 };
100
101 let Some(choice) = data.candidates.first() else {
103 tracing::debug!("There is no content candidate");
104 continue;
105 };
106
107 match choice.content.parts.first() {
108 Some(Part {
109 part: PartKind::Text(text),
110 thought: Some(true),
111 ..
112 }) => {
113 yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning: text.clone(), id: None });
114 },
115 Some(Part {
116 part: PartKind::Text(text),
117 ..
118 }) => {
119 yield Ok(streaming::RawStreamingChoice::Message(text.clone()));
120 },
121 Some(Part {
122 part: PartKind::FunctionCall(function_call),
123 ..
124 }) => {
125 yield Ok(streaming::RawStreamingChoice::ToolCall {
126 name: function_call.name.clone(),
127 id: function_call.name.clone(),
128 arguments: function_call.args.clone(),
129 call_id: None
130 });
131 },
132 Some(part) => {
133 tracing::warn!(?part, "Unsupported response type with streaming");
134 }
135 None => tracing::trace!(reason = ?choice.finish_reason, "There is no part in the streaming content"),
136 }
137
138 if choice.finish_reason.is_some() {
140 yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
141 usage_metadata: data.usage_metadata.unwrap_or_default()
142 }));
143 break;
144 }
145 }
146 Err(reqwest_eventsource::Error::StreamEnded) => {
147 break;
148 }
149 Err(error) => {
150 tracing::error!(?error, "SSE error");
151 yield Err(CompletionError::ResponseError(error.to_string()));
152 break;
153 }
154 }
155 }
156
157 event_source.close();
159 });
160
161 Ok(streaming::StreamingCompletionResponse::stream(stream))
162 }
163}