firebase_rs_sdk/ai/
api.rs

1use std::collections::HashMap;
2use std::sync::{Arc, LazyLock, Mutex};
3
4use serde_json::{json, Value};
5
6use crate::ai::backend::{Backend, BackendType};
7use crate::ai::constants::AI_COMPONENT_NAME;
8use crate::ai::error::{internal_error, invalid_argument, AiError, AiErrorCode, AiResult};
9use crate::ai::helpers::{decode_instance_identifier, encode_instance_identifier};
10use crate::ai::public_types::{AiOptions, AiRuntimeOptions};
11use crate::ai::requests::{ApiSettings, PreparedRequest, RequestFactory, RequestOptions, Task};
12use crate::app;
13use crate::app::{FirebaseApp, FirebaseOptions};
14use crate::component::types::{
15    ComponentError, DynService, InstanceFactoryOptions, InstantiationMode,
16};
17use crate::component::{Component, ComponentType};
18
19#[derive(Clone, Debug)]
20pub struct AiService {
21    inner: Arc<AiInner>,
22}
23
24#[derive(Debug)]
25struct AiInner {
26    app: FirebaseApp,
27    backend: Backend,
28    options: Mutex<AiRuntimeOptions>,
29    default_model: Option<String>,
30}
31
32#[derive(Clone, Debug, PartialEq, Eq, Hash)]
33struct CacheKey {
34    app_name: String,
35    identifier: String,
36}
37
38impl CacheKey {
39    fn new(app_name: &str, identifier: &str) -> Self {
40        Self {
41            app_name: app_name.to_string(),
42            identifier: identifier.to_string(),
43        }
44    }
45}
46
47static AI_OVERRIDES: LazyLock<Mutex<HashMap<CacheKey, Arc<AiService>>>> =
48    LazyLock::new(|| Mutex::new(HashMap::new()));
49
50#[derive(Clone, Debug, PartialEq, Eq)]
51pub struct GenerateTextRequest {
52    pub prompt: String,
53    pub model: Option<String>,
54}
55
56#[derive(Clone, Debug, PartialEq, Eq)]
57pub struct GenerateTextResponse {
58    pub text: String,
59    pub model: String,
60}
61
62impl AiService {
63    fn new(
64        app: FirebaseApp,
65        backend: Backend,
66        options: AiRuntimeOptions,
67        default_model: Option<String>,
68    ) -> Self {
69        Self {
70            inner: Arc::new(AiInner {
71                app,
72                backend,
73                options: Mutex::new(options),
74                default_model,
75            }),
76        }
77    }
78
79    /// Returns the Firebase app associated with this AI service.
80    pub fn app(&self) -> &FirebaseApp {
81        &self.inner.app
82    }
83
84    /// Returns the backend configuration used by this AI service.
85    pub fn backend(&self) -> &Backend {
86        &self.inner.backend
87    }
88
89    /// Returns the backend type tag.
90    pub fn backend_type(&self) -> BackendType {
91        self.inner.backend.backend_type()
92    }
93
94    /// Returns the Vertex AI location when using that backend.
95    pub fn location(&self) -> Option<&str> {
96        self.inner
97            .backend
98            .as_vertex_ai()
99            .map(|backend| backend.location())
100    }
101
102    /// Returns the runtime options currently applied to this AI service.
103    pub fn options(&self) -> AiRuntimeOptions {
104        self.inner.options.lock().unwrap().clone()
105    }
106
107    fn set_options(&self, options: AiRuntimeOptions) {
108        *self.inner.options.lock().unwrap() = options;
109    }
110
111    pub(crate) fn api_settings(&self) -> AiResult<ApiSettings> {
112        let options = self.inner.app.options();
113        let FirebaseOptions {
114            api_key,
115            project_id,
116            app_id,
117            ..
118        } = options;
119
120        let api_key = api_key.ok_or_else(|| {
121            AiError::new(
122                AiErrorCode::NoApiKey,
123                "Firebase options must include `api_key` to use Firebase AI endpoints",
124                None,
125            )
126        })?;
127        let project_id = project_id.ok_or_else(|| {
128            AiError::new(
129                AiErrorCode::NoProjectId,
130                "Firebase options must include `project_id` to use Firebase AI endpoints",
131                None,
132            )
133        })?;
134        let app_id = app_id.ok_or_else(|| {
135            AiError::new(
136                AiErrorCode::NoAppId,
137                "Firebase options must include `app_id` to use Firebase AI endpoints",
138                None,
139            )
140        })?;
141
142        let automatic = self.inner.app.automatic_data_collection_enabled();
143        Ok(ApiSettings::new(
144            api_key,
145            project_id,
146            app_id,
147            self.inner.backend.clone(),
148            automatic,
149            None,
150            None,
151        ))
152    }
153
154    pub(crate) fn request_factory(&self) -> AiResult<RequestFactory> {
155        Ok(RequestFactory::new(self.api_settings()?))
156    }
157
158    /// Prepares a REST request for a `generateContent` call without executing it.
159    ///
160    /// This mirrors the behaviour of `constructRequest` in the TypeScript SDK and allows advanced
161    /// callers to integrate with custom HTTP stacks while the SDK handles URL/header generation.
162    pub fn prepare_generate_content_request(
163        &self,
164        model: &str,
165        body: Value,
166        request_options: Option<RequestOptions>,
167    ) -> AiResult<PreparedRequest> {
168        let factory = self.request_factory()?;
169        factory.construct_request(model, Task::GenerateContent, false, body, request_options)
170    }
171
172    pub fn generate_text(&self, request: GenerateTextRequest) -> AiResult<GenerateTextResponse> {
173        if request.prompt.trim().is_empty() {
174            return Err(invalid_argument("Prompt must not be empty"));
175        }
176        let model = request
177            .model
178            .or_else(|| self.inner.default_model.clone())
179            .unwrap_or_else(|| "text-bison-001".to_string());
180
181        let backend_label = self.backend_type().to_string();
182        let location_suffix = self
183            .location()
184            .map(|loc| format!(" @{}", loc))
185            .unwrap_or_default();
186        let synthetic = format!(
187            "[backend:{}{}] generated {} chars",
188            backend_label,
189            location_suffix,
190            request.prompt.len()
191        );
192        Ok(GenerateTextResponse {
193            text: synthetic,
194            model,
195        })
196    }
197}
198
199#[derive(Debug)]
200struct Cache;
201
202impl Cache {
203    fn get(key: &CacheKey) -> Option<Arc<AiService>> {
204        AI_OVERRIDES.lock().unwrap().get(key).cloned()
205    }
206
207    fn insert(key: CacheKey, service: Arc<AiService>) {
208        AI_OVERRIDES.lock().unwrap().insert(key, service);
209    }
210}
211
212static AI_COMPONENT: LazyLock<()> = LazyLock::new(|| {
213    let component = Component::new(
214        AI_COMPONENT_NAME,
215        Arc::new(ai_factory),
216        ComponentType::Public,
217    )
218    .with_instantiation_mode(InstantiationMode::Lazy)
219    .with_multiple_instances(true);
220    let _ = app::registry::register_component(component);
221});
222
223fn ai_factory(
224    container: &crate::component::ComponentContainer,
225    options: InstanceFactoryOptions,
226) -> Result<DynService, ComponentError> {
227    let app = container.root_service::<FirebaseApp>().ok_or_else(|| {
228        ComponentError::InitializationFailed {
229            name: AI_COMPONENT_NAME.to_string(),
230            reason: "Firebase app not attached to component container".to_string(),
231        }
232    })?;
233
234    let identifier_backend = options
235        .instance_identifier
236        .as_deref()
237        .map(|identifier| decode_instance_identifier(identifier));
238
239    let backend = match identifier_backend {
240        Some(Ok(backend)) => backend,
241        Some(Err(err)) => {
242            return Err(ComponentError::InitializationFailed {
243                name: AI_COMPONENT_NAME.to_string(),
244                reason: err.to_string(),
245            })
246        }
247        None => {
248            if let Some(encoded) = options
249                .options
250                .get("backend")
251                .and_then(|value| value.as_str())
252            {
253                decode_instance_identifier(encoded).map_err(|err| {
254                    ComponentError::InitializationFailed {
255                        name: AI_COMPONENT_NAME.to_string(),
256                        reason: err.to_string(),
257                    }
258                })?
259            } else {
260                Backend::default()
261            }
262        }
263    };
264
265    let use_limited_tokens = options
266        .options
267        .get("useLimitedUseAppCheckTokens")
268        .and_then(|value| value.as_bool())
269        .unwrap_or(false);
270
271    let runtime_options = AiRuntimeOptions {
272        use_limited_use_app_check_tokens: use_limited_tokens,
273    };
274
275    let default_model = options
276        .options
277        .get("defaultModel")
278        .and_then(|value| value.as_str().map(|s| s.to_string()));
279
280    let service = AiService::new((*app).clone(), backend, runtime_options, default_model);
281    Ok(Arc::new(service) as DynService)
282}
283
284fn ensure_registered() {
285    LazyLock::force(&AI_COMPONENT);
286}
287
288/// Registers the AI component in the global registry.
289pub fn register_ai_component() {
290    ensure_registered();
291}
292
293/// Returns an AI service instance, mirroring the JavaScript `getAI()` API.
294///
295/// When `options` is provided the backend identifier is encoded using the same
296/// rules as `encodeInstanceIdentifier` from the JavaScript SDK so that separate
297/// backend configurations create independent service instances.
298///
299/// # Examples
300///
301/// ```
302/// # use firebase_rs_sdk::ai::backend::Backend;
303/// # use firebase_rs_sdk::ai::public_types::AiOptions;
304/// # use firebase_rs_sdk::ai::get_ai;
305/// # use firebase_rs_sdk::app::api::initialize_app;
306/// # use firebase_rs_sdk::app::{FirebaseAppSettings, FirebaseOptions};
307/// let options = FirebaseOptions {
308///     project_id: Some("project".into()),
309///     api_key: Some("test".into()),
310///     ..Default::default()
311/// };
312/// let app = initialize_app(options, Some(FirebaseAppSettings::default())).unwrap();
313/// let ai = get_ai(Some(app), Some(AiOptions {
314///     backend: Some(Backend::vertex_ai("us-central1")),
315///     use_limited_use_app_check_tokens: Some(false),
316/// }));
317/// assert!(ai.is_ok());
318/// ```
319pub fn get_ai(app: Option<FirebaseApp>, options: Option<AiOptions>) -> AiResult<Arc<AiService>> {
320    ensure_registered();
321    let app = match app {
322        Some(app) => app,
323        None => crate::app::api::get_app(None).map_err(|err| internal_error(err.to_string()))?,
324    };
325
326    let options = options.unwrap_or_default();
327    let backend = options.backend_or_default();
328    let identifier = encode_instance_identifier(&backend);
329    let runtime_options = AiRuntimeOptions {
330        use_limited_use_app_check_tokens: options.limited_use_app_check(),
331    };
332
333    let cache_key = CacheKey::new(app.name(), &identifier);
334    if let Some(service) = Cache::get(&cache_key) {
335        service.set_options(runtime_options.clone());
336        return Ok(service);
337    }
338
339    let provider = app::registry::get_provider(&app, AI_COMPONENT_NAME);
340
341    if let Some(service) = provider
342        .get_immediate_with_options::<AiService>(Some(&identifier), true)
343        .map_err(|err| internal_error(err.to_string()))?
344    {
345        service.set_options(runtime_options.clone());
346        Cache::insert(cache_key.clone(), service.clone());
347        return Ok(service);
348    }
349
350    match provider.initialize::<AiService>(
351        json!({
352            "backend": identifier,
353            "useLimitedUseAppCheckTokens": runtime_options.use_limited_use_app_check_tokens,
354        }),
355        Some(&cache_key.identifier),
356    ) {
357        Ok(service) => {
358            service.set_options(runtime_options.clone());
359            Cache::insert(cache_key.clone(), service.clone());
360            Ok(service)
361        }
362        Err(ComponentError::InstanceUnavailable { .. }) => {
363            if let Some(service) = provider
364                .get_immediate_with_options::<AiService>(Some(&cache_key.identifier), true)
365                .map_err(|err| internal_error(err.to_string()))?
366            {
367                service.set_options(runtime_options.clone());
368                Cache::insert(cache_key.clone(), service.clone());
369                Ok(service)
370            } else {
371                let fallback =
372                    Arc::new(AiService::new(app.clone(), backend, runtime_options, None));
373                Cache::insert(cache_key.clone(), fallback.clone());
374                Ok(fallback)
375            }
376        }
377        Err(err) => Err(internal_error(err.to_string())),
378    }
379}
380
381/// Convenience wrapper that mirrors the original Rust stub signature.
382pub fn get_ai_service(app: Option<FirebaseApp>) -> AiResult<Arc<AiService>> {
383    get_ai(app, None)
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389    use crate::ai::backend::Backend;
390    use crate::ai::error::AiErrorCode;
391    use crate::ai::public_types::AiOptions;
392    use crate::app::api::initialize_app;
393    use crate::app::{FirebaseAppSettings, FirebaseOptions};
394    use serde_json::json;
395
396    fn unique_settings() -> FirebaseAppSettings {
397        use std::sync::atomic::{AtomicUsize, Ordering};
398        static COUNTER: AtomicUsize = AtomicUsize::new(0);
399        FirebaseAppSettings {
400            name: Some(format!("ai-{}", COUNTER.fetch_add(1, Ordering::SeqCst))),
401            ..Default::default()
402        }
403    }
404
405    #[test]
406    fn generate_text_includes_backend_info() {
407        let options = FirebaseOptions {
408            project_id: Some("project".into()),
409            ..Default::default()
410        };
411        let app = initialize_app(options, Some(unique_settings())).unwrap();
412        let ai = get_ai_service(Some(app)).unwrap();
413        let response = ai
414            .generate_text(GenerateTextRequest {
415                prompt: "Hello AI".to_string(),
416                model: Some("text-test".to_string()),
417            })
418            .unwrap();
419        assert_eq!(response.model, "text-test");
420        assert!(response.text.contains("generated 8 chars"));
421        assert!(response.text.contains("backend:GOOGLE_AI"));
422    }
423
424    #[test]
425    fn empty_prompt_errors() {
426        let options = FirebaseOptions {
427            project_id: Some("project".into()),
428            api_key: Some("api".into()),
429            app_id: Some("app".into()),
430            ..Default::default()
431        };
432        let app = initialize_app(options, Some(unique_settings())).unwrap();
433        let ai = get_ai_service(Some(app)).unwrap();
434        let err = ai
435            .generate_text(GenerateTextRequest {
436                prompt: "  ".to_string(),
437                model: None,
438            })
439            .unwrap_err();
440        assert_eq!(err.code_str(), "AI/invalid-argument");
441    }
442
443    #[test]
444    fn backend_identifier_creates_unique_instances() {
445        let options = FirebaseOptions {
446            project_id: Some("project".into()),
447            ..Default::default()
448        };
449        let app = initialize_app(options, Some(unique_settings())).unwrap();
450
451        let google = get_ai(
452            Some(app.clone()),
453            Some(AiOptions {
454                backend: Some(Backend::google_ai()),
455                use_limited_use_app_check_tokens: None,
456            }),
457        )
458        .unwrap();
459
460        let vertex = get_ai(
461            Some(app.clone()),
462            Some(AiOptions {
463                backend: Some(Backend::vertex_ai("europe-west4")),
464                use_limited_use_app_check_tokens: Some(true),
465            }),
466        )
467        .unwrap();
468
469        assert_ne!(Arc::as_ptr(&google), Arc::as_ptr(&vertex));
470        assert_eq!(vertex.location(), Some("europe-west4"));
471        assert!(vertex.options().use_limited_use_app_check_tokens);
472    }
473
474    #[test]
475    fn get_ai_reuses_cached_instance() {
476        let options = FirebaseOptions {
477            project_id: Some("project".into()),
478            api_key: Some("api".into()),
479            app_id: Some("app".into()),
480            ..Default::default()
481        };
482        let app = initialize_app(options, Some(unique_settings())).unwrap();
483
484        let first = get_ai_service(Some(app.clone())).unwrap();
485        first
486            .generate_text(GenerateTextRequest {
487                prompt: "ping".to_string(),
488                model: None,
489            })
490            .unwrap();
491
492        let second = get_ai(Some(app.clone()), None).unwrap();
493        assert_eq!(Arc::as_ptr(&first), Arc::as_ptr(&second));
494    }
495
496    #[test]
497    fn api_settings_require_project_id() {
498        let options = FirebaseOptions {
499            api_key: Some("api".into()),
500            app_id: Some("app".into()),
501            ..Default::default()
502        };
503        let app = initialize_app(options, Some(unique_settings())).unwrap();
504        let ai = get_ai_service(Some(app)).unwrap();
505        let err = ai.api_settings().unwrap_err();
506        assert_eq!(err.code(), AiErrorCode::NoProjectId);
507    }
508
509    #[test]
510    fn prepare_generate_content_request_builds_expected_url() {
511        let options = FirebaseOptions {
512            api_key: Some("api".into()),
513            project_id: Some("project".into()),
514            app_id: Some("app".into()),
515            ..Default::default()
516        };
517        let app = initialize_app(options, Some(unique_settings())).unwrap();
518        let ai = get_ai_service(Some(app)).unwrap();
519        let prepared = ai
520            .prepare_generate_content_request(
521                "models/gemini-1.5-flash",
522                json!({ "contents": [] }),
523                None,
524            )
525            .unwrap();
526        assert_eq!(
527            prepared.url.as_str(),
528            "https://firebasevertexai.googleapis.com/v1beta/projects/project/models/gemini-1.5-flash:generateContent"
529        );
530        assert_eq!(prepared.header("x-goog-api-key"), Some("api"));
531    }
532}