1use std::collections::HashMap;
2use std::fmt;
3use std::sync::{Arc, LazyLock, Mutex};
4
5use async_trait::async_trait;
6use serde_json::{json, Value};
7
8use crate::ai::backend::{Backend, BackendType};
9use crate::ai::constants::AI_COMPONENT_NAME;
10use crate::ai::error::{
11 internal_error, invalid_argument, AiError, AiErrorCode, AiResult, CustomErrorData,
12};
13use crate::ai::helpers::{decode_instance_identifier, encode_instance_identifier};
14use crate::ai::public_types::{AiOptions, AiRuntimeOptions};
15use crate::ai::requests::{ApiSettings, PreparedRequest, RequestFactory, RequestOptions, Task};
16use crate::app;
17use crate::app::{FirebaseApp, FirebaseOptions};
18use crate::app_check::FirebaseAppCheckInternal;
19use crate::auth::Auth;
20use crate::component::types::{
21 ComponentError, DynService, InstanceFactoryOptions, InstantiationMode,
22};
23use crate::component::{Component, ComponentType, Provider};
24
25#[derive(Clone)]
26pub struct AiService {
27 inner: Arc<AiInner>,
28}
29
30impl fmt::Debug for AiService {
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 f.debug_struct("AiService")
33 .field("app", &self.inner.app.name())
34 .field("backend", &self.inner.backend.backend_type())
35 .finish()
36 }
37}
38
39struct AiInner {
40 app: FirebaseApp,
41 backend: Backend,
42 options: Mutex<AiRuntimeOptions>,
43 default_model: Option<String>,
44 auth_provider: Provider,
45 app_check_provider: Provider,
46 transport: Mutex<Arc<dyn AiHttpTransport>>,
47 #[cfg(test)]
48 test_tokens: Mutex<TestTokenOverrides>,
49}
50
51#[derive(Clone, Debug, PartialEq, Eq, Hash)]
52struct CacheKey {
53 app_name: String,
54 identifier: String,
55}
56
57impl CacheKey {
58 fn new(app_name: &str, identifier: &str) -> Self {
59 Self {
60 app_name: app_name.to_string(),
61 identifier: identifier.to_string(),
62 }
63 }
64}
65
66static AI_OVERRIDES: LazyLock<Mutex<HashMap<CacheKey, Arc<AiService>>>> =
67 LazyLock::new(|| Mutex::new(HashMap::new()));
68
69#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
70#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
71trait AiHttpTransport: Send + Sync {
72 async fn send(&self, request: PreparedRequest) -> AiResult<Value>;
73}
74
75struct ReqwestTransport {
76 client: reqwest::Client,
77}
78
79impl Default for ReqwestTransport {
80 fn default() -> Self {
81 Self {
82 client: reqwest::Client::new(),
83 }
84 }
85}
86
87#[cfg(test)]
88#[derive(Default)]
89struct TestTokenOverrides {
90 auth: Option<String>,
91 app_check: Option<String>,
92 limited_app_check: Option<String>,
93}
94
95#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
96#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
97impl AiHttpTransport for ReqwestTransport {
98 async fn send(&self, request: PreparedRequest) -> AiResult<Value> {
99 let builder = request
100 .into_reqwest(&self.client)
101 .map_err(|err| internal_error(format!("failed to encode AI request: {err}")))?;
102
103 let response = builder.send().await.map_err(|err| {
104 AiError::new(
105 AiErrorCode::FetchError,
106 format!("failed to send AI request: {err}"),
107 None,
108 )
109 })?;
110
111 let status = response.status();
112 let bytes = response.bytes().await.map_err(|err| {
113 AiError::new(
114 AiErrorCode::FetchError,
115 format!("failed to read AI response body: {err}"),
116 None,
117 )
118 })?;
119
120 let parsed = serde_json::from_slice::<Value>(&bytes);
121
122 if !status.is_success() {
123 let mut data = CustomErrorData::default().with_status(status.as_u16());
124 if let Some(reason) = status.canonical_reason() {
125 data = data.with_status_text(reason);
126 }
127
128 return match parsed {
129 Ok(json) => {
130 let message = AiService::extract_error_message(&json)
131 .unwrap_or_else(|| format!("AI endpoint returned HTTP {status}"));
132 Err(AiError::new(
133 AiErrorCode::FetchError,
134 message,
135 Some(data.with_response(json)),
136 ))
137 }
138 Err(_) => {
139 let raw = String::from_utf8_lossy(&bytes).to_string();
140 Err(AiError::new(
141 AiErrorCode::FetchError,
142 format!("AI endpoint returned HTTP {status}"),
143 Some(data.with_response(json!({ "raw": raw }))),
144 ))
145 }
146 };
147 }
148
149 parsed.map_err(|err| {
150 AiError::new(
151 AiErrorCode::ParseFailed,
152 format!("failed to parse AI response JSON: {err}"),
153 None,
154 )
155 })
156 }
157}
158
159#[derive(Clone, Debug, PartialEq)]
160pub struct GenerateTextRequest {
161 pub prompt: String,
162 pub model: Option<String>,
163 pub request_options: Option<RequestOptions>,
164}
165
166#[derive(Clone, Debug, PartialEq, Eq)]
167pub struct GenerateTextResponse {
168 pub text: String,
169 pub model: String,
170}
171
172impl AiService {
173 fn new(
174 app: FirebaseApp,
175 backend: Backend,
176 options: AiRuntimeOptions,
177 default_model: Option<String>,
178 auth_provider: Provider,
179 app_check_provider: Provider,
180 ) -> Self {
181 Self {
182 inner: Arc::new(AiInner {
183 app,
184 backend,
185 options: Mutex::new(options),
186 default_model,
187 auth_provider,
188 app_check_provider,
189 transport: Mutex::new(Arc::new(ReqwestTransport::default())),
190 #[cfg(test)]
191 test_tokens: Mutex::new(TestTokenOverrides::default()),
192 }),
193 }
194 }
195
196 pub fn app(&self) -> &FirebaseApp {
198 &self.inner.app
199 }
200
201 pub fn backend(&self) -> &Backend {
203 &self.inner.backend
204 }
205
206 pub fn backend_type(&self) -> BackendType {
208 self.inner.backend.backend_type()
209 }
210
211 pub fn location(&self) -> Option<&str> {
213 self.inner
214 .backend
215 .as_vertex_ai()
216 .map(|backend| backend.location())
217 }
218
219 pub fn options(&self) -> AiRuntimeOptions {
221 self.inner.options.lock().unwrap().clone()
222 }
223
224 fn set_options(&self, options: AiRuntimeOptions) {
225 *self.inner.options.lock().unwrap() = options;
226 }
227
228 #[cfg(test)]
229 fn set_transport_for_tests(&self, transport: Arc<dyn AiHttpTransport>) {
230 *self.inner.transport.lock().unwrap() = transport;
231 }
232
233 #[cfg(test)]
234 fn override_tokens_for_tests(
235 &self,
236 auth: Option<String>,
237 app_check: Option<String>,
238 limited_app_check: Option<String>,
239 ) {
240 let mut overrides = self.inner.test_tokens.lock().unwrap();
241 overrides.auth = auth;
242 overrides.app_check = app_check;
243 overrides.limited_app_check = limited_app_check;
244 }
245
246 async fn fetch_auth_token(&self) -> AiResult<Option<String>> {
247 #[cfg(test)]
248 if let Some(token) = self.inner.test_tokens.lock().unwrap().auth.clone() {
249 return Ok(Some(token));
250 }
251
252 let auth = match self
253 .inner
254 .auth_provider
255 .get_immediate_with_options::<Auth>(None, true)
256 {
257 Ok(Some(auth)) => auth,
258 Ok(None) => return Ok(None),
259 Err(err) => {
260 return Err(internal_error(format!(
261 "failed to resolve auth provider: {err}"
262 )))
263 }
264 };
265
266 match auth.get_token(false).await {
267 Ok(Some(token)) if token.is_empty() => Ok(None),
268 Ok(Some(token)) => Ok(Some(token)),
269 Ok(None) => Ok(None),
270 Err(err) => Err(internal_error(format!(
271 "failed to obtain auth token: {err}"
272 ))),
273 }
274 }
275
276 async fn fetch_app_check_credentials(
277 &self,
278 limited_use: bool,
279 ) -> AiResult<(Option<String>, Option<String>)> {
280 #[cfg(test)]
281 {
282 let overrides = self.inner.test_tokens.lock().unwrap();
283 if limited_use {
284 if let Some(token) = overrides.limited_app_check.clone() {
285 return Ok((Some(token), None));
286 }
287 } else if let Some(token) = overrides.app_check.clone() {
288 return Ok((Some(token), None));
289 }
290 }
291
292 let app_check = match self
293 .inner
294 .app_check_provider
295 .get_immediate_with_options::<FirebaseAppCheckInternal>(None, true)
296 {
297 Ok(Some(app_check)) => app_check,
298 Ok(None) => return Ok((None, None)),
299 Err(err) => {
300 return Err(internal_error(format!(
301 "failed to resolve App Check provider: {err}"
302 )))
303 }
304 };
305
306 let token = match if limited_use {
307 app_check.get_limited_use_token().await
308 } else {
309 app_check.get_token(false).await
310 } {
311 Ok(result) => Ok(result.token),
312 Err(err) => err
313 .cached_token()
314 .map(|cached| cached.token.clone())
315 .ok_or_else(|| internal_error(format!("failed to obtain App Check token: {err}"))),
316 }?;
317
318 if token.is_empty() {
319 return Ok((None, None));
320 }
321
322 let heartbeat = app_check.heartbeat_header().await.map_err(|err| {
323 internal_error(format!(
324 "failed to obtain App Check heartbeat header: {err}"
325 ))
326 })?;
327
328 Ok((Some(token), heartbeat))
329 }
330
331 pub(crate) async fn api_settings(&self) -> AiResult<ApiSettings> {
332 let options = self.inner.app.options();
333 let FirebaseOptions {
334 api_key,
335 project_id,
336 app_id,
337 ..
338 } = options;
339
340 let api_key = api_key.ok_or_else(|| {
341 AiError::new(
342 AiErrorCode::NoApiKey,
343 "Firebase options must include `api_key` to use Firebase AI endpoints",
344 None,
345 )
346 })?;
347 let project_id = project_id.ok_or_else(|| {
348 AiError::new(
349 AiErrorCode::NoProjectId,
350 "Firebase options must include `project_id` to use Firebase AI endpoints",
351 None,
352 )
353 })?;
354 let app_id = app_id.ok_or_else(|| {
355 AiError::new(
356 AiErrorCode::NoAppId,
357 "Firebase options must include `app_id` to use Firebase AI endpoints",
358 None,
359 )
360 })?;
361
362 let runtime_options = self.options();
363 let automatic = self.inner.app.automatic_data_collection_enabled();
364 let (app_check_token, app_check_heartbeat) = self
365 .fetch_app_check_credentials(runtime_options.use_limited_use_app_check_tokens)
366 .await?;
367 let auth_token = self.fetch_auth_token().await?;
368
369 Ok(ApiSettings::new(
370 api_key,
371 project_id,
372 app_id,
373 self.inner.backend.clone(),
374 automatic,
375 app_check_token,
376 app_check_heartbeat,
377 auth_token,
378 ))
379 }
380
381 pub(crate) async fn request_factory(&self) -> AiResult<RequestFactory> {
382 Ok(RequestFactory::new(self.api_settings().await?))
383 }
384
385 pub async fn prepare_generate_content_request(
390 &self,
391 model: &str,
392 body: Value,
393 request_options: Option<RequestOptions>,
394 ) -> AiResult<PreparedRequest> {
395 let factory = self.request_factory().await?;
396 factory.construct_request(model, Task::GenerateContent, false, body, request_options)
397 }
398
399 pub async fn generate_text(
422 &self,
423 request: GenerateTextRequest,
424 ) -> AiResult<GenerateTextResponse> {
425 if request.prompt.trim().is_empty() {
426 return Err(invalid_argument("Prompt must not be empty"));
427 }
428 let model = request
429 .model
430 .or_else(|| self.inner.default_model.clone())
431 .unwrap_or_else(|| "text-bison-001".to_string());
432
433 let body = Self::build_generate_text_body(&request.prompt);
434 let prepared = self
435 .prepare_generate_content_request(&model, body, request.request_options.clone())
436 .await?;
437 let response = self.execute_prepared_request(prepared).await?;
438 let text = match Self::extract_text_from_response(&response) {
439 Some(text) => text,
440 None => {
441 return Err(AiError::new(
442 AiErrorCode::ResponseError,
443 "AI response did not contain textual content",
444 Some(CustomErrorData::default().with_response(response)),
445 ))
446 }
447 };
448
449 Ok(GenerateTextResponse { text, model })
450 }
451
452 async fn execute_prepared_request(&self, prepared: PreparedRequest) -> AiResult<Value> {
453 let transport = self.inner.transport.lock().unwrap().clone();
454 transport.send(prepared).await
455 }
456
457 fn build_generate_text_body(prompt: &str) -> Value {
458 json!({
459 "contents": [
460 {
461 "role": "user",
462 "parts": [
463 {
464 "text": prompt,
465 }
466 ]
467 }
468 ]
469 })
470 }
471
472 fn extract_text_from_response(response: &Value) -> Option<String> {
473 if let Some(candidates) = response
474 .get("candidates")
475 .and_then(|value| value.as_array())
476 {
477 for candidate in candidates {
478 if let Some(text) = Self::extract_text_from_candidate(candidate) {
479 if !text.trim().is_empty() {
480 return Some(text);
481 }
482 }
483 }
484 }
485
486 response
487 .get("output")
488 .and_then(|value| value.as_str())
489 .map(|value| value.to_string())
490 }
491
492 fn extract_text_from_candidate(candidate: &Value) -> Option<String> {
493 if let Some(content) = candidate.get("content") {
494 if let Some(parts) = content.get("parts").and_then(|value| value.as_array()) {
495 for part in parts {
496 if let Some(text) = part.get("text").and_then(|value| value.as_str()) {
497 if !text.is_empty() {
498 return Some(text.to_string());
499 }
500 }
501 }
502 }
503 }
504
505 candidate
506 .get("output")
507 .and_then(|value| value.as_str())
508 .map(|value| value.to_string())
509 }
510
511 fn extract_error_message(value: &Value) -> Option<String> {
512 if let Some(error) = value.get("error") {
513 if let Some(message) = error.get("message").and_then(|v| v.as_str()) {
514 return Some(message.to_string());
515 }
516 }
517
518 value
519 .get("message")
520 .and_then(|v| v.as_str())
521 .map(|message| message.to_string())
522 }
523}
524
525#[derive(Debug)]
526struct Cache;
527
528impl Cache {
529 fn get(key: &CacheKey) -> Option<Arc<AiService>> {
530 AI_OVERRIDES.lock().unwrap().get(key).cloned()
531 }
532
533 fn insert(key: CacheKey, service: Arc<AiService>) {
534 AI_OVERRIDES.lock().unwrap().insert(key, service);
535 }
536}
537
538static AI_COMPONENT: LazyLock<()> = LazyLock::new(|| {
539 let component = Component::new(
540 AI_COMPONENT_NAME,
541 Arc::new(ai_factory),
542 ComponentType::Public,
543 )
544 .with_instantiation_mode(InstantiationMode::Lazy)
545 .with_multiple_instances(true);
546 let _ = app::register_component(component);
547});
548
549fn ai_factory(
550 container: &crate::component::ComponentContainer,
551 options: InstanceFactoryOptions,
552) -> Result<DynService, ComponentError> {
553 let app = container.root_service::<FirebaseApp>().ok_or_else(|| {
554 ComponentError::InitializationFailed {
555 name: AI_COMPONENT_NAME.to_string(),
556 reason: "Firebase app not attached to component container".to_string(),
557 }
558 })?;
559
560 let identifier_backend = options
561 .instance_identifier
562 .as_deref()
563 .map(|identifier| decode_instance_identifier(identifier));
564
565 let auth_provider = container.get_provider("auth-internal");
566 let app_check_provider = container.get_provider("app-check-internal");
567
568 let backend = match identifier_backend {
569 Some(Ok(backend)) => backend,
570 Some(Err(err)) => {
571 return Err(ComponentError::InitializationFailed {
572 name: AI_COMPONENT_NAME.to_string(),
573 reason: err.to_string(),
574 })
575 }
576 None => {
577 if let Some(encoded) = options
578 .options
579 .get("backend")
580 .and_then(|value| value.as_str())
581 {
582 decode_instance_identifier(encoded).map_err(|err| {
583 ComponentError::InitializationFailed {
584 name: AI_COMPONENT_NAME.to_string(),
585 reason: err.to_string(),
586 }
587 })?
588 } else {
589 Backend::default()
590 }
591 }
592 };
593
594 let use_limited_tokens = options
595 .options
596 .get("useLimitedUseAppCheckTokens")
597 .and_then(|value| value.as_bool())
598 .unwrap_or(false);
599
600 let runtime_options = AiRuntimeOptions {
601 use_limited_use_app_check_tokens: use_limited_tokens,
602 };
603
604 let default_model = options
605 .options
606 .get("defaultModel")
607 .and_then(|value| value.as_str().map(|s| s.to_string()));
608
609 let service = AiService::new(
610 (*app).clone(),
611 backend,
612 runtime_options,
613 default_model,
614 auth_provider,
615 app_check_provider,
616 );
617 Ok(Arc::new(service) as DynService)
618}
619
620fn ensure_registered() {
621 LazyLock::force(&AI_COMPONENT);
622}
623
624pub fn register_ai_component() {
626 ensure_registered();
627}
628
629pub async fn get_ai(
662 app: Option<FirebaseApp>,
663 options: Option<AiOptions>,
664) -> AiResult<Arc<AiService>> {
665 ensure_registered();
666 let app = match app {
667 Some(app) => app,
668 None => crate::app::get_app(None)
669 .await
670 .map_err(|err| internal_error(err.to_string()))?,
671 };
672
673 let options = options.unwrap_or_default();
674 let backend = options.backend_or_default();
675 let identifier = encode_instance_identifier(&backend);
676 let runtime_options = AiRuntimeOptions {
677 use_limited_use_app_check_tokens: options.limited_use_app_check(),
678 };
679
680 let cache_key = CacheKey::new(app.name(), &identifier);
681 if let Some(service) = Cache::get(&cache_key) {
682 service.set_options(runtime_options.clone());
683 return Ok(service);
684 }
685
686 let provider = app::get_provider(&app, AI_COMPONENT_NAME);
687
688 if let Some(service) = provider
689 .get_immediate_with_options::<AiService>(Some(&identifier), true)
690 .map_err(|err| internal_error(err.to_string()))?
691 {
692 service.set_options(runtime_options.clone());
693 Cache::insert(cache_key.clone(), service.clone());
694 return Ok(service);
695 }
696
697 match provider.initialize::<AiService>(
698 json!({
699 "backend": identifier,
700 "useLimitedUseAppCheckTokens": runtime_options.use_limited_use_app_check_tokens,
701 }),
702 Some(&cache_key.identifier),
703 ) {
704 Ok(service) => {
705 service.set_options(runtime_options.clone());
706 Cache::insert(cache_key.clone(), service.clone());
707 Ok(service)
708 }
709 Err(ComponentError::InstanceUnavailable { .. }) => {
710 if let Some(service) = provider
711 .get_immediate_with_options::<AiService>(Some(&cache_key.identifier), true)
712 .map_err(|err| internal_error(err.to_string()))?
713 {
714 service.set_options(runtime_options.clone());
715 Cache::insert(cache_key.clone(), service.clone());
716 Ok(service)
717 } else {
718 let container = app.container();
719 let fallback = Arc::new(AiService::new(
720 app.clone(),
721 backend,
722 runtime_options,
723 None,
724 container.get_provider("auth-internal"),
725 container.get_provider("app-check-internal"),
726 ));
727 Cache::insert(cache_key.clone(), fallback.clone());
728 Ok(fallback)
729 }
730 }
731 Err(err) => Err(internal_error(err.to_string())),
732 }
733}
734
735pub async fn get_ai_service(app: Option<FirebaseApp>) -> AiResult<Arc<AiService>> {
737 get_ai(app, None).await
738}
739
740#[cfg(test)]
741mod tests {
742 use super::*;
743 use crate::ai::backend::Backend;
744 use crate::ai::error::AiErrorCode;
745 use crate::ai::public_types::AiOptions;
746 use crate::app::initialize_app;
747 use crate::app::{FirebaseAppSettings, FirebaseOptions};
748 use async_trait::async_trait;
749 use serde_json::json;
750 use std::collections::VecDeque;
751 use std::sync::atomic::{AtomicUsize, Ordering};
752 use std::sync::Arc;
753
754 fn unique_settings() -> FirebaseAppSettings {
755 static COUNTER: AtomicUsize = AtomicUsize::new(0);
756 FirebaseAppSettings {
757 name: Some(format!("ai-{}", COUNTER.fetch_add(1, Ordering::SeqCst))),
758 ..Default::default()
759 }
760 }
761
762 #[derive(Clone, Default)]
763 struct TestTransport {
764 responses: Arc<Mutex<VecDeque<AiResult<Value>>>>,
765 requests: Arc<Mutex<Vec<PreparedRequest>>>,
766 }
767
768 impl TestTransport {
769 fn new() -> Self {
770 Self::default()
771 }
772
773 fn push_response(&self, response: AiResult<Value>) {
774 self.responses.lock().unwrap().push_back(response);
775 }
776
777 fn take_requests(&self) -> Vec<PreparedRequest> {
778 self.requests.lock().unwrap().clone()
779 }
780 }
781
782 #[cfg_attr(not(target_arch = "wasm32"), async_trait)]
783 #[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
784 impl AiHttpTransport for TestTransport {
785 async fn send(&self, request: PreparedRequest) -> AiResult<Value> {
786 self.requests.lock().unwrap().push(request);
787 self.responses
788 .lock()
789 .unwrap()
790 .pop_front()
791 .unwrap_or_else(|| Err(internal_error("no stub response configured")))
792 }
793 }
794
795 #[tokio::test(flavor = "current_thread")]
796 async fn generate_text_includes_backend_info() {
797 let transport = TestTransport::new();
798 transport.push_response(Ok(json!({
799 "candidates": [
800 {
801 "content": {
802 "parts": [
803 { "text": "Hello from mock" }
804 ]
805 }
806 }
807 ]
808 })));
809
810 let options = FirebaseOptions {
811 project_id: Some("project".into()),
812 api_key: Some("api".into()),
813 app_id: Some("app".into()),
814 ..Default::default()
815 };
816 let app = initialize_app(options, Some(unique_settings()))
817 .await
818 .unwrap();
819 let ai = get_ai_service(Some(app)).await.unwrap();
820 ai.set_transport_for_tests(Arc::new(transport.clone()));
821 let response = ai
822 .generate_text(GenerateTextRequest {
823 prompt: "Hello AI".to_string(),
824 model: Some("models/gemini-pro".to_string()),
825 request_options: None,
826 })
827 .await
828 .unwrap();
829
830 assert_eq!(response.model, "models/gemini-pro");
831 assert_eq!(response.text, "Hello from mock");
832
833 let requests = transport.take_requests();
834 assert_eq!(requests.len(), 1);
835 assert!(requests[0]
836 .url
837 .as_str()
838 .ends_with("models/gemini-pro:generateContent"));
839 assert_eq!(requests[0].header("x-goog-api-key"), Some("api"));
840 }
841
842 #[tokio::test(flavor = "current_thread")]
843 async fn limited_use_app_check_token_attached_to_requests() {
844 let transport = TestTransport::new();
845 transport.push_response(Ok(json!({
846 "candidates": [
847 {
848 "content": {
849 "parts": [
850 { "text": "Limited token response" }
851 ]
852 }
853 }
854 ]
855 })));
856
857 let options = FirebaseOptions {
858 project_id: Some("project".into()),
859 api_key: Some("api".into()),
860 app_id: Some("app".into()),
861 ..Default::default()
862 };
863 let app = initialize_app(options, Some(unique_settings()))
864 .await
865 .unwrap();
866
867 let ai = get_ai(
868 Some(app),
869 Some(AiOptions {
870 backend: Some(Backend::google_ai()),
871 use_limited_use_app_check_tokens: Some(true),
872 }),
873 )
874 .await
875 .unwrap();
876
877 ai.set_transport_for_tests(Arc::new(transport.clone()));
878 ai.override_tokens_for_tests(
879 None,
880 Some("standard-token".into()),
881 Some("limited-token".into()),
882 );
883
884 let response = ai
885 .generate_text(GenerateTextRequest {
886 prompt: "token test".to_string(),
887 model: Some("models/gemini-pro".to_string()),
888 request_options: None,
889 })
890 .await
891 .unwrap();
892
893 assert_eq!(response.text, "Limited token response");
894
895 let requests = transport.take_requests();
896 assert_eq!(requests.len(), 1);
897 assert_eq!(
898 requests[0].header("x-firebase-appcheck"),
899 Some("limited-token")
900 );
901 }
902
903 #[tokio::test(flavor = "current_thread")]
904 async fn empty_prompt_errors() {
905 let options = FirebaseOptions {
906 project_id: Some("project".into()),
907 api_key: Some("api".into()),
908 app_id: Some("app".into()),
909 ..Default::default()
910 };
911 let app = initialize_app(options, Some(unique_settings()))
912 .await
913 .unwrap();
914 let ai = get_ai_service(Some(app)).await.unwrap();
915 let err = ai
916 .generate_text(GenerateTextRequest {
917 prompt: " ".to_string(),
918 model: None,
919 request_options: None,
920 })
921 .await
922 .unwrap_err();
923 assert_eq!(err.code_str(), "AI/invalid-argument");
924 }
925
926 #[tokio::test(flavor = "current_thread")]
927 async fn backend_identifier_creates_unique_instances() {
928 let options = FirebaseOptions {
929 project_id: Some("project".into()),
930 ..Default::default()
931 };
932 let app = initialize_app(options, Some(unique_settings()))
933 .await
934 .unwrap();
935
936 let google = get_ai(
937 Some(app.clone()),
938 Some(AiOptions {
939 backend: Some(Backend::google_ai()),
940 use_limited_use_app_check_tokens: None,
941 }),
942 )
943 .await
944 .unwrap();
945
946 let vertex = get_ai(
947 Some(app.clone()),
948 Some(AiOptions {
949 backend: Some(Backend::vertex_ai("europe-west4")),
950 use_limited_use_app_check_tokens: Some(true),
951 }),
952 )
953 .await
954 .unwrap();
955
956 assert_ne!(Arc::as_ptr(&google), Arc::as_ptr(&vertex));
957 assert_eq!(vertex.location(), Some("europe-west4"));
958 assert!(vertex.options().use_limited_use_app_check_tokens);
959 }
960
961 #[tokio::test(flavor = "current_thread")]
962 async fn get_ai_reuses_cached_instance() {
963 let options = FirebaseOptions {
964 project_id: Some("project".into()),
965 api_key: Some("api".into()),
966 app_id: Some("app".into()),
967 ..Default::default()
968 };
969 let app = initialize_app(options, Some(unique_settings()))
970 .await
971 .unwrap();
972
973 let first = get_ai_service(Some(app.clone())).await.unwrap();
974 first
975 .prepare_generate_content_request("models/test-model", json!({ "contents": [] }), None)
976 .await
977 .unwrap();
978
979 let second = get_ai(Some(app.clone()), None).await.unwrap();
980 assert_eq!(Arc::as_ptr(&first), Arc::as_ptr(&second));
981 }
982
983 #[tokio::test(flavor = "current_thread")]
984 async fn api_settings_require_project_id() {
985 let options = FirebaseOptions {
986 api_key: Some("api".into()),
987 app_id: Some("app".into()),
988 ..Default::default()
989 };
990 let app = initialize_app(options, Some(unique_settings()))
991 .await
992 .unwrap();
993 let ai = get_ai_service(Some(app)).await.unwrap();
994 let err = ai.api_settings().await.unwrap_err();
995 assert_eq!(err.code(), AiErrorCode::NoProjectId);
996 }
997
998 #[tokio::test(flavor = "current_thread")]
999 async fn prepare_generate_content_request_builds_expected_url() {
1000 let options = FirebaseOptions {
1001 api_key: Some("api".into()),
1002 project_id: Some("project".into()),
1003 app_id: Some("app".into()),
1004 ..Default::default()
1005 };
1006 let app = initialize_app(options, Some(unique_settings()))
1007 .await
1008 .unwrap();
1009 let ai = get_ai_service(Some(app)).await.unwrap();
1010 let prepared = ai
1011 .prepare_generate_content_request(
1012 "models/gemini-1.5-flash",
1013 json!({ "contents": [] }),
1014 None,
1015 )
1016 .await
1017 .unwrap();
1018 assert_eq!(
1019 prepared.url.as_str(),
1020 "https://firebasevertexai.googleapis.com/v1beta/projects/project/models/gemini-1.5-flash:generateContent"
1021 );
1022 assert_eq!(prepared.header("x-goog-api-key"), Some("api"));
1023 }
1024}