firebase_rs_sdk/ai/
api.rs

1use std::collections::HashMap;
2use std::fmt;
3use std::sync::{Arc, LazyLock, Mutex};
4
5use async_trait::async_trait;
6use serde_json::{json, Value};
7
8use crate::ai::backend::{Backend, BackendType};
9use crate::ai::constants::AI_COMPONENT_NAME;
10use crate::ai::error::{
11    internal_error, invalid_argument, AiError, AiErrorCode, AiResult, CustomErrorData,
12};
13use crate::ai::helpers::{decode_instance_identifier, encode_instance_identifier};
14use crate::ai::public_types::{AiOptions, AiRuntimeOptions};
15use crate::ai::requests::{ApiSettings, PreparedRequest, RequestFactory, RequestOptions, Task};
16use crate::app;
17use crate::app::{FirebaseApp, FirebaseOptions};
18use crate::app_check::FirebaseAppCheckInternal;
19use crate::auth::Auth;
20use crate::component::types::{
21    ComponentError, DynService, InstanceFactoryOptions, InstantiationMode,
22};
23use crate::component::{Component, ComponentType, Provider};
24
25#[derive(Clone)]
26pub struct AiService {
27    inner: Arc<AiInner>,
28}
29
30impl fmt::Debug for AiService {
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        f.debug_struct("AiService")
33            .field("app", &self.inner.app.name())
34            .field("backend", &self.inner.backend.backend_type())
35            .finish()
36    }
37}
38
39struct AiInner {
40    app: FirebaseApp,
41    backend: Backend,
42    options: Mutex<AiRuntimeOptions>,
43    default_model: Option<String>,
44    auth_provider: Provider,
45    app_check_provider: Provider,
46    transport: Mutex<Arc<dyn AiHttpTransport>>,
47    #[cfg(test)]
48    test_tokens: Mutex<TestTokenOverrides>,
49}
50
51#[derive(Clone, Debug, PartialEq, Eq, Hash)]
52struct CacheKey {
53    app_name: String,
54    identifier: String,
55}
56
57impl CacheKey {
58    fn new(app_name: &str, identifier: &str) -> Self {
59        Self {
60            app_name: app_name.to_string(),
61            identifier: identifier.to_string(),
62        }
63    }
64}
65
66static AI_OVERRIDES: LazyLock<Mutex<HashMap<CacheKey, Arc<AiService>>>> =
67    LazyLock::new(|| Mutex::new(HashMap::new()));
68
69#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
70#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
71trait AiHttpTransport: Send + Sync {
72    async fn send(&self, request: PreparedRequest) -> AiResult<Value>;
73}
74
75struct ReqwestTransport {
76    client: reqwest::Client,
77}
78
79impl Default for ReqwestTransport {
80    fn default() -> Self {
81        Self {
82            client: reqwest::Client::new(),
83        }
84    }
85}
86
87#[cfg(test)]
88#[derive(Default)]
89struct TestTokenOverrides {
90    auth: Option<String>,
91    app_check: Option<String>,
92    limited_app_check: Option<String>,
93}
94
95#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
96#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
97impl AiHttpTransport for ReqwestTransport {
98    async fn send(&self, request: PreparedRequest) -> AiResult<Value> {
99        let builder = request
100            .into_reqwest(&self.client)
101            .map_err(|err| internal_error(format!("failed to encode AI request: {err}")))?;
102
103        let response = builder.send().await.map_err(|err| {
104            AiError::new(
105                AiErrorCode::FetchError,
106                format!("failed to send AI request: {err}"),
107                None,
108            )
109        })?;
110
111        let status = response.status();
112        let bytes = response.bytes().await.map_err(|err| {
113            AiError::new(
114                AiErrorCode::FetchError,
115                format!("failed to read AI response body: {err}"),
116                None,
117            )
118        })?;
119
120        let parsed = serde_json::from_slice::<Value>(&bytes);
121
122        if !status.is_success() {
123            let mut data = CustomErrorData::default().with_status(status.as_u16());
124            if let Some(reason) = status.canonical_reason() {
125                data = data.with_status_text(reason);
126            }
127
128            return match parsed {
129                Ok(json) => {
130                    let message = AiService::extract_error_message(&json)
131                        .unwrap_or_else(|| format!("AI endpoint returned HTTP {status}"));
132                    Err(AiError::new(
133                        AiErrorCode::FetchError,
134                        message,
135                        Some(data.with_response(json)),
136                    ))
137                }
138                Err(_) => {
139                    let raw = String::from_utf8_lossy(&bytes).to_string();
140                    Err(AiError::new(
141                        AiErrorCode::FetchError,
142                        format!("AI endpoint returned HTTP {status}"),
143                        Some(data.with_response(json!({ "raw": raw }))),
144                    ))
145                }
146            };
147        }
148
149        parsed.map_err(|err| {
150            AiError::new(
151                AiErrorCode::ParseFailed,
152                format!("failed to parse AI response JSON: {err}"),
153                None,
154            )
155        })
156    }
157}
158
159#[derive(Clone, Debug, PartialEq)]
160pub struct GenerateTextRequest {
161    pub prompt: String,
162    pub model: Option<String>,
163    pub request_options: Option<RequestOptions>,
164}
165
166#[derive(Clone, Debug, PartialEq, Eq)]
167pub struct GenerateTextResponse {
168    pub text: String,
169    pub model: String,
170}
171
172impl AiService {
173    fn new(
174        app: FirebaseApp,
175        backend: Backend,
176        options: AiRuntimeOptions,
177        default_model: Option<String>,
178        auth_provider: Provider,
179        app_check_provider: Provider,
180    ) -> Self {
181        Self {
182            inner: Arc::new(AiInner {
183                app,
184                backend,
185                options: Mutex::new(options),
186                default_model,
187                auth_provider,
188                app_check_provider,
189                transport: Mutex::new(Arc::new(ReqwestTransport::default())),
190                #[cfg(test)]
191                test_tokens: Mutex::new(TestTokenOverrides::default()),
192            }),
193        }
194    }
195
196    /// Returns the Firebase app associated with this AI service.
197    pub fn app(&self) -> &FirebaseApp {
198        &self.inner.app
199    }
200
201    /// Returns the backend configuration used by this AI service.
202    pub fn backend(&self) -> &Backend {
203        &self.inner.backend
204    }
205
206    /// Returns the backend type tag.
207    pub fn backend_type(&self) -> BackendType {
208        self.inner.backend.backend_type()
209    }
210
211    /// Returns the Vertex AI location when using that backend.
212    pub fn location(&self) -> Option<&str> {
213        self.inner
214            .backend
215            .as_vertex_ai()
216            .map(|backend| backend.location())
217    }
218
219    /// Returns the runtime options currently applied to this AI service.
220    pub fn options(&self) -> AiRuntimeOptions {
221        self.inner.options.lock().unwrap().clone()
222    }
223
224    fn set_options(&self, options: AiRuntimeOptions) {
225        *self.inner.options.lock().unwrap() = options;
226    }
227
228    #[cfg(test)]
229    fn set_transport_for_tests(&self, transport: Arc<dyn AiHttpTransport>) {
230        *self.inner.transport.lock().unwrap() = transport;
231    }
232
233    #[cfg(test)]
234    fn override_tokens_for_tests(
235        &self,
236        auth: Option<String>,
237        app_check: Option<String>,
238        limited_app_check: Option<String>,
239    ) {
240        let mut overrides = self.inner.test_tokens.lock().unwrap();
241        overrides.auth = auth;
242        overrides.app_check = app_check;
243        overrides.limited_app_check = limited_app_check;
244    }
245
246    async fn fetch_auth_token(&self) -> AiResult<Option<String>> {
247        #[cfg(test)]
248        if let Some(token) = self.inner.test_tokens.lock().unwrap().auth.clone() {
249            return Ok(Some(token));
250        }
251
252        let auth = match self
253            .inner
254            .auth_provider
255            .get_immediate_with_options::<Auth>(None, true)
256        {
257            Ok(Some(auth)) => auth,
258            Ok(None) => return Ok(None),
259            Err(err) => {
260                return Err(internal_error(format!(
261                    "failed to resolve auth provider: {err}"
262                )))
263            }
264        };
265
266        match auth.get_token(false).await {
267            Ok(Some(token)) if token.is_empty() => Ok(None),
268            Ok(Some(token)) => Ok(Some(token)),
269            Ok(None) => Ok(None),
270            Err(err) => Err(internal_error(format!(
271                "failed to obtain auth token: {err}"
272            ))),
273        }
274    }
275
276    async fn fetch_app_check_credentials(
277        &self,
278        limited_use: bool,
279    ) -> AiResult<(Option<String>, Option<String>)> {
280        #[cfg(test)]
281        {
282            let overrides = self.inner.test_tokens.lock().unwrap();
283            if limited_use {
284                if let Some(token) = overrides.limited_app_check.clone() {
285                    return Ok((Some(token), None));
286                }
287            } else if let Some(token) = overrides.app_check.clone() {
288                return Ok((Some(token), None));
289            }
290        }
291
292        let app_check = match self
293            .inner
294            .app_check_provider
295            .get_immediate_with_options::<FirebaseAppCheckInternal>(None, true)
296        {
297            Ok(Some(app_check)) => app_check,
298            Ok(None) => return Ok((None, None)),
299            Err(err) => {
300                return Err(internal_error(format!(
301                    "failed to resolve App Check provider: {err}"
302                )))
303            }
304        };
305
306        let token = match if limited_use {
307            app_check.get_limited_use_token().await
308        } else {
309            app_check.get_token(false).await
310        } {
311            Ok(result) => Ok(result.token),
312            Err(err) => err
313                .cached_token()
314                .map(|cached| cached.token.clone())
315                .ok_or_else(|| internal_error(format!("failed to obtain App Check token: {err}"))),
316        }?;
317
318        if token.is_empty() {
319            return Ok((None, None));
320        }
321
322        let heartbeat = app_check.heartbeat_header().await.map_err(|err| {
323            internal_error(format!(
324                "failed to obtain App Check heartbeat header: {err}"
325            ))
326        })?;
327
328        Ok((Some(token), heartbeat))
329    }
330
331    pub(crate) async fn api_settings(&self) -> AiResult<ApiSettings> {
332        let options = self.inner.app.options();
333        let FirebaseOptions {
334            api_key,
335            project_id,
336            app_id,
337            ..
338        } = options;
339
340        let api_key = api_key.ok_or_else(|| {
341            AiError::new(
342                AiErrorCode::NoApiKey,
343                "Firebase options must include `api_key` to use Firebase AI endpoints",
344                None,
345            )
346        })?;
347        let project_id = project_id.ok_or_else(|| {
348            AiError::new(
349                AiErrorCode::NoProjectId,
350                "Firebase options must include `project_id` to use Firebase AI endpoints",
351                None,
352            )
353        })?;
354        let app_id = app_id.ok_or_else(|| {
355            AiError::new(
356                AiErrorCode::NoAppId,
357                "Firebase options must include `app_id` to use Firebase AI endpoints",
358                None,
359            )
360        })?;
361
362        let runtime_options = self.options();
363        let automatic = self.inner.app.automatic_data_collection_enabled();
364        let (app_check_token, app_check_heartbeat) = self
365            .fetch_app_check_credentials(runtime_options.use_limited_use_app_check_tokens)
366            .await?;
367        let auth_token = self.fetch_auth_token().await?;
368
369        Ok(ApiSettings::new(
370            api_key,
371            project_id,
372            app_id,
373            self.inner.backend.clone(),
374            automatic,
375            app_check_token,
376            app_check_heartbeat,
377            auth_token,
378        ))
379    }
380
381    pub(crate) async fn request_factory(&self) -> AiResult<RequestFactory> {
382        Ok(RequestFactory::new(self.api_settings().await?))
383    }
384
385    /// Prepares a REST request for a `generateContent` call without executing it.
386    ///
387    /// This mirrors the behaviour of `constructRequest` in the TypeScript SDK and allows advanced
388    /// callers to integrate with custom HTTP stacks while the SDK handles URL/header generation.
389    pub async fn prepare_generate_content_request(
390        &self,
391        model: &str,
392        body: Value,
393        request_options: Option<RequestOptions>,
394    ) -> AiResult<PreparedRequest> {
395        let factory = self.request_factory().await?;
396        factory.construct_request(model, Task::GenerateContent, false, body, request_options)
397    }
398
399    /// Generates text using the configured backend.
400    ///
401    /// This issues a `generateContent` REST call against the active backend, attaching
402    /// auth and App Check credentials when available. The optional
403    /// [`RequestOptions`] can override the base URL or timeout, which is primarily
404    /// intended for tests and emulator scenarios.
405    ///
406    /// # Examples
407    ///
408    /// ```no_run
409    /// # use firebase_rs_sdk::ai::{AiService, GenerateTextRequest};
410    /// # async fn example(ai: AiService) -> firebase_rs_sdk::ai::error::AiResult<()> {
411    /// let response = ai
412    ///     .generate_text(GenerateTextRequest {
413    ///         prompt: "Hello Gemini".to_owned(),
414    ///         model: None,
415    ///         request_options: None,
416    ///     })
417    ///     .await?;
418    /// # Ok(())
419    /// # }
420    /// ```
421    pub async fn generate_text(
422        &self,
423        request: GenerateTextRequest,
424    ) -> AiResult<GenerateTextResponse> {
425        if request.prompt.trim().is_empty() {
426            return Err(invalid_argument("Prompt must not be empty"));
427        }
428        let model = request
429            .model
430            .or_else(|| self.inner.default_model.clone())
431            .unwrap_or_else(|| "text-bison-001".to_string());
432
433        let body = Self::build_generate_text_body(&request.prompt);
434        let prepared = self
435            .prepare_generate_content_request(&model, body, request.request_options.clone())
436            .await?;
437        let response = self.execute_prepared_request(prepared).await?;
438        let text = match Self::extract_text_from_response(&response) {
439            Some(text) => text,
440            None => {
441                return Err(AiError::new(
442                    AiErrorCode::ResponseError,
443                    "AI response did not contain textual content",
444                    Some(CustomErrorData::default().with_response(response)),
445                ))
446            }
447        };
448
449        Ok(GenerateTextResponse { text, model })
450    }
451
452    async fn execute_prepared_request(&self, prepared: PreparedRequest) -> AiResult<Value> {
453        let transport = self.inner.transport.lock().unwrap().clone();
454        transport.send(prepared).await
455    }
456
457    fn build_generate_text_body(prompt: &str) -> Value {
458        json!({
459            "contents": [
460                {
461                    "role": "user",
462                    "parts": [
463                        {
464                            "text": prompt,
465                        }
466                    ]
467                }
468            ]
469        })
470    }
471
472    fn extract_text_from_response(response: &Value) -> Option<String> {
473        if let Some(candidates) = response
474            .get("candidates")
475            .and_then(|value| value.as_array())
476        {
477            for candidate in candidates {
478                if let Some(text) = Self::extract_text_from_candidate(candidate) {
479                    if !text.trim().is_empty() {
480                        return Some(text);
481                    }
482                }
483            }
484        }
485
486        response
487            .get("output")
488            .and_then(|value| value.as_str())
489            .map(|value| value.to_string())
490    }
491
492    fn extract_text_from_candidate(candidate: &Value) -> Option<String> {
493        if let Some(content) = candidate.get("content") {
494            if let Some(parts) = content.get("parts").and_then(|value| value.as_array()) {
495                for part in parts {
496                    if let Some(text) = part.get("text").and_then(|value| value.as_str()) {
497                        if !text.is_empty() {
498                            return Some(text.to_string());
499                        }
500                    }
501                }
502            }
503        }
504
505        candidate
506            .get("output")
507            .and_then(|value| value.as_str())
508            .map(|value| value.to_string())
509    }
510
511    fn extract_error_message(value: &Value) -> Option<String> {
512        if let Some(error) = value.get("error") {
513            if let Some(message) = error.get("message").and_then(|v| v.as_str()) {
514                return Some(message.to_string());
515            }
516        }
517
518        value
519            .get("message")
520            .and_then(|v| v.as_str())
521            .map(|message| message.to_string())
522    }
523}
524
525#[derive(Debug)]
526struct Cache;
527
528impl Cache {
529    fn get(key: &CacheKey) -> Option<Arc<AiService>> {
530        AI_OVERRIDES.lock().unwrap().get(key).cloned()
531    }
532
533    fn insert(key: CacheKey, service: Arc<AiService>) {
534        AI_OVERRIDES.lock().unwrap().insert(key, service);
535    }
536}
537
538static AI_COMPONENT: LazyLock<()> = LazyLock::new(|| {
539    let component = Component::new(
540        AI_COMPONENT_NAME,
541        Arc::new(ai_factory),
542        ComponentType::Public,
543    )
544    .with_instantiation_mode(InstantiationMode::Lazy)
545    .with_multiple_instances(true);
546    let _ = app::register_component(component);
547});
548
549fn ai_factory(
550    container: &crate::component::ComponentContainer,
551    options: InstanceFactoryOptions,
552) -> Result<DynService, ComponentError> {
553    let app = container.root_service::<FirebaseApp>().ok_or_else(|| {
554        ComponentError::InitializationFailed {
555            name: AI_COMPONENT_NAME.to_string(),
556            reason: "Firebase app not attached to component container".to_string(),
557        }
558    })?;
559
560    let identifier_backend = options
561        .instance_identifier
562        .as_deref()
563        .map(|identifier| decode_instance_identifier(identifier));
564
565    let auth_provider = container.get_provider("auth-internal");
566    let app_check_provider = container.get_provider("app-check-internal");
567
568    let backend = match identifier_backend {
569        Some(Ok(backend)) => backend,
570        Some(Err(err)) => {
571            return Err(ComponentError::InitializationFailed {
572                name: AI_COMPONENT_NAME.to_string(),
573                reason: err.to_string(),
574            })
575        }
576        None => {
577            if let Some(encoded) = options
578                .options
579                .get("backend")
580                .and_then(|value| value.as_str())
581            {
582                decode_instance_identifier(encoded).map_err(|err| {
583                    ComponentError::InitializationFailed {
584                        name: AI_COMPONENT_NAME.to_string(),
585                        reason: err.to_string(),
586                    }
587                })?
588            } else {
589                Backend::default()
590            }
591        }
592    };
593
594    let use_limited_tokens = options
595        .options
596        .get("useLimitedUseAppCheckTokens")
597        .and_then(|value| value.as_bool())
598        .unwrap_or(false);
599
600    let runtime_options = AiRuntimeOptions {
601        use_limited_use_app_check_tokens: use_limited_tokens,
602    };
603
604    let default_model = options
605        .options
606        .get("defaultModel")
607        .and_then(|value| value.as_str().map(|s| s.to_string()));
608
609    let service = AiService::new(
610        (*app).clone(),
611        backend,
612        runtime_options,
613        default_model,
614        auth_provider,
615        app_check_provider,
616    );
617    Ok(Arc::new(service) as DynService)
618}
619
620fn ensure_registered() {
621    LazyLock::force(&AI_COMPONENT);
622}
623
624/// Registers the AI component in the global registry.
625pub fn register_ai_component() {
626    ensure_registered();
627}
628
629/// Returns an AI service instance, mirroring the JavaScript `getAI()` API.
630///
631/// When `options` is provided the backend identifier is encoded using the same
632/// rules as `encodeInstanceIdentifier` from the JavaScript SDK so that separate
633/// backend configurations create independent service instances.
634///
635/// # Examples
636///
637/// ```
638/// # use firebase_rs_sdk::ai::backend::Backend;
639/// # use firebase_rs_sdk::ai::public_types::AiOptions;
640/// # use firebase_rs_sdk::ai::get_ai;
641/// # use firebase_rs_sdk::app::initialize_app;
642/// # use firebase_rs_sdk::app::{FirebaseAppSettings, FirebaseOptions};
643/// # async fn example() {
644/// let options = FirebaseOptions {
645///     project_id: Some("project".into()),
646///     api_key: Some("test".into()),
647///     ..Default::default()
648/// };
649/// let app = initialize_app(options, Some(FirebaseAppSettings::default())).await.unwrap();
650/// let ai = get_ai(
651///     Some(app),
652///     Some(AiOptions {
653///         backend: Some(Backend::vertex_ai("us-central1")),
654///         use_limited_use_app_check_tokens: Some(false),
655///     }),
656/// )
657/// .await
658/// .unwrap();
659/// # }
660/// ```
661pub async fn get_ai(
662    app: Option<FirebaseApp>,
663    options: Option<AiOptions>,
664) -> AiResult<Arc<AiService>> {
665    ensure_registered();
666    let app = match app {
667        Some(app) => app,
668        None => crate::app::get_app(None)
669            .await
670            .map_err(|err| internal_error(err.to_string()))?,
671    };
672
673    let options = options.unwrap_or_default();
674    let backend = options.backend_or_default();
675    let identifier = encode_instance_identifier(&backend);
676    let runtime_options = AiRuntimeOptions {
677        use_limited_use_app_check_tokens: options.limited_use_app_check(),
678    };
679
680    let cache_key = CacheKey::new(app.name(), &identifier);
681    if let Some(service) = Cache::get(&cache_key) {
682        service.set_options(runtime_options.clone());
683        return Ok(service);
684    }
685
686    let provider = app::get_provider(&app, AI_COMPONENT_NAME);
687
688    if let Some(service) = provider
689        .get_immediate_with_options::<AiService>(Some(&identifier), true)
690        .map_err(|err| internal_error(err.to_string()))?
691    {
692        service.set_options(runtime_options.clone());
693        Cache::insert(cache_key.clone(), service.clone());
694        return Ok(service);
695    }
696
697    match provider.initialize::<AiService>(
698        json!({
699            "backend": identifier,
700            "useLimitedUseAppCheckTokens": runtime_options.use_limited_use_app_check_tokens,
701        }),
702        Some(&cache_key.identifier),
703    ) {
704        Ok(service) => {
705            service.set_options(runtime_options.clone());
706            Cache::insert(cache_key.clone(), service.clone());
707            Ok(service)
708        }
709        Err(ComponentError::InstanceUnavailable { .. }) => {
710            if let Some(service) = provider
711                .get_immediate_with_options::<AiService>(Some(&cache_key.identifier), true)
712                .map_err(|err| internal_error(err.to_string()))?
713            {
714                service.set_options(runtime_options.clone());
715                Cache::insert(cache_key.clone(), service.clone());
716                Ok(service)
717            } else {
718                let container = app.container();
719                let fallback = Arc::new(AiService::new(
720                    app.clone(),
721                    backend,
722                    runtime_options,
723                    None,
724                    container.get_provider("auth-internal"),
725                    container.get_provider("app-check-internal"),
726                ));
727                Cache::insert(cache_key.clone(), fallback.clone());
728                Ok(fallback)
729            }
730        }
731        Err(err) => Err(internal_error(err.to_string())),
732    }
733}
734
735/// Convenience wrapper that mirrors the original Rust stub signature.
736pub async fn get_ai_service(app: Option<FirebaseApp>) -> AiResult<Arc<AiService>> {
737    get_ai(app, None).await
738}
739
740#[cfg(test)]
741mod tests {
742    use super::*;
743    use crate::ai::backend::Backend;
744    use crate::ai::error::AiErrorCode;
745    use crate::ai::public_types::AiOptions;
746    use crate::app::initialize_app;
747    use crate::app::{FirebaseAppSettings, FirebaseOptions};
748    use async_trait::async_trait;
749    use serde_json::json;
750    use std::collections::VecDeque;
751    use std::sync::atomic::{AtomicUsize, Ordering};
752    use std::sync::Arc;
753
754    fn unique_settings() -> FirebaseAppSettings {
755        static COUNTER: AtomicUsize = AtomicUsize::new(0);
756        FirebaseAppSettings {
757            name: Some(format!("ai-{}", COUNTER.fetch_add(1, Ordering::SeqCst))),
758            ..Default::default()
759        }
760    }
761
762    #[derive(Clone, Default)]
763    struct TestTransport {
764        responses: Arc<Mutex<VecDeque<AiResult<Value>>>>,
765        requests: Arc<Mutex<Vec<PreparedRequest>>>,
766    }
767
768    impl TestTransport {
769        fn new() -> Self {
770            Self::default()
771        }
772
773        fn push_response(&self, response: AiResult<Value>) {
774            self.responses.lock().unwrap().push_back(response);
775        }
776
777        fn take_requests(&self) -> Vec<PreparedRequest> {
778            self.requests.lock().unwrap().clone()
779        }
780    }
781
782    #[cfg_attr(not(target_arch = "wasm32"), async_trait)]
783    #[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
784    impl AiHttpTransport for TestTransport {
785        async fn send(&self, request: PreparedRequest) -> AiResult<Value> {
786            self.requests.lock().unwrap().push(request);
787            self.responses
788                .lock()
789                .unwrap()
790                .pop_front()
791                .unwrap_or_else(|| Err(internal_error("no stub response configured")))
792        }
793    }
794
795    #[tokio::test(flavor = "current_thread")]
796    async fn generate_text_includes_backend_info() {
797        let transport = TestTransport::new();
798        transport.push_response(Ok(json!({
799            "candidates": [
800                {
801                    "content": {
802                        "parts": [
803                            { "text": "Hello from mock" }
804                        ]
805                    }
806                }
807            ]
808        })));
809
810        let options = FirebaseOptions {
811            project_id: Some("project".into()),
812            api_key: Some("api".into()),
813            app_id: Some("app".into()),
814            ..Default::default()
815        };
816        let app = initialize_app(options, Some(unique_settings()))
817            .await
818            .unwrap();
819        let ai = get_ai_service(Some(app)).await.unwrap();
820        ai.set_transport_for_tests(Arc::new(transport.clone()));
821        let response = ai
822            .generate_text(GenerateTextRequest {
823                prompt: "Hello AI".to_string(),
824                model: Some("models/gemini-pro".to_string()),
825                request_options: None,
826            })
827            .await
828            .unwrap();
829
830        assert_eq!(response.model, "models/gemini-pro");
831        assert_eq!(response.text, "Hello from mock");
832
833        let requests = transport.take_requests();
834        assert_eq!(requests.len(), 1);
835        assert!(requests[0]
836            .url
837            .as_str()
838            .ends_with("models/gemini-pro:generateContent"));
839        assert_eq!(requests[0].header("x-goog-api-key"), Some("api"));
840    }
841
842    #[tokio::test(flavor = "current_thread")]
843    async fn limited_use_app_check_token_attached_to_requests() {
844        let transport = TestTransport::new();
845        transport.push_response(Ok(json!({
846            "candidates": [
847                {
848                    "content": {
849                        "parts": [
850                            { "text": "Limited token response" }
851                        ]
852                    }
853                }
854            ]
855        })));
856
857        let options = FirebaseOptions {
858            project_id: Some("project".into()),
859            api_key: Some("api".into()),
860            app_id: Some("app".into()),
861            ..Default::default()
862        };
863        let app = initialize_app(options, Some(unique_settings()))
864            .await
865            .unwrap();
866
867        let ai = get_ai(
868            Some(app),
869            Some(AiOptions {
870                backend: Some(Backend::google_ai()),
871                use_limited_use_app_check_tokens: Some(true),
872            }),
873        )
874        .await
875        .unwrap();
876
877        ai.set_transport_for_tests(Arc::new(transport.clone()));
878        ai.override_tokens_for_tests(
879            None,
880            Some("standard-token".into()),
881            Some("limited-token".into()),
882        );
883
884        let response = ai
885            .generate_text(GenerateTextRequest {
886                prompt: "token test".to_string(),
887                model: Some("models/gemini-pro".to_string()),
888                request_options: None,
889            })
890            .await
891            .unwrap();
892
893        assert_eq!(response.text, "Limited token response");
894
895        let requests = transport.take_requests();
896        assert_eq!(requests.len(), 1);
897        assert_eq!(
898            requests[0].header("x-firebase-appcheck"),
899            Some("limited-token")
900        );
901    }
902
903    #[tokio::test(flavor = "current_thread")]
904    async fn empty_prompt_errors() {
905        let options = FirebaseOptions {
906            project_id: Some("project".into()),
907            api_key: Some("api".into()),
908            app_id: Some("app".into()),
909            ..Default::default()
910        };
911        let app = initialize_app(options, Some(unique_settings()))
912            .await
913            .unwrap();
914        let ai = get_ai_service(Some(app)).await.unwrap();
915        let err = ai
916            .generate_text(GenerateTextRequest {
917                prompt: "  ".to_string(),
918                model: None,
919                request_options: None,
920            })
921            .await
922            .unwrap_err();
923        assert_eq!(err.code_str(), "AI/invalid-argument");
924    }
925
926    #[tokio::test(flavor = "current_thread")]
927    async fn backend_identifier_creates_unique_instances() {
928        let options = FirebaseOptions {
929            project_id: Some("project".into()),
930            ..Default::default()
931        };
932        let app = initialize_app(options, Some(unique_settings()))
933            .await
934            .unwrap();
935
936        let google = get_ai(
937            Some(app.clone()),
938            Some(AiOptions {
939                backend: Some(Backend::google_ai()),
940                use_limited_use_app_check_tokens: None,
941            }),
942        )
943        .await
944        .unwrap();
945
946        let vertex = get_ai(
947            Some(app.clone()),
948            Some(AiOptions {
949                backend: Some(Backend::vertex_ai("europe-west4")),
950                use_limited_use_app_check_tokens: Some(true),
951            }),
952        )
953        .await
954        .unwrap();
955
956        assert_ne!(Arc::as_ptr(&google), Arc::as_ptr(&vertex));
957        assert_eq!(vertex.location(), Some("europe-west4"));
958        assert!(vertex.options().use_limited_use_app_check_tokens);
959    }
960
961    #[tokio::test(flavor = "current_thread")]
962    async fn get_ai_reuses_cached_instance() {
963        let options = FirebaseOptions {
964            project_id: Some("project".into()),
965            api_key: Some("api".into()),
966            app_id: Some("app".into()),
967            ..Default::default()
968        };
969        let app = initialize_app(options, Some(unique_settings()))
970            .await
971            .unwrap();
972
973        let first = get_ai_service(Some(app.clone())).await.unwrap();
974        first
975            .prepare_generate_content_request("models/test-model", json!({ "contents": [] }), None)
976            .await
977            .unwrap();
978
979        let second = get_ai(Some(app.clone()), None).await.unwrap();
980        assert_eq!(Arc::as_ptr(&first), Arc::as_ptr(&second));
981    }
982
983    #[tokio::test(flavor = "current_thread")]
984    async fn api_settings_require_project_id() {
985        let options = FirebaseOptions {
986            api_key: Some("api".into()),
987            app_id: Some("app".into()),
988            ..Default::default()
989        };
990        let app = initialize_app(options, Some(unique_settings()))
991            .await
992            .unwrap();
993        let ai = get_ai_service(Some(app)).await.unwrap();
994        let err = ai.api_settings().await.unwrap_err();
995        assert_eq!(err.code(), AiErrorCode::NoProjectId);
996    }
997
998    #[tokio::test(flavor = "current_thread")]
999    async fn prepare_generate_content_request_builds_expected_url() {
1000        let options = FirebaseOptions {
1001            api_key: Some("api".into()),
1002            project_id: Some("project".into()),
1003            app_id: Some("app".into()),
1004            ..Default::default()
1005        };
1006        let app = initialize_app(options, Some(unique_settings()))
1007            .await
1008            .unwrap();
1009        let ai = get_ai_service(Some(app)).await.unwrap();
1010        let prepared = ai
1011            .prepare_generate_content_request(
1012                "models/gemini-1.5-flash",
1013                json!({ "contents": [] }),
1014                None,
1015            )
1016            .await
1017            .unwrap();
1018        assert_eq!(
1019            prepared.url.as_str(),
1020            "https://firebasevertexai.googleapis.com/v1beta/projects/project/models/gemini-1.5-flash:generateContent"
1021        );
1022        assert_eq!(prepared.header("x-goog-api-key"), Some("api"));
1023    }
1024}