bitrouter_google/generate_content/
provider.rs1use std::collections::HashMap;
2
3use bitrouter_core::{
4 errors::{BitrouterError, Result},
5 models::{
6 language::{
7 call_options::LanguageModelCallOptions,
8 generate_result::LanguageModelGenerateResult,
9 language_model::LanguageModel,
10 stream_result::{
11 LanguageModelStreamResult, LanguageModelStreamResultRequest,
12 LanguageModelStreamResultResponse,
13 },
14 },
15 shared::types::JsonValue,
16 },
17};
18use regex::Regex;
19use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
20use tokio::{select, sync::mpsc};
21use tokio_stream::{StreamExt, wrappers::ReceiverStream};
22use tokio_util::sync::CancellationToken;
23
24use super::api::{ByteStream, drive_sse_stream, parse_google_error};
25use super::types::{
26 GOOGLE_PROVIDER_NAME, GoogleGenerateContentRequest, GoogleGenerateContentResponse,
27};
28
29const GOOGLE_DEFAULT_BASE_URL: &str = "https://generativelanguage.googleapis.com";
30
31#[derive(Debug, Clone)]
32pub struct GoogleConfig {
33 pub api_key: String,
34 pub base_url: String,
35 pub default_headers: HeaderMap,
36}
37
38impl GoogleConfig {
39 pub fn new(api_key: impl Into<String>) -> Self {
40 Self {
41 api_key: api_key.into(),
42 base_url: GOOGLE_DEFAULT_BASE_URL.to_owned(),
43 default_headers: HeaderMap::new(),
44 }
45 }
46}
47
48#[derive(Clone)]
49pub struct GoogleGenerativeAiModel {
50 model_id: String,
51 client: reqwest_middleware::ClientWithMiddleware,
52 config: GoogleConfig,
53 supported_urls: HashMap<String, Regex>,
54}
55
56impl GoogleGenerativeAiModel {
57 pub fn new(model_id: impl Into<String>, api_key: impl Into<String>) -> Self {
58 let client = reqwest_middleware::ClientBuilder::new(reqwest::Client::new()).build();
59 Self::with_client(model_id, client, GoogleConfig::new(api_key))
60 }
61
62 pub fn with_client(
63 model_id: impl Into<String>,
64 client: reqwest_middleware::ClientWithMiddleware,
65 config: GoogleConfig,
66 ) -> Self {
67 Self {
68 model_id: model_id.into(),
69 client,
70 config,
71 supported_urls: HashMap::new(),
72 }
73 }
74
75 async fn generate_impl(
76 &self,
77 options: LanguageModelCallOptions,
78 ) -> Result<LanguageModelGenerateResult> {
79 let request = GoogleGenerateContentRequest::from_call_options(&self.model_id, &options)?;
80 let request_body = serde_json::to_value(&request).map_err(|error| {
81 BitrouterError::invalid_request(
82 Some(GOOGLE_PROVIDER_NAME),
83 format!("failed to serialize generateContent request: {error}"),
84 None,
85 )
86 })?;
87 let (builder, request_headers) =
88 self.request_builder(&request_body, &options.headers, false)?;
89 let response = self
90 .send_request(builder, options.abort_signal.clone(), "generateContent")
91 .await?;
92
93 let response_headers = response.headers().clone();
94 if !response.status().is_success() {
95 return Err(self.decode_error_response(response).await);
96 }
97
98 let response_body: JsonValue = self
99 .await_with_cancellation(
100 options.abort_signal.clone(),
101 response.json(),
102 |error| {
103 BitrouterError::response_decode(
104 Some(GOOGLE_PROVIDER_NAME),
105 format!("failed to decode generateContent response body: {error}"),
106 None,
107 )
108 },
109 || {
110 BitrouterError::cancelled(
111 Some(GOOGLE_PROVIDER_NAME),
112 "generateContent response decoding was cancelled",
113 )
114 },
115 )
116 .await?;
117 let gen_response: GoogleGenerateContentResponse =
118 serde_json::from_value(response_body.clone()).map_err(|error| {
119 BitrouterError::response_decode(
120 Some(GOOGLE_PROVIDER_NAME),
121 format!("failed to parse generateContent response: {error}"),
122 Some(response_body.clone()),
123 )
124 })?;
125
126 gen_response.into_generate_result(
127 Some(request_headers),
128 request_body,
129 Some(response_headers),
130 response_body,
131 )
132 }
133
134 async fn stream_impl(
135 &self,
136 options: LanguageModelCallOptions,
137 ) -> Result<LanguageModelStreamResult> {
138 let request = GoogleGenerateContentRequest::from_call_options(&self.model_id, &options)?;
139 let request_body = serde_json::to_value(&request).map_err(|error| {
140 BitrouterError::invalid_request(
141 Some(GOOGLE_PROVIDER_NAME),
142 format!("failed to serialize streaming generateContent request: {error}"),
143 None,
144 )
145 })?;
146 let (builder, request_headers) =
147 self.request_builder(&request_body, &options.headers, true)?;
148 let response = self
149 .send_request(
150 builder,
151 options.abort_signal.clone(),
152 "streaming generateContent",
153 )
154 .await?;
155 let response_headers = response.headers().clone();
156 if !response.status().is_success() {
157 return Err(self.decode_error_response(response).await);
158 }
159
160 let include_raw_chunks = options.include_raw_chunks.unwrap_or(false);
161 let abort_signal = options.abort_signal.clone();
162 let bytes_stream: ByteStream = Box::pin(
163 response
164 .bytes_stream()
165 .map(|r| r.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)),
166 );
167 let (sender, receiver) = mpsc::channel(32);
168 tokio::spawn(drive_sse_stream(
169 bytes_stream,
170 abort_signal,
171 sender,
172 include_raw_chunks,
173 ));
174 let stream = Box::pin(ReceiverStream::new(receiver));
175
176 Ok(LanguageModelStreamResult {
177 stream,
178 request: Some(LanguageModelStreamResultRequest {
179 headers: Some(request_headers),
180 body: Some(request_body),
181 }),
182 response: Some(LanguageModelStreamResultResponse {
183 headers: Some(response_headers),
184 }),
185 })
186 }
187
188 fn request_builder(
189 &self,
190 request_body: &JsonValue,
191 extra_headers: &Option<HeaderMap>,
192 stream: bool,
193 ) -> Result<(reqwest_middleware::RequestBuilder, HeaderMap)> {
194 let action = if stream {
195 "streamGenerateContent?alt=sse"
196 } else {
197 "generateContent"
198 };
199 let endpoint = format!(
200 "{}/v1beta/models/{}:{}",
201 self.config.base_url.trim_end_matches('/'),
202 self.model_id,
203 action,
204 );
205 let headers = self.build_headers(extra_headers)?;
206 let request_headers = headers.clone();
207 let builder = self
208 .client
209 .post(endpoint)
210 .query(&[("key", &self.config.api_key)])
211 .headers(headers)
212 .json(request_body);
213
214 Ok((builder, request_headers))
215 }
216
217 fn build_headers(&self, extra_headers: &Option<HeaderMap>) -> Result<HeaderMap> {
218 let mut headers = HeaderMap::new();
219 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
220
221 for (name, value) in &self.config.default_headers {
222 headers.insert(name, value.clone());
223 }
224
225 if let Some(extra_headers) = extra_headers {
226 for (name, value) in extra_headers {
227 headers.insert(name, value.clone());
228 }
229 }
230
231 Ok(headers)
232 }
233
234 async fn decode_error_response(&self, response: reqwest::Response) -> BitrouterError {
235 let status = response.status();
236 let request_id = response
237 .headers()
238 .get("x-request-id")
239 .and_then(|value| value.to_str().ok())
240 .map(str::to_owned);
241 let body = match response.text().await {
242 Ok(text) if text.trim().is_empty() => None,
243 Ok(text) => serde_json::from_str::<JsonValue>(&text)
244 .ok()
245 .or(Some(JsonValue::String(text))),
246 Err(_) => None,
247 };
248
249 parse_google_error(status.as_u16(), request_id, body)
250 }
251
252 async fn send_request(
253 &self,
254 builder: reqwest_middleware::RequestBuilder,
255 abort_signal: Option<CancellationToken>,
256 operation: &str,
257 ) -> Result<reqwest::Response> {
258 self.await_with_cancellation(
259 abort_signal,
260 builder.send(),
261 |error| {
262 BitrouterError::transport(
263 Some(GOOGLE_PROVIDER_NAME),
264 format!("failed to send {operation} request: {error}"),
265 )
266 },
267 || {
268 BitrouterError::cancelled(
269 Some(GOOGLE_PROVIDER_NAME),
270 format!("{operation} request was cancelled"),
271 )
272 },
273 )
274 .await
275 }
276
277 async fn await_with_cancellation<F, T, E, M, C>(
278 &self,
279 abort_signal: Option<CancellationToken>,
280 future: F,
281 map_error: M,
282 cancelled: C,
283 ) -> Result<T>
284 where
285 F: std::future::Future<Output = std::result::Result<T, E>>,
286 M: FnOnce(E) -> BitrouterError,
287 C: FnOnce() -> BitrouterError,
288 {
289 if let Some(token) = abort_signal {
290 select! {
291 _ = token.cancelled() => Err(cancelled()),
292 result = future => result.map_err(map_error),
293 }
294 } else {
295 future.await.map_err(map_error)
296 }
297 }
298}
299
300impl LanguageModel for GoogleGenerativeAiModel {
301 fn provider_name(&self) -> &str {
302 GOOGLE_PROVIDER_NAME
303 }
304
305 fn model_id(&self) -> &str {
306 &self.model_id
307 }
308
309 fn supported_urls(&self) -> impl Future<Output = HashMap<String, Regex>> {
310 let supported_urls = self.supported_urls.clone();
311 async move { supported_urls }
312 }
313
314 async fn generate(
315 &self,
316 options: LanguageModelCallOptions,
317 ) -> Result<LanguageModelGenerateResult> {
318 self.generate_impl(options).await
319 }
320
321 async fn stream(&self, options: LanguageModelCallOptions) -> Result<LanguageModelStreamResult> {
322 self.stream_impl(options).await
323 }
324}