firebase_rs_sdk/ai/
requests.rs

1use std::time::Duration;
2
3use serde_json::Value;
4use url::Url;
5
6use crate::ai::backend::Backend;
7use crate::ai::constants::{
8    DEFAULT_API_VERSION, DEFAULT_DOMAIN, DEFAULT_FETCH_TIMEOUT_MS, LANGUAGE_TAG, PACKAGE_VERSION,
9};
10use crate::ai::error::{AiError, AiErrorCode, AiResult};
11
12/// Internal settings required to build REST requests.
13///
14/// Mirrors `ApiSettings` from `packages/ai/src/types/internal.ts`.
15#[derive(Clone, Debug)]
16pub(crate) struct ApiSettings {
17    pub api_key: String,
18    pub project: String,
19    pub app_id: String,
20    pub backend: Backend,
21    pub automatic_data_collection_enabled: bool,
22    pub app_check_token: Option<String>,
23    pub auth_token: Option<String>,
24}
25
26impl ApiSettings {
27    pub fn new(
28        api_key: String,
29        project: String,
30        app_id: String,
31        backend: Backend,
32        automatic_data_collection_enabled: bool,
33        app_check_token: Option<String>,
34        auth_token: Option<String>,
35    ) -> Self {
36        Self {
37            api_key,
38            project,
39            app_id,
40            backend,
41            automatic_data_collection_enabled,
42            app_check_token,
43            auth_token,
44        }
45    }
46}
47
48/// Additional per-request options.
49///
50/// Ported from `packages/ai/src/types/requests.ts` (`RequestOptions`).
51#[derive(Clone, Debug, Default, PartialEq)]
52pub struct RequestOptions {
53    /// Optional request timeout. Defaults to 180 seconds when omitted.
54    pub timeout: Option<Duration>,
55    /// Optional base URL overriding the default Firebase AI endpoint.
56    pub base_url: Option<String>,
57}
58
59#[derive(Clone, Copy, Debug, PartialEq, Eq)]
60pub enum HttpMethod {
61    Post,
62}
63
64/// Prepared HTTP request ready to be executed by an HTTP client.
65#[derive(Clone, Debug, PartialEq)]
66pub struct PreparedRequest {
67    pub method: HttpMethod,
68    pub url: Url,
69    pub headers: Vec<(String, String)>,
70    pub body: Value,
71    pub timeout: Duration,
72}
73
74impl PreparedRequest {
75    /// Returns the value of a header if it exists.
76    pub fn header(&self, name: &str) -> Option<&str> {
77        self.headers
78            .iter()
79            .find(|(header_name, _)| header_name.eq_ignore_ascii_case(name))
80            .map(|(_, value)| value.as_str())
81    }
82
83    /// Converts the prepared request into a `reqwest::RequestBuilder`.
84    ///
85    /// This helper is only compiled when the `ai-http` feature is enabled so the core library
86    /// remains network agnostic.
87    #[cfg(feature = "ai-http")]
88    pub fn into_reqwest(
89        self,
90        client: &reqwest::Client,
91    ) -> Result<reqwest::RequestBuilder, AiError> {
92        use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
93
94        let mut headers = HeaderMap::new();
95        for (name, value) in &self.headers {
96            let header_name = HeaderName::from_bytes(name.as_bytes()).map_err(|err| {
97                AiError::new(
98                    AiErrorCode::InvalidArgument,
99                    format!("Invalid header name '{name}': {err}"),
100                    None,
101                )
102            })?;
103            let header_value = HeaderValue::from_str(value).map_err(|err| {
104                AiError::new(
105                    AiErrorCode::InvalidArgument,
106                    format!("Invalid header value for '{name}': {err}"),
107                    None,
108                )
109            })?;
110            headers.insert(header_name, header_value);
111        }
112
113        let builder = match self.method {
114            HttpMethod::Post => client.post(self.url.clone()),
115        }
116        .headers(headers)
117        .timeout(self.timeout)
118        .body(self.body.to_string());
119        Ok(builder)
120    }
121}
122
123#[derive(Clone, Debug)]
124pub(crate) struct RequestFactory {
125    settings: ApiSettings,
126}
127
128impl RequestFactory {
129    pub fn new(settings: ApiSettings) -> Self {
130        Self { settings }
131    }
132
133    pub fn construct_request(
134        &self,
135        model: &str,
136        task: Task,
137        stream: bool,
138        body: Value,
139        request_options: Option<RequestOptions>,
140    ) -> AiResult<PreparedRequest> {
141        let options = request_options.unwrap_or_default();
142        let mut url = self.compose_base_url(&options)?;
143        let trimmed_model = model.trim_start_matches('/');
144        let model_path = match &self.settings.backend {
145            Backend::GoogleAi(_) => format!("projects/{}/{}", self.settings.project, trimmed_model),
146            Backend::VertexAi(inner) => format!(
147                "projects/{}/locations/{}/{}",
148                self.settings.project,
149                inner.location(),
150                trimmed_model
151            ),
152        };
153        let path = format!(
154            "/{}/{model_path}:{}",
155            DEFAULT_API_VERSION,
156            task.as_operation()
157        );
158        url.set_path(&path);
159        if stream {
160            url.query_pairs_mut().append_pair("alt", "sse");
161        } else {
162            url.set_query(None);
163        }
164
165        let timeout = options
166            .timeout
167            .unwrap_or_else(|| Duration::from_millis(DEFAULT_FETCH_TIMEOUT_MS));
168        let headers = self.build_headers();
169
170        Ok(PreparedRequest {
171            method: HttpMethod::Post,
172            url,
173            headers,
174            body,
175            timeout,
176        })
177    }
178
179    fn compose_base_url(&self, options: &RequestOptions) -> AiResult<Url> {
180        let base = options
181            .base_url
182            .as_ref()
183            .map(|value| value.as_str())
184            .unwrap_or(DEFAULT_DOMAIN);
185        let url = if base.starts_with("http://") || base.starts_with("https://") {
186            base.to_string()
187        } else {
188            format!("https://{base}")
189        };
190        Url::parse(&url).map_err(|err| {
191            AiError::new(
192                AiErrorCode::InvalidArgument,
193                format!("Invalid base URL '{url}': {err}"),
194                None,
195            )
196        })
197    }
198
199    fn build_headers(&self) -> Vec<(String, String)> {
200        let mut headers = Vec::with_capacity(5);
201        headers.push(("Content-Type".into(), "application/json".into()));
202        let client_header = format!(
203            "{}/{} fire/{}",
204            LANGUAGE_TAG, PACKAGE_VERSION, PACKAGE_VERSION
205        );
206        headers.push(("x-goog-api-client".into(), client_header));
207        headers.push(("x-goog-api-key".into(), self.settings.api_key.clone()));
208
209        if self.settings.automatic_data_collection_enabled && !self.settings.app_id.is_empty() {
210            headers.push(("X-Firebase-AppId".into(), self.settings.app_id.clone()));
211        }
212
213        if let Some(token) = &self.settings.app_check_token {
214            headers.push(("X-Firebase-AppCheck".into(), token.clone()));
215        }
216
217        if let Some(token) = &self.settings.auth_token {
218            headers.push(("Authorization".into(), format!("Firebase {token}")));
219        }
220
221        headers
222    }
223}
224
225/// High-level tasks supported by the request factory.
226///
227/// Mirrors the `Task` enum in `packages/ai/src/requests/request.ts`.
228#[derive(Clone, Copy, Debug, PartialEq, Eq)]
229pub enum Task {
230    GenerateContent,
231    CountTokens,
232    Predict,
233}
234
235impl Task {
236    pub fn as_operation(&self) -> &'static str {
237        match self {
238            Task::GenerateContent => "generateContent",
239            Task::CountTokens => "countTokens",
240            Task::Predict => "predict",
241        }
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    fn settings_with_backend(backend: Backend) -> ApiSettings {
250        ApiSettings::new(
251            "test-key".into(),
252            "test-project".into(),
253            "1:123:web:abc".into(),
254            backend,
255            true,
256            None,
257            None,
258        )
259    }
260
261    #[test]
262    fn constructs_google_ai_url() {
263        let factory = RequestFactory::new(settings_with_backend(Backend::google_ai()));
264        let req = factory
265            .construct_request(
266                "models/gemini-1.5-flash",
267                Task::GenerateContent,
268                false,
269                serde_json::json!({"contents": []}),
270                None,
271            )
272            .unwrap();
273        assert_eq!(
274            req.url.as_str(),
275            "https://firebasevertexai.googleapis.com/v1beta/projects/test-project/models/gemini-1.5-flash:generateContent"
276        );
277        assert_eq!(req.header("x-goog-api-key"), Some("test-key"));
278        assert_eq!(req.timeout, Duration::from_millis(DEFAULT_FETCH_TIMEOUT_MS));
279    }
280
281    #[test]
282    fn constructs_vertex_ai_url_with_base_override() {
283        let factory =
284            RequestFactory::new(settings_with_backend(Backend::vertex_ai("europe-west4")));
285        let options = RequestOptions {
286            timeout: Some(Duration::from_secs(10)),
287            base_url: Some("https://example.com".into()),
288        };
289        let req = factory
290            .construct_request(
291                "models/gemini-pro",
292                Task::CountTokens,
293                false,
294                serde_json::json!({"contents": []}),
295                Some(options),
296            )
297            .unwrap();
298        assert_eq!(
299            req.url.as_str(),
300            "https://example.com/v1beta/projects/test-project/locations/europe-west4/models/gemini-pro:countTokens"
301        );
302        assert_eq!(req.timeout, Duration::from_secs(10));
303    }
304
305    #[test]
306    fn invalid_base_url_returns_error() {
307        let factory = RequestFactory::new(settings_with_backend(Backend::google_ai()));
308        let err = factory
309            .construct_request(
310                "models/test",
311                Task::Predict,
312                false,
313                serde_json::json!({}),
314                Some(RequestOptions {
315                    timeout: None,
316                    base_url: Some("://bad".into()),
317                }),
318            )
319            .unwrap_err();
320        assert_eq!(err.code(), AiErrorCode::InvalidArgument);
321    }
322}