rs_firebase_admin_sdk/auth/token/cache/
mod.rs

1//! Public key caching for use in efficient token verification
2
3#[cfg(test)]
4mod test;
5
6pub mod error;
7
8use super::JwtRsaPubKey;
9use bytes::Bytes;
10use error::{CacheError, ClientError};
11use error_stack::{Report, ResultExt};
12use headers::{CacheControl, HeaderMapExt};
13use reqwest::Client;
14use serde::de::DeserializeOwned;
15use serde_json::from_slice;
16use std::collections::BTreeMap;
17use std::future::Future;
18use std::sync::Arc;
19use std::time::{Duration, SystemTime};
20use tokio::sync::{Mutex, RwLock};
21
22#[derive(Clone, Debug)]
23struct Cache<ContentT> {
24    expires_at: SystemTime,
25    content: ContentT,
26}
27
28impl<ContentT> Cache<ContentT> {
29    pub fn new(max_age: Duration, content: ContentT) -> Self {
30        Self {
31            expires_at: SystemTime::now() + max_age,
32            content,
33        }
34    }
35
36    pub fn is_expired(&self) -> bool {
37        self.expires_at <= SystemTime::now()
38    }
39
40    pub fn update(&mut self, max_age: Duration, content: ContentT) {
41        self.expires_at = SystemTime::now() + max_age;
42        self.content = content;
43    }
44}
45
46#[derive(Clone, Debug)]
47pub struct Resource {
48    pub data: Bytes,
49    pub max_age: Duration,
50}
51
52pub trait CacheClient: Sized + Send + Sync
53where
54    Self::Error: std::error::Error + Send + Sync + 'static,
55{
56    type Error;
57
58    /// Simple async interface to fetch data and its TTL for an URI
59    fn fetch(
60        &self,
61        uri: &str,
62    ) -> impl Future<Output = Result<Resource, Report<Self::Error>>> + Send;
63}
64
65impl CacheClient for Client {
66    type Error = ClientError;
67
68    async fn fetch(&self, uri: &str) -> Result<Resource, Report<Self::Error>> {
69        let response = self
70            .get(uri)
71            .send()
72            .await
73            .change_context(ClientError::FailedToFetch)?;
74
75        let status = response.status();
76
77        if !status.is_success() {
78            return Err(Report::new(ClientError::BadHttpResponse(status)));
79        }
80
81        let cache_header: Option<CacheControl> = response.headers().typed_get();
82        let body = response
83            .bytes()
84            .await
85            .change_context(ClientError::FailedToFetch)?;
86
87        if let Some(cache_header) = cache_header {
88            let ttl = cache_header
89                .s_max_age()
90                .unwrap_or_else(|| cache_header.max_age().unwrap_or_default());
91
92            return Ok(Resource {
93                data: body,
94                max_age: ttl,
95            });
96        }
97
98        Ok(Resource {
99            data: body,
100            max_age: Duration::default(),
101        })
102    }
103}
104
105pub struct HttpCache<CacheClientT, ContentT> {
106    client: CacheClientT,
107    path: String,
108    cache: Arc<RwLock<Cache<ContentT>>>,
109    refresh: Mutex<()>,
110}
111
112impl<CacheClientT, ContentT> HttpCache<CacheClientT, ContentT>
113where
114    CacheClientT: CacheClient,
115    ContentT: DeserializeOwned + Clone + Send + Sync,
116{
117    pub async fn new(client: CacheClientT, path: String) -> Result<Self, Report<CacheError>> {
118        let resource = client.fetch(&path).await.change_context(CacheError)?;
119
120        let initial_cache: Cache<ContentT> = Cache::new(
121            resource.max_age,
122            from_slice(&resource.data).change_context(CacheError)?,
123        );
124
125        Ok(Self {
126            client,
127            path,
128            cache: Arc::new(RwLock::new(initial_cache)),
129            refresh: Mutex::new(()),
130        })
131    }
132
133    pub async fn get(&self) -> Result<ContentT, Report<CacheError>> {
134        let cache = self.cache.read().await.clone();
135        if cache.is_expired() {
136            // to make sure only a single connection is being established to refresh the resource
137            let _refresh_guard = self.refresh.lock().await;
138
139            // check if the cache has been refreshed by another co-routine
140            let cache = self.cache.read().await.clone();
141            if !cache.is_expired() {
142                return Ok(cache.content);
143            }
144
145            // refresh resource
146            let resource = self
147                .client
148                .fetch(&self.path)
149                .await
150                .change_context(CacheError)?;
151
152            let content: ContentT = from_slice(&resource.data).change_context(CacheError)?;
153
154            self.cache
155                .write()
156                .await
157                .update(resource.max_age, content.clone());
158
159            return Ok(content);
160        }
161
162        Ok(cache.content)
163    }
164}
165
166pub type PubKeys = BTreeMap<String, JwtRsaPubKey>;
167
168pub trait KeyCache {
169    fn get_keys(&self) -> impl Future<Output = Result<PubKeys, Report<CacheError>>> + Send;
170}
171
172impl<ClientT: CacheClient> KeyCache for HttpCache<ClientT, PubKeys> {
173    fn get_keys(&self) -> impl Future<Output = Result<PubKeys, Report<CacheError>>> + Send {
174        self.get()
175    }
176}