Skip to main content

docker_registry_client/docker/
token_cache.rs

1use std::{
2    collections::HashMap,
3    sync::Arc,
4};
5
6use chrono::Utc;
7use tokio::sync::RwLock;
8use tracing::{
9    Instrument,
10    info_span,
11};
12
13use crate::docker::token::{
14    CacheKey,
15    Token,
16};
17
18#[cfg(feature = "redis_cache")]
19use redis::AsyncCommands;
20
21#[cfg(feature = "redis_cache")]
22const REDIS_PREFIX: &str = "docker-registry-client:token";
23
24#[derive(Debug)]
25pub enum FetchError {
26    CheckExists(redis::RedisError),
27    DeserializeToken(serde_json::Error),
28    GetConnection(redis::RedisError),
29    GetValue(redis::RedisError),
30}
31
32#[derive(Debug)]
33pub enum StoreError {
34    GetConnection(redis::RedisError),
35    SerializeToken(serde_json::Error),
36    SetExpiration(redis::RedisError),
37    SetValue(redis::RedisError),
38}
39
40#[async_trait::async_trait]
41pub(super) trait Cache: std::fmt::Debug + Send + Sync + dyn_clone::DynClone {
42    async fn fetch(&self, key: &CacheKey) -> Result<Option<Token>, FetchError>;
43    async fn store(&self, key: CacheKey, token: Token) -> Result<(), StoreError>;
44}
45
46dyn_clone::clone_trait_object!(Cache);
47
48/// `NoCache` is a token cache that does not cache tokens.
49#[derive(Debug, Default, Clone)]
50pub(super) struct NoCache;
51
52/// `MemoryTokenCache` is a token cache that caches tokens in memory.
53#[derive(Debug, Default, Clone)]
54pub(super) struct MemoryTokenCache {
55    cache: Arc<RwLock<HashMap<CacheKey, Token>>>,
56}
57
58#[cfg(feature = "redis_cache")]
59/// `RedisCache` is a token cache that caches tokens in Redis.
60#[derive(Debug, Clone)]
61pub(super) struct RedisCache {
62    client: redis::Client,
63}
64
65impl std::fmt::Display for FetchError {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        match self {
68            Self::CheckExists(e) => write!(f, "failed to check if key exists: {e}"),
69            Self::DeserializeToken(e) => write!(f, "failed to deserialize token: {e}"),
70            Self::GetConnection(e) => write!(f, "failed to get redis connection: {e}"),
71            Self::GetValue(e) => write!(f, "failed to get value from redis: {e}"),
72        }
73    }
74}
75
76impl std::error::Error for FetchError {}
77
78impl std::fmt::Display for StoreError {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        match self {
81            Self::GetConnection(e) => write!(f, "failed to get redis connection: {e}"),
82            Self::SerializeToken(e) => write!(f, "failed to serialize token: {e}"),
83            Self::SetExpiration(e) => write!(f, "failed to set expiration: {e}"),
84            Self::SetValue(e) => write!(f, "failed to set value in redis: {e}"),
85        }
86    }
87}
88
89impl std::error::Error for StoreError {}
90
91#[async_trait::async_trait]
92impl Cache for NoCache {
93    async fn fetch(&self, _key: &CacheKey) -> Result<Option<Token>, FetchError> {
94        Ok(None)
95    }
96
97    async fn store(&self, _key: CacheKey, _token: Token) -> Result<(), StoreError> {
98        Ok(())
99    }
100}
101
102#[async_trait::async_trait]
103impl Cache for MemoryTokenCache {
104    #[tracing::instrument]
105    async fn fetch(&self, key: &CacheKey) -> Result<Option<Token>, FetchError> {
106        let result = self.cache.read().await.get(key).cloned().and_then(|token| {
107            if let Some(expires_in) = token.expires_in {
108                token
109                    .issued_at
110                    .map(|issued_at| issued_at + chrono::Duration::seconds(expires_in))
111                    .and_then(|expires_at| {
112                        if expires_at < Utc::now() {
113                            None
114                        } else {
115                            Some(token)
116                        }
117                    })
118            } else {
119                Some(token)
120            }
121        });
122
123        Ok(result)
124    }
125
126    #[tracing::instrument]
127    async fn store(&self, key: CacheKey, token: Token) -> Result<(), StoreError> {
128        self.cache.write().await.insert(key, token);
129
130        Ok(())
131    }
132}
133
134#[cfg(feature = "redis_cache")]
135impl RedisCache {
136    #[must_use]
137    pub fn new(client: redis::Client) -> Self {
138        Self { client }
139    }
140}
141
142#[cfg(feature = "redis_cache")]
143#[async_trait::async_trait]
144impl Cache for RedisCache {
145    #[tracing::instrument]
146    async fn fetch(&self, key: &CacheKey) -> Result<Option<Token>, FetchError> {
147        let mut connection = self
148            .client
149            .get_multiplexed_async_connection()
150            .instrument(info_span!("get redis connection"))
151            .await
152            .map_err(FetchError::GetConnection)?;
153
154        let key = format!("{REDIS_PREFIX}:{key}");
155
156        let exists: bool = connection
157            .exists(&key)
158            .instrument(info_span!("check if key exists"))
159            .await
160            .map_err(FetchError::CheckExists)?;
161
162        if !exists {
163            return Ok(None);
164        }
165
166        let value: String = connection
167            .get(&key)
168            .instrument(info_span!("get value"))
169            .await
170            .map_err(FetchError::GetValue)?;
171
172        let token = serde_json::from_str(&value).map_err(FetchError::DeserializeToken)?;
173
174        Ok(Some(token))
175    }
176
177    #[tracing::instrument]
178    async fn store(&self, key: CacheKey, token: Token) -> Result<(), StoreError> {
179        let mut connection = self
180            .client
181            .get_multiplexed_async_connection()
182            .instrument(info_span!("get redis connection"))
183            .await
184            .map_err(StoreError::GetConnection)?;
185
186        let key = format!("{REDIS_PREFIX}:{key}");
187
188        let value = serde_json::to_string(&token).map_err(StoreError::SerializeToken)?;
189
190        connection
191            .set::<&String, String, String>(&key, value)
192            .instrument(info_span!("set value"))
193            .await
194            .map_err(StoreError::SetValue)?;
195
196        if let Some(expires_in) = token.expires_in {
197            connection
198                .expire::<&String, String>(&key, expires_in)
199                .instrument(info_span!("set expire"))
200                .await
201                .map_err(StoreError::SetExpiration)?;
202        }
203
204        Ok(())
205    }
206}