bitrouter_google/generate_content/
provider.rs1use std::collections::HashMap;
2
3use bitrouter_core::{
4 errors::{BitrouterError, Result},
5 models::{
6 language::{
7 call_options::LanguageModelCallOptions,
8 generate_result::LanguageModelGenerateResult,
9 language_model::LanguageModel,
10 stream_result::{
11 LanguageModelStreamResult, LanguageModelStreamResultRequest,
12 LanguageModelStreamResultResponse,
13 },
14 },
15 shared::types::JsonValue,
16 },
17};
18use regex::Regex;
19use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
20use tokio::{select, sync::mpsc};
21use tokio_stream::{StreamExt, wrappers::ReceiverStream};
22use tokio_util::sync::CancellationToken;
23
24use super::api::{ByteStream, drive_sse_stream, parse_google_error};
25use super::types::{
26 GOOGLE_PROVIDER_NAME, GoogleGenerateContentRequest, GoogleGenerateContentResponse,
27};
28
29const GOOGLE_DEFAULT_BASE_URL: &str = "https://generativelanguage.googleapis.com";
30
31#[derive(Debug, Clone)]
32pub struct GoogleConfig {
33 pub api_key: String,
34 pub base_url: String,
35 pub default_headers: HeaderMap,
36}
37
38impl GoogleConfig {
39 pub fn new(api_key: impl Into<String>) -> Self {
40 Self {
41 api_key: api_key.into(),
42 base_url: GOOGLE_DEFAULT_BASE_URL.to_owned(),
43 default_headers: HeaderMap::new(),
44 }
45 }
46}
47
48#[derive(Clone)]
49pub struct GoogleGenerativeAiModel {
50 model_id: String,
51 client: reqwest::Client,
52 config: GoogleConfig,
53 supported_urls: HashMap<String, Regex>,
54}
55
56impl GoogleGenerativeAiModel {
57 pub fn new(model_id: impl Into<String>, api_key: impl Into<String>) -> Self {
58 Self::with_client(model_id, reqwest::Client::new(), GoogleConfig::new(api_key))
59 }
60
61 pub fn with_client(
62 model_id: impl Into<String>,
63 client: reqwest::Client,
64 config: GoogleConfig,
65 ) -> Self {
66 Self {
67 model_id: model_id.into(),
68 client,
69 config,
70 supported_urls: HashMap::new(),
71 }
72 }
73
74 async fn generate_impl(
75 &self,
76 options: LanguageModelCallOptions,
77 ) -> Result<LanguageModelGenerateResult> {
78 let request = GoogleGenerateContentRequest::from_call_options(&self.model_id, &options)?;
79 let request_body = serde_json::to_value(&request).map_err(|error| {
80 BitrouterError::invalid_request(
81 Some(GOOGLE_PROVIDER_NAME),
82 format!("failed to serialize generateContent request: {error}"),
83 None,
84 )
85 })?;
86 let (builder, request_headers) =
87 self.request_builder(&request_body, &options.headers, false)?;
88 let response = self
89 .send_request(builder, options.abort_signal.clone(), "generateContent")
90 .await?;
91
92 let response_headers = response.headers().clone();
93 if !response.status().is_success() {
94 return Err(self.decode_error_response(response).await);
95 }
96
97 let response_body: JsonValue = self
98 .await_with_cancellation(
99 options.abort_signal.clone(),
100 response.json(),
101 |error| {
102 BitrouterError::response_decode(
103 Some(GOOGLE_PROVIDER_NAME),
104 format!("failed to decode generateContent response body: {error}"),
105 None,
106 )
107 },
108 || {
109 BitrouterError::cancelled(
110 Some(GOOGLE_PROVIDER_NAME),
111 "generateContent response decoding was cancelled",
112 )
113 },
114 )
115 .await?;
116 let gen_response: GoogleGenerateContentResponse =
117 serde_json::from_value(response_body.clone()).map_err(|error| {
118 BitrouterError::response_decode(
119 Some(GOOGLE_PROVIDER_NAME),
120 format!("failed to parse generateContent response: {error}"),
121 Some(response_body.clone()),
122 )
123 })?;
124
125 gen_response.into_generate_result(
126 Some(request_headers),
127 request_body,
128 Some(response_headers),
129 response_body,
130 )
131 }
132
133 async fn stream_impl(
134 &self,
135 options: LanguageModelCallOptions,
136 ) -> Result<LanguageModelStreamResult> {
137 let request = GoogleGenerateContentRequest::from_call_options(&self.model_id, &options)?;
138 let request_body = serde_json::to_value(&request).map_err(|error| {
139 BitrouterError::invalid_request(
140 Some(GOOGLE_PROVIDER_NAME),
141 format!("failed to serialize streaming generateContent request: {error}"),
142 None,
143 )
144 })?;
145 let (builder, request_headers) =
146 self.request_builder(&request_body, &options.headers, true)?;
147 let response = self
148 .send_request(
149 builder,
150 options.abort_signal.clone(),
151 "streaming generateContent",
152 )
153 .await?;
154 let response_headers = response.headers().clone();
155 if !response.status().is_success() {
156 return Err(self.decode_error_response(response).await);
157 }
158
159 let include_raw_chunks = options.include_raw_chunks.unwrap_or(false);
160 let abort_signal = options.abort_signal.clone();
161 let bytes_stream: ByteStream = Box::pin(
162 response
163 .bytes_stream()
164 .map(|r| r.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)),
165 );
166 let (sender, receiver) = mpsc::channel(32);
167 tokio::spawn(drive_sse_stream(
168 bytes_stream,
169 abort_signal,
170 sender,
171 include_raw_chunks,
172 ));
173 let stream = Box::pin(ReceiverStream::new(receiver));
174
175 Ok(LanguageModelStreamResult {
176 stream,
177 request: Some(LanguageModelStreamResultRequest {
178 headers: Some(request_headers),
179 body: Some(request_body),
180 }),
181 response: Some(LanguageModelStreamResultResponse {
182 headers: Some(response_headers),
183 }),
184 })
185 }
186
187 fn request_builder(
188 &self,
189 request_body: &JsonValue,
190 extra_headers: &Option<HeaderMap>,
191 stream: bool,
192 ) -> Result<(reqwest::RequestBuilder, HeaderMap)> {
193 let action = if stream {
194 "streamGenerateContent?alt=sse"
195 } else {
196 "generateContent"
197 };
198 let endpoint = format!(
199 "{}/v1beta/models/{}:{}",
200 self.config.base_url.trim_end_matches('/'),
201 self.model_id,
202 action,
203 );
204 let headers = self.build_headers(extra_headers)?;
205 let request_headers = headers.clone();
206 let builder = self
207 .client
208 .post(endpoint)
209 .query(&[("key", &self.config.api_key)])
210 .headers(headers)
211 .json(request_body);
212
213 Ok((builder, request_headers))
214 }
215
216 fn build_headers(&self, extra_headers: &Option<HeaderMap>) -> Result<HeaderMap> {
217 let mut headers = HeaderMap::new();
218 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
219
220 for (name, value) in &self.config.default_headers {
221 headers.insert(name, value.clone());
222 }
223
224 if let Some(extra_headers) = extra_headers {
225 for (name, value) in extra_headers {
226 headers.insert(name, value.clone());
227 }
228 }
229
230 Ok(headers)
231 }
232
233 async fn decode_error_response(&self, response: reqwest::Response) -> BitrouterError {
234 let status = response.status();
235 let request_id = response
236 .headers()
237 .get("x-request-id")
238 .and_then(|value| value.to_str().ok())
239 .map(str::to_owned);
240 let body = match response.text().await {
241 Ok(text) if text.trim().is_empty() => None,
242 Ok(text) => serde_json::from_str::<JsonValue>(&text)
243 .ok()
244 .or(Some(JsonValue::String(text))),
245 Err(_) => None,
246 };
247
248 parse_google_error(status.as_u16(), request_id, body)
249 }
250
251 async fn send_request(
252 &self,
253 builder: reqwest::RequestBuilder,
254 abort_signal: Option<CancellationToken>,
255 operation: &str,
256 ) -> Result<reqwest::Response> {
257 self.await_with_cancellation(
258 abort_signal,
259 builder.send(),
260 |error| {
261 BitrouterError::transport(
262 Some(GOOGLE_PROVIDER_NAME),
263 format!("failed to send {operation} request: {error}"),
264 )
265 },
266 || {
267 BitrouterError::cancelled(
268 Some(GOOGLE_PROVIDER_NAME),
269 format!("{operation} request was cancelled"),
270 )
271 },
272 )
273 .await
274 }
275
276 async fn await_with_cancellation<F, T, E, M, C>(
277 &self,
278 abort_signal: Option<CancellationToken>,
279 future: F,
280 map_error: M,
281 cancelled: C,
282 ) -> Result<T>
283 where
284 F: std::future::Future<Output = std::result::Result<T, E>>,
285 M: FnOnce(E) -> BitrouterError,
286 C: FnOnce() -> BitrouterError,
287 {
288 if let Some(token) = abort_signal {
289 select! {
290 _ = token.cancelled() => Err(cancelled()),
291 result = future => result.map_err(map_error),
292 }
293 } else {
294 future.await.map_err(map_error)
295 }
296 }
297}
298
299impl LanguageModel for GoogleGenerativeAiModel {
300 fn provider_name(&self) -> &str {
301 GOOGLE_PROVIDER_NAME
302 }
303
304 fn model_id(&self) -> &str {
305 &self.model_id
306 }
307
308 fn supported_urls(&self) -> impl Future<Output = HashMap<String, Regex>> {
309 let supported_urls = self.supported_urls.clone();
310 async move { supported_urls }
311 }
312
313 async fn generate(
314 &self,
315 options: LanguageModelCallOptions,
316 ) -> Result<LanguageModelGenerateResult> {
317 self.generate_impl(options).await
318 }
319
320 async fn stream(&self, options: LanguageModelCallOptions) -> Result<LanguageModelStreamResult> {
321 self.stream_impl(options).await
322 }
323}