docker_registry_client/docker/
token_cache.rs1use std::{
2 collections::HashMap,
3 sync::Arc,
4};
5
6use chrono::Utc;
7use tokio::sync::RwLock;
8use tracing::{
9 Instrument,
10 info_span,
11};
12
13use crate::docker::token::{
14 CacheKey,
15 Token,
16};
17
18#[cfg(feature = "redis_cache")]
19use redis::AsyncCommands;
20
21#[cfg(feature = "redis_cache")]
22const REDIS_PREFIX: &str = "docker-registry-client:token";
23
24#[derive(Debug)]
25pub enum FetchError {
26 CheckExists(redis::RedisError),
27 DeserializeToken(serde_json::Error),
28 GetConnection(redis::RedisError),
29 GetValue(redis::RedisError),
30}
31
32#[derive(Debug)]
33pub enum StoreError {
34 GetConnection(redis::RedisError),
35 SerializeToken(serde_json::Error),
36 SetExpiration(redis::RedisError),
37 SetValue(redis::RedisError),
38}
39
40#[async_trait::async_trait]
41pub(super) trait Cache: std::fmt::Debug + Send + Sync + dyn_clone::DynClone {
42 async fn fetch(&self, key: &CacheKey) -> Result<Option<Token>, FetchError>;
43 async fn store(&self, key: CacheKey, token: Token) -> Result<(), StoreError>;
44}
45
46dyn_clone::clone_trait_object!(Cache);
47
48#[derive(Debug, Default, Clone)]
50pub(super) struct NoCache;
51
52#[derive(Debug, Default, Clone)]
54pub(super) struct MemoryTokenCache {
55 cache: Arc<RwLock<HashMap<CacheKey, Token>>>,
56}
57
58#[cfg(feature = "redis_cache")]
59#[derive(Debug, Clone)]
61pub(super) struct RedisCache {
62 client: redis::Client,
63}
64
65impl std::fmt::Display for FetchError {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 match self {
68 Self::CheckExists(e) => write!(f, "failed to check if key exists: {e}"),
69 Self::DeserializeToken(e) => write!(f, "failed to deserialize token: {e}"),
70 Self::GetConnection(e) => write!(f, "failed to get redis connection: {e}"),
71 Self::GetValue(e) => write!(f, "failed to get value from redis: {e}"),
72 }
73 }
74}
75
76impl std::error::Error for FetchError {}
77
78impl std::fmt::Display for StoreError {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 match self {
81 Self::GetConnection(e) => write!(f, "failed to get redis connection: {e}"),
82 Self::SerializeToken(e) => write!(f, "failed to serialize token: {e}"),
83 Self::SetExpiration(e) => write!(f, "failed to set expiration: {e}"),
84 Self::SetValue(e) => write!(f, "failed to set value in redis: {e}"),
85 }
86 }
87}
88
89impl std::error::Error for StoreError {}
90
91#[async_trait::async_trait]
92impl Cache for NoCache {
93 async fn fetch(&self, _key: &CacheKey) -> Result<Option<Token>, FetchError> {
94 Ok(None)
95 }
96
97 async fn store(&self, _key: CacheKey, _token: Token) -> Result<(), StoreError> {
98 Ok(())
99 }
100}
101
102#[async_trait::async_trait]
103impl Cache for MemoryTokenCache {
104 #[tracing::instrument]
105 async fn fetch(&self, key: &CacheKey) -> Result<Option<Token>, FetchError> {
106 let result = self.cache.read().await.get(key).cloned().and_then(|token| {
107 if let Some(expires_in) = token.expires_in {
108 token
109 .issued_at
110 .map(|issued_at| issued_at + chrono::Duration::seconds(expires_in))
111 .and_then(|expires_at| {
112 if expires_at < Utc::now() {
113 None
114 } else {
115 Some(token)
116 }
117 })
118 } else {
119 Some(token)
120 }
121 });
122
123 Ok(result)
124 }
125
126 #[tracing::instrument]
127 async fn store(&self, key: CacheKey, token: Token) -> Result<(), StoreError> {
128 self.cache.write().await.insert(key, token);
129
130 Ok(())
131 }
132}
133
134#[cfg(feature = "redis_cache")]
135impl RedisCache {
136 #[must_use]
137 pub fn new(client: redis::Client) -> Self {
138 Self { client }
139 }
140}
141
142#[cfg(feature = "redis_cache")]
143#[async_trait::async_trait]
144impl Cache for RedisCache {
145 #[tracing::instrument]
146 async fn fetch(&self, key: &CacheKey) -> Result<Option<Token>, FetchError> {
147 let mut connection = self
148 .client
149 .get_multiplexed_async_connection()
150 .instrument(info_span!("get redis connection"))
151 .await
152 .map_err(FetchError::GetConnection)?;
153
154 let key = format!("{REDIS_PREFIX}:{key}");
155
156 let exists: bool = connection
157 .exists(&key)
158 .instrument(info_span!("check if key exists"))
159 .await
160 .map_err(FetchError::CheckExists)?;
161
162 if !exists {
163 return Ok(None);
164 }
165
166 let value: String = connection
167 .get(&key)
168 .instrument(info_span!("get value"))
169 .await
170 .map_err(FetchError::GetValue)?;
171
172 let token = serde_json::from_str(&value).map_err(FetchError::DeserializeToken)?;
173
174 Ok(Some(token))
175 }
176
177 #[tracing::instrument]
178 async fn store(&self, key: CacheKey, token: Token) -> Result<(), StoreError> {
179 let mut connection = self
180 .client
181 .get_multiplexed_async_connection()
182 .instrument(info_span!("get redis connection"))
183 .await
184 .map_err(StoreError::GetConnection)?;
185
186 let key = format!("{REDIS_PREFIX}:{key}");
187
188 let value = serde_json::to_string(&token).map_err(StoreError::SerializeToken)?;
189
190 connection
191 .set::<&String, String, String>(&key, value)
192 .instrument(info_span!("set value"))
193 .await
194 .map_err(StoreError::SetValue)?;
195
196 if let Some(expires_in) = token.expires_in {
197 connection
198 .expire::<&String, String>(&key, expires_in)
199 .instrument(info_span!("set expire"))
200 .await
201 .map_err(StoreError::SetExpiration)?;
202 }
203
204 Ok(())
205 }
206}