firebase_rs_sdk/ai/
requests.rs

1use std::time::Duration;
2
3use serde_json::Value;
4use url::Url;
5
6use crate::ai::backend::Backend;
7use crate::ai::constants::{
8    DEFAULT_API_VERSION, DEFAULT_DOMAIN, DEFAULT_FETCH_TIMEOUT_MS, LANGUAGE_TAG, PACKAGE_VERSION,
9};
10use crate::ai::error::{AiError, AiErrorCode, AiResult};
11
12/// Internal settings required to build REST requests.
13///
14/// Mirrors `ApiSettings` from `packages/ai/src/types/internal.ts`.
15#[derive(Clone, Debug)]
16pub(crate) struct ApiSettings {
17    pub api_key: String,
18    pub project: String,
19    pub app_id: String,
20    pub backend: Backend,
21    pub automatic_data_collection_enabled: bool,
22    pub app_check_token: Option<String>,
23    pub app_check_heartbeat: Option<String>,
24    pub auth_token: Option<String>,
25}
26
27impl ApiSettings {
28    pub fn new(
29        api_key: String,
30        project: String,
31        app_id: String,
32        backend: Backend,
33        automatic_data_collection_enabled: bool,
34        app_check_token: Option<String>,
35        app_check_heartbeat: Option<String>,
36        auth_token: Option<String>,
37    ) -> Self {
38        Self {
39            api_key,
40            project,
41            app_id,
42            backend,
43            automatic_data_collection_enabled,
44            app_check_token,
45            app_check_heartbeat,
46            auth_token,
47        }
48    }
49}
50
51/// Additional per-request options.
52///
53/// Ported from `packages/ai/src/types/requests.ts` (`RequestOptions`).
54#[derive(Clone, Debug, Default, PartialEq)]
55pub struct RequestOptions {
56    /// Optional request timeout. Defaults to 180 seconds when omitted.
57    pub timeout: Option<Duration>,
58    /// Optional base URL overriding the default Firebase AI endpoint.
59    pub base_url: Option<String>,
60}
61
62#[derive(Clone, Copy, Debug, PartialEq, Eq)]
63pub enum HttpMethod {
64    Post,
65}
66
67/// Prepared HTTP request ready to be executed by an HTTP client.
68#[derive(Clone, Debug, PartialEq)]
69pub struct PreparedRequest {
70    pub method: HttpMethod,
71    pub url: Url,
72    pub headers: Vec<(String, String)>,
73    pub body: Value,
74    pub timeout: Duration,
75}
76
77impl PreparedRequest {
78    /// Returns the value of a header if it exists.
79    pub fn header(&self, name: &str) -> Option<&str> {
80        self.headers
81            .iter()
82            .find(|(header_name, _)| header_name.eq_ignore_ascii_case(name))
83            .map(|(_, value)| value.as_str())
84    }
85
86    /// Converts the prepared request into a `reqwest::RequestBuilder`.
87    pub fn into_reqwest(
88        self,
89        client: &reqwest::Client,
90    ) -> Result<reqwest::RequestBuilder, AiError> {
91        use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
92
93        let mut headers = HeaderMap::new();
94        for (name, value) in &self.headers {
95            let header_name = HeaderName::from_bytes(name.as_bytes()).map_err(|err| {
96                AiError::new(
97                    AiErrorCode::InvalidArgument,
98                    format!("Invalid header name '{name}': {err}"),
99                    None,
100                )
101            })?;
102            let header_value = HeaderValue::from_str(value).map_err(|err| {
103                AiError::new(
104                    AiErrorCode::InvalidArgument,
105                    format!("Invalid header value for '{name}': {err}"),
106                    None,
107                )
108            })?;
109            headers.insert(header_name, header_value);
110        }
111
112        let builder = match self.method {
113            HttpMethod::Post => client.post(self.url.clone()),
114        }
115        .headers(headers)
116        .body(self.body.to_string());
117
118        #[cfg(not(target_arch = "wasm32"))]
119        let builder = builder.timeout(self.timeout);
120
121        Ok(builder)
122    }
123}
124
125#[derive(Clone, Debug)]
126pub(crate) struct RequestFactory {
127    settings: ApiSettings,
128}
129
130impl RequestFactory {
131    pub fn new(settings: ApiSettings) -> Self {
132        Self { settings }
133    }
134
135    pub fn construct_request(
136        &self,
137        model: &str,
138        task: Task,
139        stream: bool,
140        body: Value,
141        request_options: Option<RequestOptions>,
142    ) -> AiResult<PreparedRequest> {
143        let options = request_options.unwrap_or_default();
144        let mut url = self.compose_base_url(&options)?;
145        let trimmed_model = model.trim_start_matches('/');
146        let model_path = match &self.settings.backend {
147            Backend::GoogleAi(_) => format!("projects/{}/{}", self.settings.project, trimmed_model),
148            Backend::VertexAi(inner) => format!(
149                "projects/{}/locations/{}/{}",
150                self.settings.project,
151                inner.location(),
152                trimmed_model
153            ),
154        };
155        let path = format!(
156            "/{}/{model_path}:{}",
157            DEFAULT_API_VERSION,
158            task.as_operation()
159        );
160        url.set_path(&path);
161        if stream {
162            url.query_pairs_mut().append_pair("alt", "sse");
163        } else {
164            url.set_query(None);
165        }
166
167        let timeout = options
168            .timeout
169            .unwrap_or_else(|| Duration::from_millis(DEFAULT_FETCH_TIMEOUT_MS));
170        let headers = self.build_headers();
171
172        Ok(PreparedRequest {
173            method: HttpMethod::Post,
174            url,
175            headers,
176            body,
177            timeout,
178        })
179    }
180
181    fn compose_base_url(&self, options: &RequestOptions) -> AiResult<Url> {
182        let base = options
183            .base_url
184            .as_ref()
185            .map(|value| value.as_str())
186            .unwrap_or(DEFAULT_DOMAIN);
187        let url = if base.starts_with("http://") || base.starts_with("https://") {
188            base.to_string()
189        } else {
190            format!("https://{base}")
191        };
192        Url::parse(&url).map_err(|err| {
193            AiError::new(
194                AiErrorCode::InvalidArgument,
195                format!("Invalid base URL '{url}': {err}"),
196                None,
197            )
198        })
199    }
200
201    fn build_headers(&self) -> Vec<(String, String)> {
202        let mut headers = Vec::with_capacity(5);
203        headers.push(("Content-Type".into(), "application/json".into()));
204        let client_header = format!(
205            "{}/{} fire/{}",
206            LANGUAGE_TAG, PACKAGE_VERSION, PACKAGE_VERSION
207        );
208        headers.push(("x-goog-api-client".into(), client_header));
209        headers.push(("x-goog-api-key".into(), self.settings.api_key.clone()));
210
211        if self.settings.automatic_data_collection_enabled && !self.settings.app_id.is_empty() {
212            headers.push(("X-Firebase-AppId".into(), self.settings.app_id.clone()));
213        }
214
215        if let Some(token) = &self.settings.app_check_token {
216            headers.push(("X-Firebase-AppCheck".into(), token.clone()));
217        }
218
219        if let Some(header) = &self.settings.app_check_heartbeat {
220            headers.push(("X-Firebase-Client".into(), header.clone()));
221        }
222
223        if let Some(token) = &self.settings.auth_token {
224            headers.push(("Authorization".into(), format!("Firebase {token}")));
225        }
226
227        headers
228    }
229}
230
231/// High-level tasks supported by the request factory.
232///
233/// Mirrors the `Task` enum in `packages/ai/src/requests/request.ts`.
234#[derive(Clone, Copy, Debug, PartialEq, Eq)]
235pub enum Task {
236    GenerateContent,
237    #[allow(dead_code)]
238    CountTokens,
239    #[allow(dead_code)]
240    Predict,
241}
242
243impl Task {
244    pub fn as_operation(&self) -> &'static str {
245        match self {
246            Task::GenerateContent => "generateContent",
247            Task::CountTokens => "countTokens",
248            Task::Predict => "predict",
249        }
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    fn settings_with_backend(backend: Backend) -> ApiSettings {
258        ApiSettings::new(
259            "test-key".into(),
260            "test-project".into(),
261            "1:123:web:abc".into(),
262            backend,
263            true,
264            None,
265            None,
266            None,
267        )
268    }
269
270    #[test]
271    fn constructs_google_ai_url() {
272        let factory = RequestFactory::new(settings_with_backend(Backend::google_ai()));
273        let req = factory
274            .construct_request(
275                "models/gemini-1.5-flash",
276                Task::GenerateContent,
277                false,
278                serde_json::json!({"contents": []}),
279                None,
280            )
281            .unwrap();
282        assert_eq!(
283            req.url.as_str(),
284            "https://firebasevertexai.googleapis.com/v1beta/projects/test-project/models/gemini-1.5-flash:generateContent"
285        );
286        assert_eq!(req.header("x-goog-api-key"), Some("test-key"));
287        assert_eq!(req.timeout, Duration::from_millis(DEFAULT_FETCH_TIMEOUT_MS));
288    }
289
290    #[test]
291    fn constructs_vertex_ai_url_with_base_override() {
292        let factory =
293            RequestFactory::new(settings_with_backend(Backend::vertex_ai("europe-west4")));
294        let options = RequestOptions {
295            timeout: Some(Duration::from_secs(10)),
296            base_url: Some("https://example.com".into()),
297        };
298        let req = factory
299            .construct_request(
300                "models/gemini-pro",
301                Task::CountTokens,
302                false,
303                serde_json::json!({"contents": []}),
304                Some(options),
305            )
306            .unwrap();
307        assert_eq!(
308            req.url.as_str(),
309            "https://example.com/v1beta/projects/test-project/locations/europe-west4/models/gemini-pro:countTokens"
310        );
311        assert_eq!(req.timeout, Duration::from_secs(10));
312    }
313
314    #[test]
315    fn invalid_base_url_returns_error() {
316        let factory = RequestFactory::new(settings_with_backend(Backend::google_ai()));
317        let err = factory
318            .construct_request(
319                "models/test",
320                Task::Predict,
321                false,
322                serde_json::json!({}),
323                Some(RequestOptions {
324                    timeout: None,
325                    base_url: Some("://bad".into()),
326                }),
327            )
328            .unwrap_err();
329        assert_eq!(err.code(), AiErrorCode::InvalidArgument);
330    }
331}