firebase_rs_sdk/app_check/
token_provider.rs1use std::sync::atomic::{AtomicBool, Ordering};
4
5use crate::app_check::errors::AppCheckError;
6use crate::app_check::FirebaseAppCheckInternal;
7use crate::firestore::{
8 internal_error, invalid_argument, unauthenticated, unavailable, FirestoreError, FirestoreResult,
9};
10use crate::firestore::{TokenProvider, TokenProviderArc};
11
12pub struct AppCheckTokenProvider {
14 app_check: FirebaseAppCheckInternal,
15 force_refresh: AtomicBool,
16}
17
18impl AppCheckTokenProvider {
19 pub fn new(app_check: FirebaseAppCheckInternal) -> Self {
21 Self {
22 app_check,
23 force_refresh: AtomicBool::new(false),
24 }
25 }
26
27 pub fn into_arc(self) -> TokenProviderArc {
29 std::sync::Arc::new(self)
30 }
31}
32
33pub fn app_check_token_provider_arc(app_check: FirebaseAppCheckInternal) -> TokenProviderArc {
35 AppCheckTokenProvider::new(app_check).into_arc()
36}
37
38impl Clone for AppCheckTokenProvider {
39 fn clone(&self) -> Self {
40 Self {
41 app_check: self.app_check.clone(),
42 force_refresh: AtomicBool::new(self.force_refresh.load(Ordering::SeqCst)),
43 }
44 }
45}
46
47#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
48#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
49impl TokenProvider for AppCheckTokenProvider {
50 async fn get_token(&self) -> FirestoreResult<Option<String>> {
51 let force_refresh = self.force_refresh.swap(false, Ordering::SeqCst);
52 match self.app_check.get_token(force_refresh).await {
53 Ok(result) => {
54 if result.token.is_empty() {
55 Ok(None)
56 } else {
57 Ok(Some(result.token))
58 }
59 }
60 Err(err) => {
61 if let Some(cached) = err.cached_token() {
62 if cached.token.is_empty() {
63 Ok(None)
64 } else {
65 Ok(Some(cached.token.clone()))
66 }
67 } else {
68 Err(map_app_check_error(err.cause))
69 }
70 }
71 }
72 }
73
74 fn invalidate_token(&self) {
75 self.force_refresh.store(true, Ordering::SeqCst);
76 }
77
78 async fn heartbeat_header(&self) -> FirestoreResult<Option<String>> {
79 self.app_check
80 .heartbeat_header()
81 .await
82 .map_err(map_app_check_error)
83 }
84}
85
86fn map_app_check_error(error: AppCheckError) -> FirestoreError {
87 match error.clone() {
88 AppCheckError::AlreadyInitialized { .. }
89 | AppCheckError::UseBeforeActivation { .. }
90 | AppCheckError::InvalidConfiguration { .. } => invalid_argument(error.to_string()),
91 AppCheckError::TokenExpired => unauthenticated(error.to_string()),
92 AppCheckError::Internal(message) => internal_error(message),
93 AppCheckError::TokenFetchFailed { .. }
94 | AppCheckError::ProviderError { .. }
95 | AppCheckError::FetchNetworkError { .. }
96 | AppCheckError::FetchParseError { .. }
97 | AppCheckError::FetchStatusError { .. }
98 | AppCheckError::RecaptchaError { .. }
99 | AppCheckError::InitialThrottle { .. }
100 | AppCheckError::Throttled { .. } => unavailable(error.to_string()),
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107 use crate::app::{FirebaseApp, FirebaseAppConfig, FirebaseOptions};
108 use crate::app_check::api::{initialize_app_check, token_with_ttl};
109 use crate::app_check::types::{
110 box_app_check_future, AppCheckOptions, AppCheckProvider, AppCheckProviderFuture,
111 AppCheckToken,
112 };
113 use crate::component::ComponentContainer;
114 use std::sync::Arc;
115 use std::time::Duration;
116
117 #[derive(Clone)]
118 struct StaticTokenProvider {
119 token: String,
120 }
121
122 impl AppCheckProvider for StaticTokenProvider {
123 fn get_token(
124 &self,
125 ) -> AppCheckProviderFuture<'_, crate::app_check::AppCheckResult<AppCheckToken>> {
126 let token = self.token.clone();
127 box_app_check_future(async move { token_with_ttl(token, Duration::from_secs(60)) })
128 }
129 }
130
131 #[derive(Clone)]
132 struct ErrorProvider;
133
134 impl AppCheckProvider for ErrorProvider {
135 fn get_token(
136 &self,
137 ) -> AppCheckProviderFuture<'_, crate::app_check::AppCheckResult<AppCheckToken>> {
138 box_app_check_future(async move {
139 Err(AppCheckError::TokenFetchFailed {
140 message: "network".into(),
141 })
142 })
143 }
144 }
145
146 fn test_app(name: &str) -> FirebaseApp {
147 FirebaseApp::new(
148 FirebaseOptions::default(),
149 FirebaseAppConfig::new(name.to_owned(), false),
150 ComponentContainer::new(name.to_owned()),
151 )
152 }
153
154 #[tokio::test(flavor = "current_thread")]
155 async fn returns_token_string() {
156 let provider = Arc::new(StaticTokenProvider {
157 token: "app-check-123".into(),
158 });
159 let options = AppCheckOptions::new(provider);
160 let app_check = initialize_app_check(Some(test_app("app-check-ok")), options)
161 .await
162 .expect("initialize app check");
163 let internal = FirebaseAppCheckInternal::new(app_check);
164 let provider = AppCheckTokenProvider::new(internal);
165
166 let token = provider.get_token().await.unwrap();
167 assert_eq!(token.as_deref(), Some("app-check-123"));
168
169 let heartbeat = provider.heartbeat_header().await.unwrap();
170 assert!(heartbeat.is_none());
171 }
172
173 #[tokio::test(flavor = "current_thread")]
174 async fn propagates_errors() {
175 let provider = Arc::new(ErrorProvider);
176 let options = AppCheckOptions::new(provider);
177 let app_check = initialize_app_check(Some(test_app("app-check-err")), options)
178 .await
179 .expect("initialize app check");
180 let internal = FirebaseAppCheckInternal::new(app_check);
181 let provider = AppCheckTokenProvider::new(internal);
182
183 let error = provider.get_token().await.unwrap_err();
184 assert_eq!(
185 error.code,
186 crate::firestore::FirestoreErrorCode::Unavailable
187 );
188 }
189}