1use std::time::Duration;
2
3use serde_json::Value;
4use url::Url;
5
6use crate::ai::backend::Backend;
7use crate::ai::constants::{
8 DEFAULT_API_VERSION, DEFAULT_DOMAIN, DEFAULT_FETCH_TIMEOUT_MS, LANGUAGE_TAG, PACKAGE_VERSION,
9};
10use crate::ai::error::{AiError, AiErrorCode, AiResult};
11
12#[derive(Clone, Debug)]
16pub(crate) struct ApiSettings {
17 pub api_key: String,
18 pub project: String,
19 pub app_id: String,
20 pub backend: Backend,
21 pub automatic_data_collection_enabled: bool,
22 pub app_check_token: Option<String>,
23 pub app_check_heartbeat: Option<String>,
24 pub auth_token: Option<String>,
25}
26
27impl ApiSettings {
28 pub fn new(
29 api_key: String,
30 project: String,
31 app_id: String,
32 backend: Backend,
33 automatic_data_collection_enabled: bool,
34 app_check_token: Option<String>,
35 app_check_heartbeat: Option<String>,
36 auth_token: Option<String>,
37 ) -> Self {
38 Self {
39 api_key,
40 project,
41 app_id,
42 backend,
43 automatic_data_collection_enabled,
44 app_check_token,
45 app_check_heartbeat,
46 auth_token,
47 }
48 }
49}
50
51#[derive(Clone, Debug, Default, PartialEq)]
55pub struct RequestOptions {
56 pub timeout: Option<Duration>,
58 pub base_url: Option<String>,
60}
61
62#[derive(Clone, Copy, Debug, PartialEq, Eq)]
63pub enum HttpMethod {
64 Post,
65}
66
67#[derive(Clone, Debug, PartialEq)]
69pub struct PreparedRequest {
70 pub method: HttpMethod,
71 pub url: Url,
72 pub headers: Vec<(String, String)>,
73 pub body: Value,
74 pub timeout: Duration,
75}
76
77impl PreparedRequest {
78 pub fn header(&self, name: &str) -> Option<&str> {
80 self.headers
81 .iter()
82 .find(|(header_name, _)| header_name.eq_ignore_ascii_case(name))
83 .map(|(_, value)| value.as_str())
84 }
85
86 pub fn into_reqwest(
88 self,
89 client: &reqwest::Client,
90 ) -> Result<reqwest::RequestBuilder, AiError> {
91 use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
92
93 let mut headers = HeaderMap::new();
94 for (name, value) in &self.headers {
95 let header_name = HeaderName::from_bytes(name.as_bytes()).map_err(|err| {
96 AiError::new(
97 AiErrorCode::InvalidArgument,
98 format!("Invalid header name '{name}': {err}"),
99 None,
100 )
101 })?;
102 let header_value = HeaderValue::from_str(value).map_err(|err| {
103 AiError::new(
104 AiErrorCode::InvalidArgument,
105 format!("Invalid header value for '{name}': {err}"),
106 None,
107 )
108 })?;
109 headers.insert(header_name, header_value);
110 }
111
112 let builder = match self.method {
113 HttpMethod::Post => client.post(self.url.clone()),
114 }
115 .headers(headers)
116 .body(self.body.to_string());
117
118 #[cfg(not(target_arch = "wasm32"))]
119 let builder = builder.timeout(self.timeout);
120
121 Ok(builder)
122 }
123}
124
125#[derive(Clone, Debug)]
126pub(crate) struct RequestFactory {
127 settings: ApiSettings,
128}
129
130impl RequestFactory {
131 pub fn new(settings: ApiSettings) -> Self {
132 Self { settings }
133 }
134
135 pub fn construct_request(
136 &self,
137 model: &str,
138 task: Task,
139 stream: bool,
140 body: Value,
141 request_options: Option<RequestOptions>,
142 ) -> AiResult<PreparedRequest> {
143 let options = request_options.unwrap_or_default();
144 let mut url = self.compose_base_url(&options)?;
145 let trimmed_model = model.trim_start_matches('/');
146 let model_path = match &self.settings.backend {
147 Backend::GoogleAi(_) => format!("projects/{}/{}", self.settings.project, trimmed_model),
148 Backend::VertexAi(inner) => format!(
149 "projects/{}/locations/{}/{}",
150 self.settings.project,
151 inner.location(),
152 trimmed_model
153 ),
154 };
155 let path = format!(
156 "/{}/{model_path}:{}",
157 DEFAULT_API_VERSION,
158 task.as_operation()
159 );
160 url.set_path(&path);
161 if stream {
162 url.query_pairs_mut().append_pair("alt", "sse");
163 } else {
164 url.set_query(None);
165 }
166
167 let timeout = options
168 .timeout
169 .unwrap_or_else(|| Duration::from_millis(DEFAULT_FETCH_TIMEOUT_MS));
170 let headers = self.build_headers();
171
172 Ok(PreparedRequest {
173 method: HttpMethod::Post,
174 url,
175 headers,
176 body,
177 timeout,
178 })
179 }
180
181 fn compose_base_url(&self, options: &RequestOptions) -> AiResult<Url> {
182 let base = options
183 .base_url
184 .as_ref()
185 .map(|value| value.as_str())
186 .unwrap_or(DEFAULT_DOMAIN);
187 let url = if base.starts_with("http://") || base.starts_with("https://") {
188 base.to_string()
189 } else {
190 format!("https://{base}")
191 };
192 Url::parse(&url).map_err(|err| {
193 AiError::new(
194 AiErrorCode::InvalidArgument,
195 format!("Invalid base URL '{url}': {err}"),
196 None,
197 )
198 })
199 }
200
201 fn build_headers(&self) -> Vec<(String, String)> {
202 let mut headers = Vec::with_capacity(5);
203 headers.push(("Content-Type".into(), "application/json".into()));
204 let client_header = format!(
205 "{}/{} fire/{}",
206 LANGUAGE_TAG, PACKAGE_VERSION, PACKAGE_VERSION
207 );
208 headers.push(("x-goog-api-client".into(), client_header));
209 headers.push(("x-goog-api-key".into(), self.settings.api_key.clone()));
210
211 if self.settings.automatic_data_collection_enabled && !self.settings.app_id.is_empty() {
212 headers.push(("X-Firebase-AppId".into(), self.settings.app_id.clone()));
213 }
214
215 if let Some(token) = &self.settings.app_check_token {
216 headers.push(("X-Firebase-AppCheck".into(), token.clone()));
217 }
218
219 if let Some(header) = &self.settings.app_check_heartbeat {
220 headers.push(("X-Firebase-Client".into(), header.clone()));
221 }
222
223 if let Some(token) = &self.settings.auth_token {
224 headers.push(("Authorization".into(), format!("Firebase {token}")));
225 }
226
227 headers
228 }
229}
230
231#[derive(Clone, Copy, Debug, PartialEq, Eq)]
235pub enum Task {
236 GenerateContent,
237 #[allow(dead_code)]
238 CountTokens,
239 #[allow(dead_code)]
240 Predict,
241}
242
243impl Task {
244 pub fn as_operation(&self) -> &'static str {
245 match self {
246 Task::GenerateContent => "generateContent",
247 Task::CountTokens => "countTokens",
248 Task::Predict => "predict",
249 }
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 fn settings_with_backend(backend: Backend) -> ApiSettings {
258 ApiSettings::new(
259 "test-key".into(),
260 "test-project".into(),
261 "1:123:web:abc".into(),
262 backend,
263 true,
264 None,
265 None,
266 None,
267 )
268 }
269
270 #[test]
271 fn constructs_google_ai_url() {
272 let factory = RequestFactory::new(settings_with_backend(Backend::google_ai()));
273 let req = factory
274 .construct_request(
275 "models/gemini-1.5-flash",
276 Task::GenerateContent,
277 false,
278 serde_json::json!({"contents": []}),
279 None,
280 )
281 .unwrap();
282 assert_eq!(
283 req.url.as_str(),
284 "https://firebasevertexai.googleapis.com/v1beta/projects/test-project/models/gemini-1.5-flash:generateContent"
285 );
286 assert_eq!(req.header("x-goog-api-key"), Some("test-key"));
287 assert_eq!(req.timeout, Duration::from_millis(DEFAULT_FETCH_TIMEOUT_MS));
288 }
289
290 #[test]
291 fn constructs_vertex_ai_url_with_base_override() {
292 let factory =
293 RequestFactory::new(settings_with_backend(Backend::vertex_ai("europe-west4")));
294 let options = RequestOptions {
295 timeout: Some(Duration::from_secs(10)),
296 base_url: Some("https://example.com".into()),
297 };
298 let req = factory
299 .construct_request(
300 "models/gemini-pro",
301 Task::CountTokens,
302 false,
303 serde_json::json!({"contents": []}),
304 Some(options),
305 )
306 .unwrap();
307 assert_eq!(
308 req.url.as_str(),
309 "https://example.com/v1beta/projects/test-project/locations/europe-west4/models/gemini-pro:countTokens"
310 );
311 assert_eq!(req.timeout, Duration::from_secs(10));
312 }
313
314 #[test]
315 fn invalid_base_url_returns_error() {
316 let factory = RequestFactory::new(settings_with_backend(Backend::google_ai()));
317 let err = factory
318 .construct_request(
319 "models/test",
320 Task::Predict,
321 false,
322 serde_json::json!({}),
323 Some(RequestOptions {
324 timeout: None,
325 base_url: Some("://bad".into()),
326 }),
327 )
328 .unwrap_err();
329 assert_eq!(err.code(), AiErrorCode::InvalidArgument);
330 }
331}