Skip to main content

bitrouter_google/generate_content/
provider.rs

1use std::collections::HashMap;
2
3use bitrouter_core::{
4    errors::{BitrouterError, Result},
5    models::{
6        language::{
7            call_options::LanguageModelCallOptions,
8            generate_result::LanguageModelGenerateResult,
9            language_model::LanguageModel,
10            stream_result::{
11                LanguageModelStreamResult, LanguageModelStreamResultRequest,
12                LanguageModelStreamResultResponse,
13            },
14        },
15        shared::types::JsonValue,
16    },
17};
18use regex::Regex;
19use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
20use tokio::{select, sync::mpsc};
21use tokio_stream::{StreamExt, wrappers::ReceiverStream};
22use tokio_util::sync::CancellationToken;
23
24use super::api::{ByteStream, drive_sse_stream, parse_google_error};
25use super::types::{
26    GOOGLE_PROVIDER_NAME, GoogleGenerateContentRequest, GoogleGenerateContentResponse,
27};
28
29const GOOGLE_DEFAULT_BASE_URL: &str = "https://generativelanguage.googleapis.com";
30
31#[derive(Debug, Clone)]
32pub struct GoogleConfig {
33    pub api_key: String,
34    pub base_url: String,
35    pub default_headers: HeaderMap,
36}
37
38impl GoogleConfig {
39    pub fn new(api_key: impl Into<String>) -> Self {
40        Self {
41            api_key: api_key.into(),
42            base_url: GOOGLE_DEFAULT_BASE_URL.to_owned(),
43            default_headers: HeaderMap::new(),
44        }
45    }
46}
47
48#[derive(Clone)]
49pub struct GoogleGenerativeAiModel {
50    model_id: String,
51    client: reqwest::Client,
52    config: GoogleConfig,
53    supported_urls: HashMap<String, Regex>,
54}
55
56impl GoogleGenerativeAiModel {
57    pub fn new(model_id: impl Into<String>, api_key: impl Into<String>) -> Self {
58        Self::with_client(model_id, reqwest::Client::new(), GoogleConfig::new(api_key))
59    }
60
61    pub fn with_client(
62        model_id: impl Into<String>,
63        client: reqwest::Client,
64        config: GoogleConfig,
65    ) -> Self {
66        Self {
67            model_id: model_id.into(),
68            client,
69            config,
70            supported_urls: HashMap::new(),
71        }
72    }
73
74    async fn generate_impl(
75        &self,
76        options: LanguageModelCallOptions,
77    ) -> Result<LanguageModelGenerateResult> {
78        let request = GoogleGenerateContentRequest::from_call_options(&self.model_id, &options)?;
79        let request_body = serde_json::to_value(&request).map_err(|error| {
80            BitrouterError::invalid_request(
81                Some(GOOGLE_PROVIDER_NAME),
82                format!("failed to serialize generateContent request: {error}"),
83                None,
84            )
85        })?;
86        let (builder, request_headers) =
87            self.request_builder(&request_body, &options.headers, false)?;
88        let response = self
89            .send_request(builder, options.abort_signal.clone(), "generateContent")
90            .await?;
91
92        let response_headers = response.headers().clone();
93        if !response.status().is_success() {
94            return Err(self.decode_error_response(response).await);
95        }
96
97        let response_body: JsonValue = self
98            .await_with_cancellation(
99                options.abort_signal.clone(),
100                response.json(),
101                |error| {
102                    BitrouterError::response_decode(
103                        Some(GOOGLE_PROVIDER_NAME),
104                        format!("failed to decode generateContent response body: {error}"),
105                        None,
106                    )
107                },
108                || {
109                    BitrouterError::cancelled(
110                        Some(GOOGLE_PROVIDER_NAME),
111                        "generateContent response decoding was cancelled",
112                    )
113                },
114            )
115            .await?;
116        let gen_response: GoogleGenerateContentResponse =
117            serde_json::from_value(response_body.clone()).map_err(|error| {
118                BitrouterError::response_decode(
119                    Some(GOOGLE_PROVIDER_NAME),
120                    format!("failed to parse generateContent response: {error}"),
121                    Some(response_body.clone()),
122                )
123            })?;
124
125        gen_response.into_generate_result(
126            Some(request_headers),
127            request_body,
128            Some(response_headers),
129            response_body,
130        )
131    }
132
133    async fn stream_impl(
134        &self,
135        options: LanguageModelCallOptions,
136    ) -> Result<LanguageModelStreamResult> {
137        let request = GoogleGenerateContentRequest::from_call_options(&self.model_id, &options)?;
138        let request_body = serde_json::to_value(&request).map_err(|error| {
139            BitrouterError::invalid_request(
140                Some(GOOGLE_PROVIDER_NAME),
141                format!("failed to serialize streaming generateContent request: {error}"),
142                None,
143            )
144        })?;
145        let (builder, request_headers) =
146            self.request_builder(&request_body, &options.headers, true)?;
147        let response = self
148            .send_request(
149                builder,
150                options.abort_signal.clone(),
151                "streaming generateContent",
152            )
153            .await?;
154        let response_headers = response.headers().clone();
155        if !response.status().is_success() {
156            return Err(self.decode_error_response(response).await);
157        }
158
159        let include_raw_chunks = options.include_raw_chunks.unwrap_or(false);
160        let abort_signal = options.abort_signal.clone();
161        let bytes_stream: ByteStream = Box::pin(
162            response
163                .bytes_stream()
164                .map(|r| r.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)),
165        );
166        let (sender, receiver) = mpsc::channel(32);
167        tokio::spawn(drive_sse_stream(
168            bytes_stream,
169            abort_signal,
170            sender,
171            include_raw_chunks,
172        ));
173        let stream = Box::pin(ReceiverStream::new(receiver));
174
175        Ok(LanguageModelStreamResult {
176            stream,
177            request: Some(LanguageModelStreamResultRequest {
178                headers: Some(request_headers),
179                body: Some(request_body),
180            }),
181            response: Some(LanguageModelStreamResultResponse {
182                headers: Some(response_headers),
183            }),
184        })
185    }
186
187    fn request_builder(
188        &self,
189        request_body: &JsonValue,
190        extra_headers: &Option<HeaderMap>,
191        stream: bool,
192    ) -> Result<(reqwest::RequestBuilder, HeaderMap)> {
193        let action = if stream {
194            "streamGenerateContent?alt=sse"
195        } else {
196            "generateContent"
197        };
198        let endpoint = format!(
199            "{}/v1beta/models/{}:{}",
200            self.config.base_url.trim_end_matches('/'),
201            self.model_id,
202            action,
203        );
204        let headers = self.build_headers(extra_headers)?;
205        let request_headers = headers.clone();
206        let builder = self
207            .client
208            .post(endpoint)
209            .query(&[("key", &self.config.api_key)])
210            .headers(headers)
211            .json(request_body);
212
213        Ok((builder, request_headers))
214    }
215
216    fn build_headers(&self, extra_headers: &Option<HeaderMap>) -> Result<HeaderMap> {
217        let mut headers = HeaderMap::new();
218        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
219
220        for (name, value) in &self.config.default_headers {
221            headers.insert(name, value.clone());
222        }
223
224        if let Some(extra_headers) = extra_headers {
225            for (name, value) in extra_headers {
226                headers.insert(name, value.clone());
227            }
228        }
229
230        Ok(headers)
231    }
232
233    async fn decode_error_response(&self, response: reqwest::Response) -> BitrouterError {
234        let status = response.status();
235        let request_id = response
236            .headers()
237            .get("x-request-id")
238            .and_then(|value| value.to_str().ok())
239            .map(str::to_owned);
240        let body = match response.text().await {
241            Ok(text) if text.trim().is_empty() => None,
242            Ok(text) => serde_json::from_str::<JsonValue>(&text)
243                .ok()
244                .or(Some(JsonValue::String(text))),
245            Err(_) => None,
246        };
247
248        parse_google_error(status.as_u16(), request_id, body)
249    }
250
251    async fn send_request(
252        &self,
253        builder: reqwest::RequestBuilder,
254        abort_signal: Option<CancellationToken>,
255        operation: &str,
256    ) -> Result<reqwest::Response> {
257        self.await_with_cancellation(
258            abort_signal,
259            builder.send(),
260            |error| {
261                BitrouterError::transport(
262                    Some(GOOGLE_PROVIDER_NAME),
263                    format!("failed to send {operation} request: {error}"),
264                )
265            },
266            || {
267                BitrouterError::cancelled(
268                    Some(GOOGLE_PROVIDER_NAME),
269                    format!("{operation} request was cancelled"),
270                )
271            },
272        )
273        .await
274    }
275
276    async fn await_with_cancellation<F, T, E, M, C>(
277        &self,
278        abort_signal: Option<CancellationToken>,
279        future: F,
280        map_error: M,
281        cancelled: C,
282    ) -> Result<T>
283    where
284        F: std::future::Future<Output = std::result::Result<T, E>>,
285        M: FnOnce(E) -> BitrouterError,
286        C: FnOnce() -> BitrouterError,
287    {
288        if let Some(token) = abort_signal {
289            select! {
290                _ = token.cancelled() => Err(cancelled()),
291                result = future => result.map_err(map_error),
292            }
293        } else {
294            future.await.map_err(map_error)
295        }
296    }
297}
298
299impl LanguageModel for GoogleGenerativeAiModel {
300    fn provider_name(&self) -> &str {
301        GOOGLE_PROVIDER_NAME
302    }
303
304    fn model_id(&self) -> &str {
305        &self.model_id
306    }
307
308    fn supported_urls(&self) -> impl Future<Output = HashMap<String, Regex>> {
309        let supported_urls = self.supported_urls.clone();
310        async move { supported_urls }
311    }
312
313    async fn generate(
314        &self,
315        options: LanguageModelCallOptions,
316    ) -> Result<LanguageModelGenerateResult> {
317        self.generate_impl(options).await
318    }
319
320    async fn stream(&self, options: LanguageModelCallOptions) -> Result<LanguageModelStreamResult> {
321        self.stream_impl(options).await
322    }
323}