rs_firebase_admin_sdk/auth/token/cache/
mod.rs1#[cfg(test)]
4mod test;
5
6pub mod error;
7
8use super::JwtRsaPubKey;
9use bytes::Bytes;
10use error::{CacheError, ClientError};
11use error_stack::{Report, ResultExt};
12use headers::{CacheControl, HeaderMapExt};
13use reqwest::Client;
14use serde::de::DeserializeOwned;
15use serde_json::from_slice;
16use std::collections::BTreeMap;
17use std::future::Future;
18use std::sync::Arc;
19use std::time::{Duration, SystemTime};
20use tokio::sync::{Mutex, RwLock};
21
22#[derive(Clone, Debug)]
23struct Cache<ContentT> {
24 expires_at: SystemTime,
25 content: ContentT,
26}
27
28impl<ContentT> Cache<ContentT> {
29 pub fn new(max_age: Duration, content: ContentT) -> Self {
30 Self {
31 expires_at: SystemTime::now() + max_age,
32 content,
33 }
34 }
35
36 pub fn is_expired(&self) -> bool {
37 self.expires_at <= SystemTime::now()
38 }
39
40 pub fn update(&mut self, max_age: Duration, content: ContentT) {
41 self.expires_at = SystemTime::now() + max_age;
42 self.content = content;
43 }
44}
45
46#[derive(Clone, Debug)]
47pub struct Resource {
48 pub data: Bytes,
49 pub max_age: Duration,
50}
51
52pub trait CacheClient: Sized + Send + Sync
53where
54 Self::Error: std::error::Error + Send + Sync + 'static,
55{
56 type Error;
57
58 fn fetch(
60 &self,
61 uri: &str,
62 ) -> impl Future<Output = Result<Resource, Report<Self::Error>>> + Send;
63}
64
65impl CacheClient for Client {
66 type Error = ClientError;
67
68 async fn fetch(&self, uri: &str) -> Result<Resource, Report<Self::Error>> {
69 let response = self
70 .get(uri)
71 .send()
72 .await
73 .change_context(ClientError::FailedToFetch)?;
74
75 let status = response.status();
76
77 if !status.is_success() {
78 return Err(Report::new(ClientError::BadHttpResponse(status)));
79 }
80
81 let cache_header: Option<CacheControl> = response.headers().typed_get();
82 let body = response
83 .bytes()
84 .await
85 .change_context(ClientError::FailedToFetch)?;
86
87 if let Some(cache_header) = cache_header {
88 let ttl = cache_header
89 .s_max_age()
90 .unwrap_or_else(|| cache_header.max_age().unwrap_or_default());
91
92 return Ok(Resource {
93 data: body,
94 max_age: ttl,
95 });
96 }
97
98 Ok(Resource {
99 data: body,
100 max_age: Duration::default(),
101 })
102 }
103}
104
105pub struct HttpCache<CacheClientT, ContentT> {
106 client: CacheClientT,
107 path: String,
108 cache: Arc<RwLock<Cache<ContentT>>>,
109 refresh: Mutex<()>,
110}
111
112impl<CacheClientT, ContentT> HttpCache<CacheClientT, ContentT>
113where
114 CacheClientT: CacheClient,
115 ContentT: DeserializeOwned + Clone + Send + Sync,
116{
117 pub async fn new(client: CacheClientT, path: String) -> Result<Self, Report<CacheError>> {
118 let resource = client.fetch(&path).await.change_context(CacheError)?;
119
120 let initial_cache: Cache<ContentT> = Cache::new(
121 resource.max_age,
122 from_slice(&resource.data).change_context(CacheError)?,
123 );
124
125 Ok(Self {
126 client,
127 path,
128 cache: Arc::new(RwLock::new(initial_cache)),
129 refresh: Mutex::new(()),
130 })
131 }
132
133 pub async fn get(&self) -> Result<ContentT, Report<CacheError>> {
134 let cache = self.cache.read().await.clone();
135 if cache.is_expired() {
136 let _refresh_guard = self.refresh.lock().await;
138
139 let cache = self.cache.read().await.clone();
141 if !cache.is_expired() {
142 return Ok(cache.content);
143 }
144
145 let resource = self
147 .client
148 .fetch(&self.path)
149 .await
150 .change_context(CacheError)?;
151
152 let content: ContentT = from_slice(&resource.data).change_context(CacheError)?;
153
154 self.cache
155 .write()
156 .await
157 .update(resource.max_age, content.clone());
158
159 return Ok(content);
160 }
161
162 Ok(cache.content)
163 }
164}
165
166pub type PubKeys = BTreeMap<String, JwtRsaPubKey>;
167
168pub trait KeyCache {
169 fn get_keys(&self) -> impl Future<Output = Result<PubKeys, Report<CacheError>>> + Send;
170}
171
172impl<ClientT: CacheClient> KeyCache for HttpCache<ClientT, PubKeys> {
173 fn get_keys(&self) -> impl Future<Output = Result<PubKeys, Report<CacheError>>> + Send {
174 self.get()
175 }
176}