Skip to main content

bitrouter_google/generate_content/
provider.rs

1use std::collections::HashMap;
2
3use bitrouter_core::{
4    errors::{BitrouterError, Result},
5    models::{
6        language::{
7            call_options::LanguageModelCallOptions,
8            generate_result::LanguageModelGenerateResult,
9            language_model::LanguageModel,
10            stream_result::{
11                LanguageModelStreamResult, LanguageModelStreamResultRequest,
12                LanguageModelStreamResultResponse,
13            },
14        },
15        shared::types::JsonValue,
16    },
17};
18use regex::Regex;
19use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
20use tokio::{select, sync::mpsc};
21use tokio_stream::{StreamExt, wrappers::ReceiverStream};
22use tokio_util::sync::CancellationToken;
23
24use super::api::{ByteStream, drive_sse_stream, parse_google_error};
25use super::types::{
26    GOOGLE_PROVIDER_NAME, GoogleGenerateContentRequest, GoogleGenerateContentResponse,
27};
28
29const GOOGLE_DEFAULT_BASE_URL: &str = "https://generativelanguage.googleapis.com";
30
31#[derive(Debug, Clone)]
32pub struct GoogleConfig {
33    pub api_key: String,
34    pub base_url: String,
35    pub default_headers: HeaderMap,
36}
37
38impl GoogleConfig {
39    pub fn new(api_key: impl Into<String>) -> Self {
40        Self {
41            api_key: api_key.into(),
42            base_url: GOOGLE_DEFAULT_BASE_URL.to_owned(),
43            default_headers: HeaderMap::new(),
44        }
45    }
46}
47
48#[derive(Clone)]
49pub struct GoogleGenerativeAiModel {
50    model_id: String,
51    client: reqwest_middleware::ClientWithMiddleware,
52    config: GoogleConfig,
53    supported_urls: HashMap<String, Regex>,
54}
55
56impl GoogleGenerativeAiModel {
57    pub fn new(model_id: impl Into<String>, api_key: impl Into<String>) -> Self {
58        let client = reqwest_middleware::ClientBuilder::new(reqwest::Client::new()).build();
59        Self::with_client(model_id, client, GoogleConfig::new(api_key))
60    }
61
62    pub fn with_client(
63        model_id: impl Into<String>,
64        client: reqwest_middleware::ClientWithMiddleware,
65        config: GoogleConfig,
66    ) -> Self {
67        Self {
68            model_id: model_id.into(),
69            client,
70            config,
71            supported_urls: HashMap::new(),
72        }
73    }
74
75    async fn generate_impl(
76        &self,
77        options: LanguageModelCallOptions,
78    ) -> Result<LanguageModelGenerateResult> {
79        let request = GoogleGenerateContentRequest::from_call_options(&self.model_id, &options)?;
80        let request_body = serde_json::to_value(&request).map_err(|error| {
81            BitrouterError::invalid_request(
82                Some(GOOGLE_PROVIDER_NAME),
83                format!("failed to serialize generateContent request: {error}"),
84                None,
85            )
86        })?;
87        let (builder, request_headers) =
88            self.request_builder(&request_body, &options.headers, false)?;
89        let response = self
90            .send_request(builder, options.abort_signal.clone(), "generateContent")
91            .await?;
92
93        let response_headers = response.headers().clone();
94        if !response.status().is_success() {
95            return Err(self.decode_error_response(response).await);
96        }
97
98        let response_body: JsonValue = self
99            .await_with_cancellation(
100                options.abort_signal.clone(),
101                response.json(),
102                |error| {
103                    BitrouterError::response_decode(
104                        Some(GOOGLE_PROVIDER_NAME),
105                        format!("failed to decode generateContent response body: {error}"),
106                        None,
107                    )
108                },
109                || {
110                    BitrouterError::cancelled(
111                        Some(GOOGLE_PROVIDER_NAME),
112                        "generateContent response decoding was cancelled",
113                    )
114                },
115            )
116            .await?;
117        let gen_response: GoogleGenerateContentResponse =
118            serde_json::from_value(response_body.clone()).map_err(|error| {
119                BitrouterError::response_decode(
120                    Some(GOOGLE_PROVIDER_NAME),
121                    format!("failed to parse generateContent response: {error}"),
122                    Some(response_body.clone()),
123                )
124            })?;
125
126        gen_response.into_generate_result(
127            Some(request_headers),
128            request_body,
129            Some(response_headers),
130            response_body,
131        )
132    }
133
134    async fn stream_impl(
135        &self,
136        options: LanguageModelCallOptions,
137    ) -> Result<LanguageModelStreamResult> {
138        let request = GoogleGenerateContentRequest::from_call_options(&self.model_id, &options)?;
139        let request_body = serde_json::to_value(&request).map_err(|error| {
140            BitrouterError::invalid_request(
141                Some(GOOGLE_PROVIDER_NAME),
142                format!("failed to serialize streaming generateContent request: {error}"),
143                None,
144            )
145        })?;
146        let (builder, request_headers) =
147            self.request_builder(&request_body, &options.headers, true)?;
148        let response = self
149            .send_request(
150                builder,
151                options.abort_signal.clone(),
152                "streaming generateContent",
153            )
154            .await?;
155        let response_headers = response.headers().clone();
156        if !response.status().is_success() {
157            return Err(self.decode_error_response(response).await);
158        }
159
160        let include_raw_chunks = options.include_raw_chunks.unwrap_or(false);
161        let abort_signal = options.abort_signal.clone();
162        let bytes_stream: ByteStream = Box::pin(
163            response
164                .bytes_stream()
165                .map(|r| r.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)),
166        );
167        let (sender, receiver) = mpsc::channel(32);
168        tokio::spawn(drive_sse_stream(
169            bytes_stream,
170            abort_signal,
171            sender,
172            include_raw_chunks,
173        ));
174        let stream = Box::pin(ReceiverStream::new(receiver));
175
176        Ok(LanguageModelStreamResult {
177            stream,
178            request: Some(LanguageModelStreamResultRequest {
179                headers: Some(request_headers),
180                body: Some(request_body),
181            }),
182            response: Some(LanguageModelStreamResultResponse {
183                headers: Some(response_headers),
184            }),
185        })
186    }
187
188    fn request_builder(
189        &self,
190        request_body: &JsonValue,
191        extra_headers: &Option<HeaderMap>,
192        stream: bool,
193    ) -> Result<(reqwest_middleware::RequestBuilder, HeaderMap)> {
194        let action = if stream {
195            "streamGenerateContent?alt=sse"
196        } else {
197            "generateContent"
198        };
199        let endpoint = format!(
200            "{}/v1beta/models/{}:{}",
201            self.config.base_url.trim_end_matches('/'),
202            self.model_id,
203            action,
204        );
205        let headers = self.build_headers(extra_headers)?;
206        let request_headers = headers.clone();
207        let builder = self
208            .client
209            .post(endpoint)
210            .query(&[("key", &self.config.api_key)])
211            .headers(headers)
212            .json(request_body);
213
214        Ok((builder, request_headers))
215    }
216
217    fn build_headers(&self, extra_headers: &Option<HeaderMap>) -> Result<HeaderMap> {
218        let mut headers = HeaderMap::new();
219        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
220
221        for (name, value) in &self.config.default_headers {
222            headers.insert(name, value.clone());
223        }
224
225        if let Some(extra_headers) = extra_headers {
226            for (name, value) in extra_headers {
227                headers.insert(name, value.clone());
228            }
229        }
230
231        Ok(headers)
232    }
233
234    async fn decode_error_response(&self, response: reqwest::Response) -> BitrouterError {
235        let status = response.status();
236        let request_id = response
237            .headers()
238            .get("x-request-id")
239            .and_then(|value| value.to_str().ok())
240            .map(str::to_owned);
241        let body = match response.text().await {
242            Ok(text) if text.trim().is_empty() => None,
243            Ok(text) => serde_json::from_str::<JsonValue>(&text)
244                .ok()
245                .or(Some(JsonValue::String(text))),
246            Err(_) => None,
247        };
248
249        parse_google_error(status.as_u16(), request_id, body)
250    }
251
252    async fn send_request(
253        &self,
254        builder: reqwest_middleware::RequestBuilder,
255        abort_signal: Option<CancellationToken>,
256        operation: &str,
257    ) -> Result<reqwest::Response> {
258        self.await_with_cancellation(
259            abort_signal,
260            builder.send(),
261            |error| {
262                BitrouterError::transport(
263                    Some(GOOGLE_PROVIDER_NAME),
264                    format!("failed to send {operation} request: {error}"),
265                )
266            },
267            || {
268                BitrouterError::cancelled(
269                    Some(GOOGLE_PROVIDER_NAME),
270                    format!("{operation} request was cancelled"),
271                )
272            },
273        )
274        .await
275    }
276
277    async fn await_with_cancellation<F, T, E, M, C>(
278        &self,
279        abort_signal: Option<CancellationToken>,
280        future: F,
281        map_error: M,
282        cancelled: C,
283    ) -> Result<T>
284    where
285        F: std::future::Future<Output = std::result::Result<T, E>>,
286        M: FnOnce(E) -> BitrouterError,
287        C: FnOnce() -> BitrouterError,
288    {
289        if let Some(token) = abort_signal {
290            select! {
291                _ = token.cancelled() => Err(cancelled()),
292                result = future => result.map_err(map_error),
293            }
294        } else {
295            future.await.map_err(map_error)
296        }
297    }
298}
299
300impl LanguageModel for GoogleGenerativeAiModel {
301    fn provider_name(&self) -> &str {
302        GOOGLE_PROVIDER_NAME
303    }
304
305    fn model_id(&self) -> &str {
306        &self.model_id
307    }
308
309    fn supported_urls(&self) -> impl Future<Output = HashMap<String, Regex>> {
310        let supported_urls = self.supported_urls.clone();
311        async move { supported_urls }
312    }
313
314    async fn generate(
315        &self,
316        options: LanguageModelCallOptions,
317    ) -> Result<LanguageModelGenerateResult> {
318        self.generate_impl(options).await
319    }
320
321    async fn stream(&self, options: LanguageModelCallOptions) -> Result<LanguageModelStreamResult> {
322        self.stream_impl(options).await
323    }
324}