1use std::collections::HashMap;
2use std::sync::{Arc, LazyLock, Mutex};
3
4use serde_json::{json, Value};
5
6use crate::ai::backend::{Backend, BackendType};
7use crate::ai::constants::AI_COMPONENT_NAME;
8use crate::ai::error::{internal_error, invalid_argument, AiError, AiErrorCode, AiResult};
9use crate::ai::helpers::{decode_instance_identifier, encode_instance_identifier};
10use crate::ai::public_types::{AiOptions, AiRuntimeOptions};
11use crate::ai::requests::{ApiSettings, PreparedRequest, RequestFactory, RequestOptions, Task};
12use crate::app;
13use crate::app::{FirebaseApp, FirebaseOptions};
14use crate::component::types::{
15 ComponentError, DynService, InstanceFactoryOptions, InstantiationMode,
16};
17use crate::component::{Component, ComponentType};
18
19#[derive(Clone, Debug)]
20pub struct AiService {
21 inner: Arc<AiInner>,
22}
23
24#[derive(Debug)]
25struct AiInner {
26 app: FirebaseApp,
27 backend: Backend,
28 options: Mutex<AiRuntimeOptions>,
29 default_model: Option<String>,
30}
31
32#[derive(Clone, Debug, PartialEq, Eq, Hash)]
33struct CacheKey {
34 app_name: String,
35 identifier: String,
36}
37
38impl CacheKey {
39 fn new(app_name: &str, identifier: &str) -> Self {
40 Self {
41 app_name: app_name.to_string(),
42 identifier: identifier.to_string(),
43 }
44 }
45}
46
47static AI_OVERRIDES: LazyLock<Mutex<HashMap<CacheKey, Arc<AiService>>>> =
48 LazyLock::new(|| Mutex::new(HashMap::new()));
49
50#[derive(Clone, Debug, PartialEq, Eq)]
51pub struct GenerateTextRequest {
52 pub prompt: String,
53 pub model: Option<String>,
54}
55
56#[derive(Clone, Debug, PartialEq, Eq)]
57pub struct GenerateTextResponse {
58 pub text: String,
59 pub model: String,
60}
61
62impl AiService {
63 fn new(
64 app: FirebaseApp,
65 backend: Backend,
66 options: AiRuntimeOptions,
67 default_model: Option<String>,
68 ) -> Self {
69 Self {
70 inner: Arc::new(AiInner {
71 app,
72 backend,
73 options: Mutex::new(options),
74 default_model,
75 }),
76 }
77 }
78
79 pub fn app(&self) -> &FirebaseApp {
81 &self.inner.app
82 }
83
84 pub fn backend(&self) -> &Backend {
86 &self.inner.backend
87 }
88
89 pub fn backend_type(&self) -> BackendType {
91 self.inner.backend.backend_type()
92 }
93
94 pub fn location(&self) -> Option<&str> {
96 self.inner
97 .backend
98 .as_vertex_ai()
99 .map(|backend| backend.location())
100 }
101
102 pub fn options(&self) -> AiRuntimeOptions {
104 self.inner.options.lock().unwrap().clone()
105 }
106
107 fn set_options(&self, options: AiRuntimeOptions) {
108 *self.inner.options.lock().unwrap() = options;
109 }
110
111 pub(crate) fn api_settings(&self) -> AiResult<ApiSettings> {
112 let options = self.inner.app.options();
113 let FirebaseOptions {
114 api_key,
115 project_id,
116 app_id,
117 ..
118 } = options;
119
120 let api_key = api_key.ok_or_else(|| {
121 AiError::new(
122 AiErrorCode::NoApiKey,
123 "Firebase options must include `api_key` to use Firebase AI endpoints",
124 None,
125 )
126 })?;
127 let project_id = project_id.ok_or_else(|| {
128 AiError::new(
129 AiErrorCode::NoProjectId,
130 "Firebase options must include `project_id` to use Firebase AI endpoints",
131 None,
132 )
133 })?;
134 let app_id = app_id.ok_or_else(|| {
135 AiError::new(
136 AiErrorCode::NoAppId,
137 "Firebase options must include `app_id` to use Firebase AI endpoints",
138 None,
139 )
140 })?;
141
142 let automatic = self.inner.app.automatic_data_collection_enabled();
143 Ok(ApiSettings::new(
144 api_key,
145 project_id,
146 app_id,
147 self.inner.backend.clone(),
148 automatic,
149 None,
150 None,
151 ))
152 }
153
154 pub(crate) fn request_factory(&self) -> AiResult<RequestFactory> {
155 Ok(RequestFactory::new(self.api_settings()?))
156 }
157
158 pub fn prepare_generate_content_request(
163 &self,
164 model: &str,
165 body: Value,
166 request_options: Option<RequestOptions>,
167 ) -> AiResult<PreparedRequest> {
168 let factory = self.request_factory()?;
169 factory.construct_request(model, Task::GenerateContent, false, body, request_options)
170 }
171
172 pub fn generate_text(&self, request: GenerateTextRequest) -> AiResult<GenerateTextResponse> {
173 if request.prompt.trim().is_empty() {
174 return Err(invalid_argument("Prompt must not be empty"));
175 }
176 let model = request
177 .model
178 .or_else(|| self.inner.default_model.clone())
179 .unwrap_or_else(|| "text-bison-001".to_string());
180
181 let backend_label = self.backend_type().to_string();
182 let location_suffix = self
183 .location()
184 .map(|loc| format!(" @{}", loc))
185 .unwrap_or_default();
186 let synthetic = format!(
187 "[backend:{}{}] generated {} chars",
188 backend_label,
189 location_suffix,
190 request.prompt.len()
191 );
192 Ok(GenerateTextResponse {
193 text: synthetic,
194 model,
195 })
196 }
197}
198
199#[derive(Debug)]
200struct Cache;
201
202impl Cache {
203 fn get(key: &CacheKey) -> Option<Arc<AiService>> {
204 AI_OVERRIDES.lock().unwrap().get(key).cloned()
205 }
206
207 fn insert(key: CacheKey, service: Arc<AiService>) {
208 AI_OVERRIDES.lock().unwrap().insert(key, service);
209 }
210}
211
212static AI_COMPONENT: LazyLock<()> = LazyLock::new(|| {
213 let component = Component::new(
214 AI_COMPONENT_NAME,
215 Arc::new(ai_factory),
216 ComponentType::Public,
217 )
218 .with_instantiation_mode(InstantiationMode::Lazy)
219 .with_multiple_instances(true);
220 let _ = app::registry::register_component(component);
221});
222
223fn ai_factory(
224 container: &crate::component::ComponentContainer,
225 options: InstanceFactoryOptions,
226) -> Result<DynService, ComponentError> {
227 let app = container.root_service::<FirebaseApp>().ok_or_else(|| {
228 ComponentError::InitializationFailed {
229 name: AI_COMPONENT_NAME.to_string(),
230 reason: "Firebase app not attached to component container".to_string(),
231 }
232 })?;
233
234 let identifier_backend = options
235 .instance_identifier
236 .as_deref()
237 .map(|identifier| decode_instance_identifier(identifier));
238
239 let backend = match identifier_backend {
240 Some(Ok(backend)) => backend,
241 Some(Err(err)) => {
242 return Err(ComponentError::InitializationFailed {
243 name: AI_COMPONENT_NAME.to_string(),
244 reason: err.to_string(),
245 })
246 }
247 None => {
248 if let Some(encoded) = options
249 .options
250 .get("backend")
251 .and_then(|value| value.as_str())
252 {
253 decode_instance_identifier(encoded).map_err(|err| {
254 ComponentError::InitializationFailed {
255 name: AI_COMPONENT_NAME.to_string(),
256 reason: err.to_string(),
257 }
258 })?
259 } else {
260 Backend::default()
261 }
262 }
263 };
264
265 let use_limited_tokens = options
266 .options
267 .get("useLimitedUseAppCheckTokens")
268 .and_then(|value| value.as_bool())
269 .unwrap_or(false);
270
271 let runtime_options = AiRuntimeOptions {
272 use_limited_use_app_check_tokens: use_limited_tokens,
273 };
274
275 let default_model = options
276 .options
277 .get("defaultModel")
278 .and_then(|value| value.as_str().map(|s| s.to_string()));
279
280 let service = AiService::new((*app).clone(), backend, runtime_options, default_model);
281 Ok(Arc::new(service) as DynService)
282}
283
284fn ensure_registered() {
285 LazyLock::force(&AI_COMPONENT);
286}
287
288pub fn register_ai_component() {
290 ensure_registered();
291}
292
293pub fn get_ai(app: Option<FirebaseApp>, options: Option<AiOptions>) -> AiResult<Arc<AiService>> {
320 ensure_registered();
321 let app = match app {
322 Some(app) => app,
323 None => crate::app::api::get_app(None).map_err(|err| internal_error(err.to_string()))?,
324 };
325
326 let options = options.unwrap_or_default();
327 let backend = options.backend_or_default();
328 let identifier = encode_instance_identifier(&backend);
329 let runtime_options = AiRuntimeOptions {
330 use_limited_use_app_check_tokens: options.limited_use_app_check(),
331 };
332
333 let cache_key = CacheKey::new(app.name(), &identifier);
334 if let Some(service) = Cache::get(&cache_key) {
335 service.set_options(runtime_options.clone());
336 return Ok(service);
337 }
338
339 let provider = app::registry::get_provider(&app, AI_COMPONENT_NAME);
340
341 if let Some(service) = provider
342 .get_immediate_with_options::<AiService>(Some(&identifier), true)
343 .map_err(|err| internal_error(err.to_string()))?
344 {
345 service.set_options(runtime_options.clone());
346 Cache::insert(cache_key.clone(), service.clone());
347 return Ok(service);
348 }
349
350 match provider.initialize::<AiService>(
351 json!({
352 "backend": identifier,
353 "useLimitedUseAppCheckTokens": runtime_options.use_limited_use_app_check_tokens,
354 }),
355 Some(&cache_key.identifier),
356 ) {
357 Ok(service) => {
358 service.set_options(runtime_options.clone());
359 Cache::insert(cache_key.clone(), service.clone());
360 Ok(service)
361 }
362 Err(ComponentError::InstanceUnavailable { .. }) => {
363 if let Some(service) = provider
364 .get_immediate_with_options::<AiService>(Some(&cache_key.identifier), true)
365 .map_err(|err| internal_error(err.to_string()))?
366 {
367 service.set_options(runtime_options.clone());
368 Cache::insert(cache_key.clone(), service.clone());
369 Ok(service)
370 } else {
371 let fallback =
372 Arc::new(AiService::new(app.clone(), backend, runtime_options, None));
373 Cache::insert(cache_key.clone(), fallback.clone());
374 Ok(fallback)
375 }
376 }
377 Err(err) => Err(internal_error(err.to_string())),
378 }
379}
380
381pub fn get_ai_service(app: Option<FirebaseApp>) -> AiResult<Arc<AiService>> {
383 get_ai(app, None)
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389 use crate::ai::backend::Backend;
390 use crate::ai::error::AiErrorCode;
391 use crate::ai::public_types::AiOptions;
392 use crate::app::api::initialize_app;
393 use crate::app::{FirebaseAppSettings, FirebaseOptions};
394 use serde_json::json;
395
396 fn unique_settings() -> FirebaseAppSettings {
397 use std::sync::atomic::{AtomicUsize, Ordering};
398 static COUNTER: AtomicUsize = AtomicUsize::new(0);
399 FirebaseAppSettings {
400 name: Some(format!("ai-{}", COUNTER.fetch_add(1, Ordering::SeqCst))),
401 ..Default::default()
402 }
403 }
404
405 #[test]
406 fn generate_text_includes_backend_info() {
407 let options = FirebaseOptions {
408 project_id: Some("project".into()),
409 ..Default::default()
410 };
411 let app = initialize_app(options, Some(unique_settings())).unwrap();
412 let ai = get_ai_service(Some(app)).unwrap();
413 let response = ai
414 .generate_text(GenerateTextRequest {
415 prompt: "Hello AI".to_string(),
416 model: Some("text-test".to_string()),
417 })
418 .unwrap();
419 assert_eq!(response.model, "text-test");
420 assert!(response.text.contains("generated 8 chars"));
421 assert!(response.text.contains("backend:GOOGLE_AI"));
422 }
423
424 #[test]
425 fn empty_prompt_errors() {
426 let options = FirebaseOptions {
427 project_id: Some("project".into()),
428 api_key: Some("api".into()),
429 app_id: Some("app".into()),
430 ..Default::default()
431 };
432 let app = initialize_app(options, Some(unique_settings())).unwrap();
433 let ai = get_ai_service(Some(app)).unwrap();
434 let err = ai
435 .generate_text(GenerateTextRequest {
436 prompt: " ".to_string(),
437 model: None,
438 })
439 .unwrap_err();
440 assert_eq!(err.code_str(), "AI/invalid-argument");
441 }
442
443 #[test]
444 fn backend_identifier_creates_unique_instances() {
445 let options = FirebaseOptions {
446 project_id: Some("project".into()),
447 ..Default::default()
448 };
449 let app = initialize_app(options, Some(unique_settings())).unwrap();
450
451 let google = get_ai(
452 Some(app.clone()),
453 Some(AiOptions {
454 backend: Some(Backend::google_ai()),
455 use_limited_use_app_check_tokens: None,
456 }),
457 )
458 .unwrap();
459
460 let vertex = get_ai(
461 Some(app.clone()),
462 Some(AiOptions {
463 backend: Some(Backend::vertex_ai("europe-west4")),
464 use_limited_use_app_check_tokens: Some(true),
465 }),
466 )
467 .unwrap();
468
469 assert_ne!(Arc::as_ptr(&google), Arc::as_ptr(&vertex));
470 assert_eq!(vertex.location(), Some("europe-west4"));
471 assert!(vertex.options().use_limited_use_app_check_tokens);
472 }
473
474 #[test]
475 fn get_ai_reuses_cached_instance() {
476 let options = FirebaseOptions {
477 project_id: Some("project".into()),
478 api_key: Some("api".into()),
479 app_id: Some("app".into()),
480 ..Default::default()
481 };
482 let app = initialize_app(options, Some(unique_settings())).unwrap();
483
484 let first = get_ai_service(Some(app.clone())).unwrap();
485 first
486 .generate_text(GenerateTextRequest {
487 prompt: "ping".to_string(),
488 model: None,
489 })
490 .unwrap();
491
492 let second = get_ai(Some(app.clone()), None).unwrap();
493 assert_eq!(Arc::as_ptr(&first), Arc::as_ptr(&second));
494 }
495
496 #[test]
497 fn api_settings_require_project_id() {
498 let options = FirebaseOptions {
499 api_key: Some("api".into()),
500 app_id: Some("app".into()),
501 ..Default::default()
502 };
503 let app = initialize_app(options, Some(unique_settings())).unwrap();
504 let ai = get_ai_service(Some(app)).unwrap();
505 let err = ai.api_settings().unwrap_err();
506 assert_eq!(err.code(), AiErrorCode::NoProjectId);
507 }
508
509 #[test]
510 fn prepare_generate_content_request_builds_expected_url() {
511 let options = FirebaseOptions {
512 api_key: Some("api".into()),
513 project_id: Some("project".into()),
514 app_id: Some("app".into()),
515 ..Default::default()
516 };
517 let app = initialize_app(options, Some(unique_settings())).unwrap();
518 let ai = get_ai_service(Some(app)).unwrap();
519 let prepared = ai
520 .prepare_generate_content_request(
521 "models/gemini-1.5-flash",
522 json!({ "contents": [] }),
523 None,
524 )
525 .unwrap();
526 assert_eq!(
527 prepared.url.as_str(),
528 "https://firebasevertexai.googleapis.com/v1beta/projects/project/models/gemini-1.5-flash:generateContent"
529 );
530 assert_eq!(prepared.header("x-goog-api-key"), Some("api"));
531 }
532}