1use oci_spec::distribution::Reference;
2use serde::Deserialize;
3use std::collections::BTreeMap;
4use std::fmt;
5use std::sync::Arc;
6use std::time::{SystemTime, UNIX_EPOCH};
7use tokio::sync::RwLock;
8use tracing::{debug, warn};
9
10#[derive(Deserialize, Clone)]
12#[serde(untagged)]
13#[serde(rename_all = "snake_case")]
14pub(crate) enum RegistryToken {
15 Token { token: String },
16 AccessToken { access_token: String },
17}
18
19impl fmt::Debug for RegistryToken {
20 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21 let redacted = String::from("<redacted>");
22 match self {
23 RegistryToken::Token { .. } => {
24 f.debug_struct("Token").field("token", &redacted).finish()
25 }
26 RegistryToken::AccessToken { .. } => f
27 .debug_struct("AccessToken")
28 .field("access_token", &redacted)
29 .finish(),
30 }
31 }
32}
33
34#[derive(Debug, Clone)]
35pub(crate) enum RegistryTokenType {
36 Bearer(RegistryToken),
37 Basic(String, String),
38}
39
40impl RegistryToken {
41 pub fn bearer_token(&self) -> String {
42 format!("Bearer {}", self.token())
43 }
44
45 pub fn token(&self) -> &str {
46 match self {
47 RegistryToken::Token { token } => token,
48 RegistryToken::AccessToken { access_token } => access_token,
49 }
50 }
51}
52
53#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
55pub enum RegistryOperation {
56 Push,
58 Pull,
60}
61
62#[derive(Debug, Deserialize)]
63struct BearerTokenClaims {
64 exp: Option<u64>,
65}
66
67#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
68struct TokenCacheKey {
69 registry: String,
70 repository: String,
71 operation: RegistryOperation,
72}
73
74struct TokenCacheValue {
75 token: RegistryTokenType,
76 expiration: u64,
77}
78
79#[derive(Clone)]
80pub(crate) struct TokenCache {
81 tokens: Arc<RwLock<BTreeMap<TokenCacheKey, TokenCacheValue>>>,
83 pub default_expiration_secs: usize,
85}
86
87impl TokenCache {
88 pub(crate) fn new(default_expiration_secs: usize) -> Self {
89 TokenCache {
90 tokens: Arc::new(RwLock::new(BTreeMap::new())),
91 default_expiration_secs,
92 }
93 }
94
95 pub(crate) async fn insert(
96 &self,
97 reference: &Reference,
98 op: RegistryOperation,
99 token: RegistryTokenType,
100 ) {
101 let expiration = match token {
102 RegistryTokenType::Basic(_, _) => u64::MAX,
103 RegistryTokenType::Bearer(ref t) => {
104 match parse_expiration_from_jwt(t.token(), self.default_expiration_secs) {
105 Some(value) => value,
106 None => return,
107 }
108 }
109 };
110 let registry = reference.resolve_registry().to_string();
111 let repository = reference.repository().to_string();
112 debug!(%registry, %repository, ?op, %expiration, "Inserting token");
113 self.tokens.write().await.insert(
114 TokenCacheKey {
115 registry,
116 repository,
117 operation: op,
118 },
119 TokenCacheValue { token, expiration },
120 );
121 }
122
123 pub(crate) async fn get(
124 &self,
125 reference: &Reference,
126 op: RegistryOperation,
127 ) -> Option<RegistryTokenType> {
128 let registry = reference.resolve_registry().to_string();
129 let repository = reference.repository().to_string();
130 let key = TokenCacheKey {
131 registry,
132 repository,
133 operation: op,
134 };
135 match self.tokens.read().await.get(&key) {
136 Some(TokenCacheValue {
137 ref token,
138 expiration,
139 }) => {
140 let now = SystemTime::now();
141 let epoch = now
142 .duration_since(UNIX_EPOCH)
143 .expect("Time went backwards")
144 .as_secs();
145 if epoch > *expiration {
146 debug!(%key.registry, %key.repository, ?key.operation, %expiration, miss=false, expired=true, "Fetching token");
147 None
148 } else {
149 debug!(%key.registry, %key.repository, ?key.operation, %expiration, miss=false, expired=false, "Fetching token");
150 Some(token.clone())
151 }
152 }
153 None => {
154 debug!(%key.registry, %key.repository, ?key.operation, miss = true, "Fetching token");
155 None
156 }
157 }
158 }
159}
160
161fn parse_expiration_from_jwt(token_str: &str, default_expiration_secs: usize) -> Option<u64> {
162 match jsonwebtoken::dangerous::insecure_decode::<BearerTokenClaims>(token_str) {
163 Ok(token) => {
164 let token_exp = match token.claims.exp {
165 Some(exp) => exp,
166 None => {
167 let now = SystemTime::now();
176 let epoch = now
177 .duration_since(UNIX_EPOCH)
178 .expect("Time went backwards")
179 .as_secs();
180 let expiration = epoch + default_expiration_secs as u64;
181 debug!(?token, "Cannot extract expiration from token's claims, assuming a {} seconds validity", default_expiration_secs);
182 expiration
183 }
184 };
185
186 Some(token_exp)
187 }
188 Err(error) if error.kind() == &jsonwebtoken::errors::ErrorKind::InvalidToken => {
189 let epoch = SystemTime::now()
193 .duration_since(UNIX_EPOCH)
194 .expect("Time went backwards")
195 .as_secs();
196 debug!(
197 "Bearer token is not a JWT, assuming a {} seconds validity",
198 default_expiration_secs
199 );
200 Some(epoch + default_expiration_secs as u64)
201 }
202 Err(error) => {
203 warn!(?error, "Invalid bearer token");
204 None
205 }
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212 use jsonwebtoken::{EncodingKey, Header};
213 use oci_spec::distribution::Reference;
214 use serde::Serialize;
215
216 const OPAQUE_TOKEN: &str = "ghs_exampleOpaqueTokenFromGHCR1234567890";
218
219 #[derive(Serialize)]
220 struct ClaimsWithExp {
221 exp: u64,
222 }
223
224 #[derive(Serialize)]
225 struct ClaimsWithoutExp {
226 sub: &'static str,
227 }
228
229 fn make_jwt_with_exp(exp: u64) -> String {
230 jsonwebtoken::encode(
231 &Header::default(),
232 &ClaimsWithExp { exp },
233 &EncodingKey::from_secret(b"secret"),
234 )
235 .expect("failed to encode JWT with exp")
236 }
237
238 fn make_jwt_without_exp() -> String {
239 jsonwebtoken::encode(
240 &Header::default(),
241 &ClaimsWithoutExp { sub: "test" },
242 &EncodingKey::from_secret(b"secret"),
243 )
244 .expect("failed to encode JWT without exp")
245 }
246
247 #[test]
248 fn jwt_with_exp_uses_claims_expiration() {
249 let token = make_jwt_with_exp(9999999999);
250 let exp = parse_expiration_from_jwt(&token, 60)
251 .expect("should return Some for valid JWT with exp");
252 assert_eq!(exp, 9999999999);
253 }
254
255 #[test]
256 fn jwt_without_exp_uses_default_expiration() {
257 let token = make_jwt_without_exp();
258 let before = SystemTime::now()
259 .duration_since(UNIX_EPOCH)
260 .unwrap()
261 .as_secs();
262 let exp =
263 parse_expiration_from_jwt(&token, 60).expect("should return Some for JWT without exp");
264 let after = SystemTime::now()
265 .duration_since(UNIX_EPOCH)
266 .unwrap()
267 .as_secs();
268 assert!(exp >= before + 60);
269 assert!(exp <= after + 60);
270 }
271
272 #[test]
273 fn opaque_token_uses_default_expiration() {
274 let before = SystemTime::now()
275 .duration_since(UNIX_EPOCH)
276 .unwrap()
277 .as_secs();
278 let exp = parse_expiration_from_jwt(OPAQUE_TOKEN, 60)
279 .expect("opaque token should return Some with default expiration");
280 let after = SystemTime::now()
281 .duration_since(UNIX_EPOCH)
282 .unwrap()
283 .as_secs();
284 assert!(exp >= before + 60);
285 assert!(exp <= after + 60);
286 }
287
288 #[tokio::test]
289 async fn opaque_token_is_cached() {
290 let cache = TokenCache::new(60);
291 let reference: Reference = "ghcr.io/kubewarden/policies/pod-privileged:v1.0.10"
292 .parse()
293 .unwrap();
294 let token = RegistryTokenType::Bearer(RegistryToken::Token {
295 token: OPAQUE_TOKEN.to_string(),
296 });
297
298 cache
299 .insert(&reference, RegistryOperation::Pull, token)
300 .await;
301
302 assert!(
303 cache
304 .get(&reference, RegistryOperation::Pull)
305 .await
306 .is_some(),
307 "opaque bearer token should be cached"
308 );
309 }
310}