rig/providers/gemini/
streaming.rs1use crate::telemetry::SpanCombinator;
2use async_stream::stream;
3use futures::StreamExt;
4use reqwest_eventsource::{Event, RequestBuilderExt};
5use serde::{Deserialize, Serialize};
6use tracing::info_span;
7
8use super::completion::{
9 CompletionModel, create_request_body,
10 gemini_api_types::{ContentCandidate, Part, PartKind},
11};
12use crate::{
13 completion::{CompletionError, CompletionRequest, GetTokenUsage},
14 streaming::{self},
15};
16
17#[derive(Debug, Deserialize, Serialize, Default, Clone)]
18#[serde(rename_all = "camelCase")]
19pub struct PartialUsage {
20 pub total_token_count: i32,
21 #[serde(skip_serializing_if = "Option::is_none")]
22 pub cached_content_token_count: Option<i32>,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 pub candidates_token_count: Option<i32>,
25 #[serde(skip_serializing_if = "Option::is_none")]
26 pub thoughts_token_count: Option<i32>,
27 pub prompt_token_count: i32,
28}
29
30impl GetTokenUsage for PartialUsage {
31 fn token_usage(&self) -> Option<crate::completion::Usage> {
32 let mut usage = crate::completion::Usage::new();
33
34 usage.input_tokens = self.prompt_token_count as u64;
35 usage.output_tokens = (self.cached_content_token_count.unwrap_or_default()
36 + self.candidates_token_count.unwrap_or_default()
37 + self.thoughts_token_count.unwrap_or_default()) as u64;
38 usage.total_tokens = usage.input_tokens + usage.output_tokens;
39
40 Some(usage)
41 }
42}
43
44#[derive(Debug, Deserialize)]
45#[serde(rename_all = "camelCase")]
46pub struct StreamGenerateContentResponse {
47 pub candidates: Vec<ContentCandidate>,
49 pub model_version: Option<String>,
50 pub usage_metadata: Option<PartialUsage>,
51}
52
53#[derive(Clone, Debug, Serialize, Deserialize)]
54pub struct StreamingCompletionResponse {
55 pub usage_metadata: PartialUsage,
56}
57
58impl GetTokenUsage for StreamingCompletionResponse {
59 fn token_usage(&self) -> Option<crate::completion::Usage> {
60 let mut usage = crate::completion::Usage::new();
61 usage.total_tokens = self.usage_metadata.total_token_count as u64;
62 usage.output_tokens = self
63 .usage_metadata
64 .candidates_token_count
65 .map(|x| x as u64)
66 .unwrap_or(0);
67 usage.input_tokens = self.usage_metadata.prompt_token_count as u64;
68 Some(usage)
69 }
70}
71
72impl CompletionModel<reqwest::Client> {
73 pub(crate) async fn stream(
74 &self,
75 completion_request: CompletionRequest,
76 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
77 {
78 let span = if tracing::Span::current().is_disabled() {
79 info_span!(
80 target: "rig::completions",
81 "chat_streaming",
82 gen_ai.operation.name = "chat_streaming",
83 gen_ai.provider.name = "gcp.gemini",
84 gen_ai.request.model = self.model,
85 gen_ai.system_instructions = &completion_request.preamble,
86 gen_ai.response.id = tracing::field::Empty,
87 gen_ai.response.model = self.model,
88 gen_ai.usage.output_tokens = tracing::field::Empty,
89 gen_ai.usage.input_tokens = tracing::field::Empty,
90 gen_ai.input.messages = tracing::field::Empty,
91 gen_ai.output.messages = tracing::field::Empty,
92 )
93 } else {
94 tracing::Span::current()
95 };
96 let request = create_request_body(completion_request)?;
97
98 span.record_model_input(&request.contents);
99
100 tracing::debug!(
101 "Sending completion request to Gemini API {}",
102 serde_json::to_string_pretty(&request)?
103 );
104
105 let mut event_source = self
107 .client
108 .post_sse(&format!(
109 "/v1beta/models/{}:streamGenerateContent",
110 self.model
111 ))
112 .json(&request)
113 .eventsource()
114 .expect("Cloning request must always succeed");
115
116 let stream = stream! {
117 let mut text_response = String::new();
118 let mut model_outputs: Vec<Part> = Vec::new();
119 while let Some(event_result) = event_source.next().await {
120 match event_result {
121 Ok(Event::Open) => {
122 tracing::trace!("SSE connection opened");
123 continue;
124 }
125 Ok(Event::Message(message)) => {
126 if message.data.trim().is_empty() {
128 continue;
129 }
130
131 let data = match serde_json::from_str::<StreamGenerateContentResponse>(&message.data) {
132 Ok(d) => d,
133 Err(error) => {
134 tracing::error!(?error, message = message.data, "Failed to parse SSE message");
135 continue;
136 }
137 };
138
139 let Some(choice) = data.candidates.first() else {
141 tracing::debug!("There is no content candidate");
142 continue;
143 };
144
145 match choice.content.parts.first() {
146 Some(Part {
147 part: PartKind::Text(text),
148 thought: Some(true),
149 ..
150 }) => {
151 yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning: text.clone(), id: None });
152 },
153 Some(Part {
154 part: PartKind::Text(text),
155 ..
156 }) => {
157 text_response += text;
158 yield Ok(streaming::RawStreamingChoice::Message(text.clone()));
159 },
160 Some(Part {
161 part: PartKind::FunctionCall(function_call),
162 ..
163 }) => {
164 model_outputs.push(choice.content.parts.first().cloned().expect("This should never fail"));
165 yield Ok(streaming::RawStreamingChoice::ToolCall {
166 name: function_call.name.clone(),
167 id: function_call.name.clone(),
168 arguments: function_call.args.clone(),
169 call_id: None
170 });
171 },
172 Some(part) => {
173 tracing::warn!(?part, "Unsupported response type with streaming");
174 }
175 None => tracing::trace!(reason = ?choice.finish_reason, "There is no part in the streaming content"),
176 }
177
178 if choice.finish_reason.is_some() {
180 if !text_response.is_empty() {
181 model_outputs.push(Part { thought: None, thought_signature: None, part: PartKind::Text(text_response), additional_params: None });
182 }
183 let span = tracing::Span::current();
184 span.record_model_output(&model_outputs);
185 span.record_token_usage(&data.usage_metadata);
186 yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
187 usage_metadata: data.usage_metadata.unwrap_or_default()
188 }));
189 break;
190 }
191 }
192 Err(reqwest_eventsource::Error::StreamEnded) => {
193 break;
194 }
195 Err(error) => {
196 tracing::error!(?error, "SSE error");
197 yield Err(CompletionError::ResponseError(error.to_string()));
198 break;
199 }
200 }
201 }
202
203 event_source.close();
205 };
206
207 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
208 stream,
209 )))
210 }
211}