firebase_rs_sdk/app_check/
token_provider.rs

1// #![cfg(feature = "firestore")]
2
3use std::sync::atomic::{AtomicBool, Ordering};
4
5use crate::app_check::errors::AppCheckError;
6use crate::app_check::FirebaseAppCheckInternal;
7use crate::firestore::{
8    internal_error, invalid_argument, unauthenticated, unavailable, FirestoreError, FirestoreResult,
9};
10use crate::firestore::{TokenProvider, TokenProviderArc};
11
12/// Bridges App Check token retrieval into Firestore's [`TokenProvider`] trait.
13pub struct AppCheckTokenProvider {
14    app_check: FirebaseAppCheckInternal,
15    force_refresh: AtomicBool,
16}
17
18impl AppCheckTokenProvider {
19    /// Creates a new provider backed by the given App Check instance.
20    pub fn new(app_check: FirebaseAppCheckInternal) -> Self {
21        Self {
22            app_check,
23            force_refresh: AtomicBool::new(false),
24        }
25    }
26
27    /// Converts the provider into a reference-counted [`TokenProviderArc`].
28    pub fn into_arc(self) -> TokenProviderArc {
29        std::sync::Arc::new(self)
30    }
31}
32
33/// Convenience helper to expose an App Check instance as a [`TokenProviderArc`].
34pub fn app_check_token_provider_arc(app_check: FirebaseAppCheckInternal) -> TokenProviderArc {
35    AppCheckTokenProvider::new(app_check).into_arc()
36}
37
38impl Clone for AppCheckTokenProvider {
39    fn clone(&self) -> Self {
40        Self {
41            app_check: self.app_check.clone(),
42            force_refresh: AtomicBool::new(self.force_refresh.load(Ordering::SeqCst)),
43        }
44    }
45}
46
47#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
48#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
49impl TokenProvider for AppCheckTokenProvider {
50    async fn get_token(&self) -> FirestoreResult<Option<String>> {
51        let force_refresh = self.force_refresh.swap(false, Ordering::SeqCst);
52        match self.app_check.get_token(force_refresh).await {
53            Ok(result) => {
54                if result.token.is_empty() {
55                    Ok(None)
56                } else {
57                    Ok(Some(result.token))
58                }
59            }
60            Err(err) => {
61                if let Some(cached) = err.cached_token() {
62                    if cached.token.is_empty() {
63                        Ok(None)
64                    } else {
65                        Ok(Some(cached.token.clone()))
66                    }
67                } else {
68                    Err(map_app_check_error(err.cause))
69                }
70            }
71        }
72    }
73
74    fn invalidate_token(&self) {
75        self.force_refresh.store(true, Ordering::SeqCst);
76    }
77
78    async fn heartbeat_header(&self) -> FirestoreResult<Option<String>> {
79        self.app_check
80            .heartbeat_header()
81            .await
82            .map_err(map_app_check_error)
83    }
84}
85
86fn map_app_check_error(error: AppCheckError) -> FirestoreError {
87    match error.clone() {
88        AppCheckError::AlreadyInitialized { .. }
89        | AppCheckError::UseBeforeActivation { .. }
90        | AppCheckError::InvalidConfiguration { .. } => invalid_argument(error.to_string()),
91        AppCheckError::TokenExpired => unauthenticated(error.to_string()),
92        AppCheckError::Internal(message) => internal_error(message),
93        AppCheckError::TokenFetchFailed { .. }
94        | AppCheckError::ProviderError { .. }
95        | AppCheckError::FetchNetworkError { .. }
96        | AppCheckError::FetchParseError { .. }
97        | AppCheckError::FetchStatusError { .. }
98        | AppCheckError::RecaptchaError { .. }
99        | AppCheckError::InitialThrottle { .. }
100        | AppCheckError::Throttled { .. } => unavailable(error.to_string()),
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use crate::app::{FirebaseApp, FirebaseAppConfig, FirebaseOptions};
108    use crate::app_check::api::{initialize_app_check, token_with_ttl};
109    use crate::app_check::types::{
110        box_app_check_future, AppCheckOptions, AppCheckProvider, AppCheckProviderFuture,
111        AppCheckToken,
112    };
113    use crate::component::ComponentContainer;
114    use std::sync::Arc;
115    use std::time::Duration;
116
117    #[derive(Clone)]
118    struct StaticTokenProvider {
119        token: String,
120    }
121
122    impl AppCheckProvider for StaticTokenProvider {
123        fn get_token(
124            &self,
125        ) -> AppCheckProviderFuture<'_, crate::app_check::AppCheckResult<AppCheckToken>> {
126            let token = self.token.clone();
127            box_app_check_future(async move { token_with_ttl(token, Duration::from_secs(60)) })
128        }
129    }
130
131    #[derive(Clone)]
132    struct ErrorProvider;
133
134    impl AppCheckProvider for ErrorProvider {
135        fn get_token(
136            &self,
137        ) -> AppCheckProviderFuture<'_, crate::app_check::AppCheckResult<AppCheckToken>> {
138            box_app_check_future(async move {
139                Err(AppCheckError::TokenFetchFailed {
140                    message: "network".into(),
141                })
142            })
143        }
144    }
145
146    fn test_app(name: &str) -> FirebaseApp {
147        FirebaseApp::new(
148            FirebaseOptions::default(),
149            FirebaseAppConfig::new(name.to_owned(), false),
150            ComponentContainer::new(name.to_owned()),
151        )
152    }
153
154    #[tokio::test(flavor = "current_thread")]
155    async fn returns_token_string() {
156        let provider = Arc::new(StaticTokenProvider {
157            token: "app-check-123".into(),
158        });
159        let options = AppCheckOptions::new(provider);
160        let app_check = initialize_app_check(Some(test_app("app-check-ok")), options)
161            .await
162            .expect("initialize app check");
163        let internal = FirebaseAppCheckInternal::new(app_check);
164        let provider = AppCheckTokenProvider::new(internal);
165
166        let token = provider.get_token().await.unwrap();
167        assert_eq!(token.as_deref(), Some("app-check-123"));
168
169        let heartbeat = provider.heartbeat_header().await.unwrap();
170        assert!(heartbeat.is_none());
171    }
172
173    #[tokio::test(flavor = "current_thread")]
174    async fn propagates_errors() {
175        let provider = Arc::new(ErrorProvider);
176        let options = AppCheckOptions::new(provider);
177        let app_check = initialize_app_check(Some(test_app("app-check-err")), options)
178            .await
179            .expect("initialize app check");
180        let internal = FirebaseAppCheckInternal::new(app_check);
181        let provider = AppCheckTokenProvider::new(internal);
182
183        let error = provider.get_token().await.unwrap_err();
184        assert_eq!(
185            error.code,
186            crate::firestore::FirestoreErrorCode::Unavailable
187        );
188    }
189}