Skip to main content

oci_client/
token_cache.rs

1use oci_spec::distribution::Reference;
2use serde::Deserialize;
3use std::collections::BTreeMap;
4use std::fmt;
5use std::sync::Arc;
6use std::time::{SystemTime, UNIX_EPOCH};
7use tokio::sync::RwLock;
8use tracing::{debug, warn};
9
10/// A token granted during the OAuth2-like workflow for OCI registries.
11#[derive(Deserialize, Clone)]
12#[serde(untagged)]
13#[serde(rename_all = "snake_case")]
14pub(crate) enum RegistryToken {
15    Token { token: String },
16    AccessToken { access_token: String },
17}
18
19impl fmt::Debug for RegistryToken {
20    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21        let redacted = String::from("<redacted>");
22        match self {
23            RegistryToken::Token { .. } => {
24                f.debug_struct("Token").field("token", &redacted).finish()
25            }
26            RegistryToken::AccessToken { .. } => f
27                .debug_struct("AccessToken")
28                .field("access_token", &redacted)
29                .finish(),
30        }
31    }
32}
33
34#[derive(Debug, Clone)]
35pub(crate) enum RegistryTokenType {
36    Bearer(RegistryToken),
37    Basic(String, String),
38}
39
40impl RegistryToken {
41    pub fn bearer_token(&self) -> String {
42        format!("Bearer {}", self.token())
43    }
44
45    pub fn token(&self) -> &str {
46        match self {
47            RegistryToken::Token { token } => token,
48            RegistryToken::AccessToken { access_token } => access_token,
49        }
50    }
51}
52
53/// Desired operation for registry authentication
54#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
55pub enum RegistryOperation {
56    /// Authenticate for push operations
57    Push,
58    /// Authenticate for pull operations
59    Pull,
60}
61
62#[derive(Debug, Deserialize)]
63struct BearerTokenClaims {
64    exp: Option<u64>,
65}
66
67#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
68struct TokenCacheKey {
69    registry: String,
70    repository: String,
71    operation: RegistryOperation,
72}
73
74struct TokenCacheValue {
75    token: RegistryTokenType,
76    expiration: u64,
77}
78
79#[derive(Clone)]
80pub(crate) struct TokenCache {
81    // (registry, repository, scope) -> (token, expiration)
82    tokens: Arc<RwLock<BTreeMap<TokenCacheKey, TokenCacheValue>>>,
83    /// Default token expiration in seconds, to use when claim doesn't specify a value
84    pub default_expiration_secs: usize,
85}
86
87impl TokenCache {
88    pub(crate) fn new(default_expiration_secs: usize) -> Self {
89        TokenCache {
90            tokens: Arc::new(RwLock::new(BTreeMap::new())),
91            default_expiration_secs,
92        }
93    }
94
95    pub(crate) async fn insert(
96        &self,
97        reference: &Reference,
98        op: RegistryOperation,
99        token: RegistryTokenType,
100    ) {
101        let expiration = match token {
102            RegistryTokenType::Basic(_, _) => u64::MAX,
103            RegistryTokenType::Bearer(ref t) => {
104                match parse_expiration_from_jwt(t.token(), self.default_expiration_secs) {
105                    Some(value) => value,
106                    None => return,
107                }
108            }
109        };
110        let registry = reference.resolve_registry().to_string();
111        let repository = reference.repository().to_string();
112        debug!(%registry, %repository, ?op, %expiration, "Inserting token");
113        self.tokens.write().await.insert(
114            TokenCacheKey {
115                registry,
116                repository,
117                operation: op,
118            },
119            TokenCacheValue { token, expiration },
120        );
121    }
122
123    pub(crate) async fn get(
124        &self,
125        reference: &Reference,
126        op: RegistryOperation,
127    ) -> Option<RegistryTokenType> {
128        let registry = reference.resolve_registry().to_string();
129        let repository = reference.repository().to_string();
130        let key = TokenCacheKey {
131            registry,
132            repository,
133            operation: op,
134        };
135        match self.tokens.read().await.get(&key) {
136            Some(TokenCacheValue {
137                ref token,
138                expiration,
139            }) => {
140                let now = SystemTime::now();
141                let epoch = now
142                    .duration_since(UNIX_EPOCH)
143                    .expect("Time went backwards")
144                    .as_secs();
145                if epoch > *expiration {
146                    debug!(%key.registry, %key.repository, ?key.operation, %expiration, miss=false, expired=true, "Fetching token");
147                    None
148                } else {
149                    debug!(%key.registry, %key.repository, ?key.operation, %expiration, miss=false, expired=false, "Fetching token");
150                    Some(token.clone())
151                }
152            }
153            None => {
154                debug!(%key.registry, %key.repository, ?key.operation, miss = true, "Fetching token");
155                None
156            }
157        }
158    }
159}
160
161fn parse_expiration_from_jwt(token_str: &str, default_expiration_secs: usize) -> Option<u64> {
162    match jsonwebtoken::dangerous::insecure_decode::<BearerTokenClaims>(token_str) {
163        Ok(token) => {
164            let token_exp = match token.claims.exp {
165                Some(exp) => exp,
166                None => {
167                    // the token doesn't have a claim that states a
168                    // value for the expiration. We assume it has a 60
169                    // seconds validity as indicated here:
170                    // https://docs.docker.com/reference/api/registry/auth/#token-response-fields
171                    // > (Optional) The duration in seconds since the token was issued
172                    // > that it will remain valid. When omitted, this defaults to 60 seconds.
173                    // > For compatibility with older clients, a token should never be returned
174                    // > with less than 60 seconds to live.
175                    let now = SystemTime::now();
176                    let epoch = now
177                        .duration_since(UNIX_EPOCH)
178                        .expect("Time went backwards")
179                        .as_secs();
180                    let expiration = epoch + default_expiration_secs as u64;
181                    debug!(?token, "Cannot extract expiration from token's claims, assuming a {} seconds validity", default_expiration_secs);
182                    expiration
183                }
184            };
185
186            Some(token_exp)
187        }
188        Err(error) if error.kind() == &jsonwebtoken::errors::ErrorKind::InvalidToken => {
189            // The token is not a JWT (e.g., an opaque token issued by registries
190            // like GHCR). Use the default expiration as a best-effort assumption,
191            // mirroring the behaviour for JWT tokens that carry no `exp` claim.
192            let epoch = SystemTime::now()
193                .duration_since(UNIX_EPOCH)
194                .expect("Time went backwards")
195                .as_secs();
196            debug!(
197                "Bearer token is not a JWT, assuming a {} seconds validity",
198                default_expiration_secs
199            );
200            Some(epoch + default_expiration_secs as u64)
201        }
202        Err(error) => {
203            warn!(?error, "Invalid bearer token");
204            None
205        }
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use jsonwebtoken::{EncodingKey, Header};
213    use oci_spec::distribution::Reference;
214    use serde::Serialize;
215
216    // An opaque token as issued by registries like GHCR — not a JWT.
217    const OPAQUE_TOKEN: &str = "ghs_exampleOpaqueTokenFromGHCR1234567890";
218
219    #[derive(Serialize)]
220    struct ClaimsWithExp {
221        exp: u64,
222    }
223
224    #[derive(Serialize)]
225    struct ClaimsWithoutExp {
226        sub: &'static str,
227    }
228
229    fn make_jwt_with_exp(exp: u64) -> String {
230        jsonwebtoken::encode(
231            &Header::default(),
232            &ClaimsWithExp { exp },
233            &EncodingKey::from_secret(b"secret"),
234        )
235        .expect("failed to encode JWT with exp")
236    }
237
238    fn make_jwt_without_exp() -> String {
239        jsonwebtoken::encode(
240            &Header::default(),
241            &ClaimsWithoutExp { sub: "test" },
242            &EncodingKey::from_secret(b"secret"),
243        )
244        .expect("failed to encode JWT without exp")
245    }
246
247    #[test]
248    fn jwt_with_exp_uses_claims_expiration() {
249        let token = make_jwt_with_exp(9999999999);
250        let exp = parse_expiration_from_jwt(&token, 60)
251            .expect("should return Some for valid JWT with exp");
252        assert_eq!(exp, 9999999999);
253    }
254
255    #[test]
256    fn jwt_without_exp_uses_default_expiration() {
257        let token = make_jwt_without_exp();
258        let before = SystemTime::now()
259            .duration_since(UNIX_EPOCH)
260            .unwrap()
261            .as_secs();
262        let exp =
263            parse_expiration_from_jwt(&token, 60).expect("should return Some for JWT without exp");
264        let after = SystemTime::now()
265            .duration_since(UNIX_EPOCH)
266            .unwrap()
267            .as_secs();
268        assert!(exp >= before + 60);
269        assert!(exp <= after + 60);
270    }
271
272    #[test]
273    fn opaque_token_uses_default_expiration() {
274        let before = SystemTime::now()
275            .duration_since(UNIX_EPOCH)
276            .unwrap()
277            .as_secs();
278        let exp = parse_expiration_from_jwt(OPAQUE_TOKEN, 60)
279            .expect("opaque token should return Some with default expiration");
280        let after = SystemTime::now()
281            .duration_since(UNIX_EPOCH)
282            .unwrap()
283            .as_secs();
284        assert!(exp >= before + 60);
285        assert!(exp <= after + 60);
286    }
287
288    #[tokio::test]
289    async fn opaque_token_is_cached() {
290        let cache = TokenCache::new(60);
291        let reference: Reference = "ghcr.io/kubewarden/policies/pod-privileged:v1.0.10"
292            .parse()
293            .unwrap();
294        let token = RegistryTokenType::Bearer(RegistryToken::Token {
295            token: OPAQUE_TOKEN.to_string(),
296        });
297
298        cache
299            .insert(&reference, RegistryOperation::Pull, token)
300            .await;
301
302        assert!(
303            cache
304                .get(&reference, RegistryOperation::Pull)
305                .await
306                .is_some(),
307            "opaque bearer token should be cached"
308        );
309    }
310}