1use std::time::Duration;
2
3use serde_json::Value;
4use url::Url;
5
6use crate::ai::backend::Backend;
7use crate::ai::constants::{
8 DEFAULT_API_VERSION, DEFAULT_DOMAIN, DEFAULT_FETCH_TIMEOUT_MS, LANGUAGE_TAG, PACKAGE_VERSION,
9};
10use crate::ai::error::{AiError, AiErrorCode, AiResult};
11
12#[derive(Clone, Debug)]
16pub(crate) struct ApiSettings {
17 pub api_key: String,
18 pub project: String,
19 pub app_id: String,
20 pub backend: Backend,
21 pub automatic_data_collection_enabled: bool,
22 pub app_check_token: Option<String>,
23 pub auth_token: Option<String>,
24}
25
26impl ApiSettings {
27 pub fn new(
28 api_key: String,
29 project: String,
30 app_id: String,
31 backend: Backend,
32 automatic_data_collection_enabled: bool,
33 app_check_token: Option<String>,
34 auth_token: Option<String>,
35 ) -> Self {
36 Self {
37 api_key,
38 project,
39 app_id,
40 backend,
41 automatic_data_collection_enabled,
42 app_check_token,
43 auth_token,
44 }
45 }
46}
47
48#[derive(Clone, Debug, Default, PartialEq)]
52pub struct RequestOptions {
53 pub timeout: Option<Duration>,
55 pub base_url: Option<String>,
57}
58
59#[derive(Clone, Copy, Debug, PartialEq, Eq)]
60pub enum HttpMethod {
61 Post,
62}
63
64#[derive(Clone, Debug, PartialEq)]
66pub struct PreparedRequest {
67 pub method: HttpMethod,
68 pub url: Url,
69 pub headers: Vec<(String, String)>,
70 pub body: Value,
71 pub timeout: Duration,
72}
73
74impl PreparedRequest {
75 pub fn header(&self, name: &str) -> Option<&str> {
77 self.headers
78 .iter()
79 .find(|(header_name, _)| header_name.eq_ignore_ascii_case(name))
80 .map(|(_, value)| value.as_str())
81 }
82
83 #[cfg(feature = "ai-http")]
88 pub fn into_reqwest(
89 self,
90 client: &reqwest::Client,
91 ) -> Result<reqwest::RequestBuilder, AiError> {
92 use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
93
94 let mut headers = HeaderMap::new();
95 for (name, value) in &self.headers {
96 let header_name = HeaderName::from_bytes(name.as_bytes()).map_err(|err| {
97 AiError::new(
98 AiErrorCode::InvalidArgument,
99 format!("Invalid header name '{name}': {err}"),
100 None,
101 )
102 })?;
103 let header_value = HeaderValue::from_str(value).map_err(|err| {
104 AiError::new(
105 AiErrorCode::InvalidArgument,
106 format!("Invalid header value for '{name}': {err}"),
107 None,
108 )
109 })?;
110 headers.insert(header_name, header_value);
111 }
112
113 let builder = match self.method {
114 HttpMethod::Post => client.post(self.url.clone()),
115 }
116 .headers(headers)
117 .timeout(self.timeout)
118 .body(self.body.to_string());
119 Ok(builder)
120 }
121}
122
123#[derive(Clone, Debug)]
124pub(crate) struct RequestFactory {
125 settings: ApiSettings,
126}
127
128impl RequestFactory {
129 pub fn new(settings: ApiSettings) -> Self {
130 Self { settings }
131 }
132
133 pub fn construct_request(
134 &self,
135 model: &str,
136 task: Task,
137 stream: bool,
138 body: Value,
139 request_options: Option<RequestOptions>,
140 ) -> AiResult<PreparedRequest> {
141 let options = request_options.unwrap_or_default();
142 let mut url = self.compose_base_url(&options)?;
143 let trimmed_model = model.trim_start_matches('/');
144 let model_path = match &self.settings.backend {
145 Backend::GoogleAi(_) => format!("projects/{}/{}", self.settings.project, trimmed_model),
146 Backend::VertexAi(inner) => format!(
147 "projects/{}/locations/{}/{}",
148 self.settings.project,
149 inner.location(),
150 trimmed_model
151 ),
152 };
153 let path = format!(
154 "/{}/{model_path}:{}",
155 DEFAULT_API_VERSION,
156 task.as_operation()
157 );
158 url.set_path(&path);
159 if stream {
160 url.query_pairs_mut().append_pair("alt", "sse");
161 } else {
162 url.set_query(None);
163 }
164
165 let timeout = options
166 .timeout
167 .unwrap_or_else(|| Duration::from_millis(DEFAULT_FETCH_TIMEOUT_MS));
168 let headers = self.build_headers();
169
170 Ok(PreparedRequest {
171 method: HttpMethod::Post,
172 url,
173 headers,
174 body,
175 timeout,
176 })
177 }
178
179 fn compose_base_url(&self, options: &RequestOptions) -> AiResult<Url> {
180 let base = options
181 .base_url
182 .as_ref()
183 .map(|value| value.as_str())
184 .unwrap_or(DEFAULT_DOMAIN);
185 let url = if base.starts_with("http://") || base.starts_with("https://") {
186 base.to_string()
187 } else {
188 format!("https://{base}")
189 };
190 Url::parse(&url).map_err(|err| {
191 AiError::new(
192 AiErrorCode::InvalidArgument,
193 format!("Invalid base URL '{url}': {err}"),
194 None,
195 )
196 })
197 }
198
199 fn build_headers(&self) -> Vec<(String, String)> {
200 let mut headers = Vec::with_capacity(5);
201 headers.push(("Content-Type".into(), "application/json".into()));
202 let client_header = format!(
203 "{}/{} fire/{}",
204 LANGUAGE_TAG, PACKAGE_VERSION, PACKAGE_VERSION
205 );
206 headers.push(("x-goog-api-client".into(), client_header));
207 headers.push(("x-goog-api-key".into(), self.settings.api_key.clone()));
208
209 if self.settings.automatic_data_collection_enabled && !self.settings.app_id.is_empty() {
210 headers.push(("X-Firebase-AppId".into(), self.settings.app_id.clone()));
211 }
212
213 if let Some(token) = &self.settings.app_check_token {
214 headers.push(("X-Firebase-AppCheck".into(), token.clone()));
215 }
216
217 if let Some(token) = &self.settings.auth_token {
218 headers.push(("Authorization".into(), format!("Firebase {token}")));
219 }
220
221 headers
222 }
223}
224
225#[derive(Clone, Copy, Debug, PartialEq, Eq)]
229pub enum Task {
230 GenerateContent,
231 CountTokens,
232 Predict,
233}
234
235impl Task {
236 pub fn as_operation(&self) -> &'static str {
237 match self {
238 Task::GenerateContent => "generateContent",
239 Task::CountTokens => "countTokens",
240 Task::Predict => "predict",
241 }
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 fn settings_with_backend(backend: Backend) -> ApiSettings {
250 ApiSettings::new(
251 "test-key".into(),
252 "test-project".into(),
253 "1:123:web:abc".into(),
254 backend,
255 true,
256 None,
257 None,
258 )
259 }
260
261 #[test]
262 fn constructs_google_ai_url() {
263 let factory = RequestFactory::new(settings_with_backend(Backend::google_ai()));
264 let req = factory
265 .construct_request(
266 "models/gemini-1.5-flash",
267 Task::GenerateContent,
268 false,
269 serde_json::json!({"contents": []}),
270 None,
271 )
272 .unwrap();
273 assert_eq!(
274 req.url.as_str(),
275 "https://firebasevertexai.googleapis.com/v1beta/projects/test-project/models/gemini-1.5-flash:generateContent"
276 );
277 assert_eq!(req.header("x-goog-api-key"), Some("test-key"));
278 assert_eq!(req.timeout, Duration::from_millis(DEFAULT_FETCH_TIMEOUT_MS));
279 }
280
281 #[test]
282 fn constructs_vertex_ai_url_with_base_override() {
283 let factory =
284 RequestFactory::new(settings_with_backend(Backend::vertex_ai("europe-west4")));
285 let options = RequestOptions {
286 timeout: Some(Duration::from_secs(10)),
287 base_url: Some("https://example.com".into()),
288 };
289 let req = factory
290 .construct_request(
291 "models/gemini-pro",
292 Task::CountTokens,
293 false,
294 serde_json::json!({"contents": []}),
295 Some(options),
296 )
297 .unwrap();
298 assert_eq!(
299 req.url.as_str(),
300 "https://example.com/v1beta/projects/test-project/locations/europe-west4/models/gemini-pro:countTokens"
301 );
302 assert_eq!(req.timeout, Duration::from_secs(10));
303 }
304
305 #[test]
306 fn invalid_base_url_returns_error() {
307 let factory = RequestFactory::new(settings_with_backend(Backend::google_ai()));
308 let err = factory
309 .construct_request(
310 "models/test",
311 Task::Predict,
312 false,
313 serde_json::json!({}),
314 Some(RequestOptions {
315 timeout: None,
316 base_url: Some("://bad".into()),
317 }),
318 )
319 .unwrap_err();
320 assert_eq!(err.code(), AiErrorCode::InvalidArgument);
321 }
322}