use oci_spec::distribution::Reference;
use serde::Deserialize;
use std::collections::BTreeMap;
use std::fmt;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
use tracing::{debug, warn};
#[derive(Deserialize, Clone)]
#[serde(untagged)]
#[serde(rename_all = "snake_case")]
pub(crate) enum RegistryToken {
Token { token: String },
AccessToken { access_token: String },
}
impl fmt::Debug for RegistryToken {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let redacted = String::from("<redacted>");
match self {
RegistryToken::Token { .. } => {
f.debug_struct("Token").field("token", &redacted).finish()
}
RegistryToken::AccessToken { .. } => f
.debug_struct("AccessToken")
.field("access_token", &redacted)
.finish(),
}
}
}
#[derive(Debug, Clone)]
pub(crate) enum RegistryTokenType {
Bearer(RegistryToken),
Basic(String, String),
}
impl RegistryToken {
pub fn bearer_token(&self) -> String {
format!("Bearer {}", self.token())
}
pub fn token(&self) -> &str {
match self {
RegistryToken::Token { token } => token,
RegistryToken::AccessToken { access_token } => access_token,
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum RegistryOperation {
Push,
Pull,
}
#[derive(Debug, Deserialize)]
struct BearerTokenClaims {
exp: Option<u64>,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
struct TokenCacheKey {
registry: String,
repository: String,
operation: RegistryOperation,
}
struct TokenCacheValue {
token: RegistryTokenType,
expiration: u64,
}
#[derive(Clone)]
pub(crate) struct TokenCache {
tokens: Arc<RwLock<BTreeMap<TokenCacheKey, TokenCacheValue>>>,
pub default_expiration_secs: usize,
}
impl TokenCache {
pub(crate) fn new(default_expiration_secs: usize) -> Self {
TokenCache {
tokens: Arc::new(RwLock::new(BTreeMap::new())),
default_expiration_secs,
}
}
pub(crate) async fn insert(
&self,
reference: &Reference,
op: RegistryOperation,
token: RegistryTokenType,
) {
let expiration = match token {
RegistryTokenType::Basic(_, _) => u64::MAX,
RegistryTokenType::Bearer(ref t) => {
match parse_expiration_from_jwt(t.token(), self.default_expiration_secs) {
Some(value) => value,
None => return,
}
}
};
let registry = reference.resolve_registry().to_string();
let repository = reference.repository().to_string();
debug!(%registry, %repository, ?op, %expiration, "Inserting token");
self.tokens.write().await.insert(
TokenCacheKey {
registry,
repository,
operation: op,
},
TokenCacheValue { token, expiration },
);
}
pub(crate) async fn get(
&self,
reference: &Reference,
op: RegistryOperation,
) -> Option<RegistryTokenType> {
let registry = reference.resolve_registry().to_string();
let repository = reference.repository().to_string();
let key = TokenCacheKey {
registry,
repository,
operation: op,
};
match self.tokens.read().await.get(&key) {
Some(TokenCacheValue {
ref token,
expiration,
}) => {
let now = SystemTime::now();
let epoch = now
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_secs();
if epoch > *expiration {
debug!(%key.registry, %key.repository, ?key.operation, %expiration, miss=false, expired=true, "Fetching token");
None
} else {
debug!(%key.registry, %key.repository, ?key.operation, %expiration, miss=false, expired=false, "Fetching token");
Some(token.clone())
}
}
None => {
debug!(%key.registry, %key.repository, ?key.operation, miss = true, "Fetching token");
None
}
}
}
}
fn parse_expiration_from_jwt(token_str: &str, default_expiration_secs: usize) -> Option<u64> {
match jsonwebtoken::dangerous::insecure_decode::<BearerTokenClaims>(token_str) {
Ok(token) => {
let token_exp = match token.claims.exp {
Some(exp) => exp,
None => {
let now = SystemTime::now();
let epoch = now
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_secs();
let expiration = epoch + default_expiration_secs as u64;
debug!(?token, "Cannot extract expiration from token's claims, assuming a {} seconds validity", default_expiration_secs);
expiration
}
};
Some(token_exp)
}
Err(error) if error.kind() == &jsonwebtoken::errors::ErrorKind::InvalidToken => {
let epoch = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_secs();
debug!(
"Bearer token is not a JWT, assuming a {} seconds validity",
default_expiration_secs
);
Some(epoch + default_expiration_secs as u64)
}
Err(error) => {
warn!(?error, "Invalid bearer token");
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use jsonwebtoken::{EncodingKey, Header};
use oci_spec::distribution::Reference;
use serde::Serialize;
const OPAQUE_TOKEN: &str = "ghs_exampleOpaqueTokenFromGHCR1234567890";
#[derive(Serialize)]
struct ClaimsWithExp {
exp: u64,
}
#[derive(Serialize)]
struct ClaimsWithoutExp {
sub: &'static str,
}
fn make_jwt_with_exp(exp: u64) -> String {
jsonwebtoken::encode(
&Header::default(),
&ClaimsWithExp { exp },
&EncodingKey::from_secret(b"secret"),
)
.expect("failed to encode JWT with exp")
}
fn make_jwt_without_exp() -> String {
jsonwebtoken::encode(
&Header::default(),
&ClaimsWithoutExp { sub: "test" },
&EncodingKey::from_secret(b"secret"),
)
.expect("failed to encode JWT without exp")
}
#[test]
fn jwt_with_exp_uses_claims_expiration() {
let token = make_jwt_with_exp(9999999999);
let exp = parse_expiration_from_jwt(&token, 60)
.expect("should return Some for valid JWT with exp");
assert_eq!(exp, 9999999999);
}
#[test]
fn jwt_without_exp_uses_default_expiration() {
let token = make_jwt_without_exp();
let before = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let exp =
parse_expiration_from_jwt(&token, 60).expect("should return Some for JWT without exp");
let after = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
assert!(exp >= before + 60);
assert!(exp <= after + 60);
}
#[test]
fn opaque_token_uses_default_expiration() {
let before = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let exp = parse_expiration_from_jwt(OPAQUE_TOKEN, 60)
.expect("opaque token should return Some with default expiration");
let after = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
assert!(exp >= before + 60);
assert!(exp <= after + 60);
}
#[tokio::test]
async fn opaque_token_is_cached() {
let cache = TokenCache::new(60);
let reference: Reference = "ghcr.io/kubewarden/policies/pod-privileged:v1.0.10"
.parse()
.unwrap();
let token = RegistryTokenType::Bearer(RegistryToken::Token {
token: OPAQUE_TOKEN.to_string(),
});
cache
.insert(&reference, RegistryOperation::Pull, token)
.await;
assert!(
cache
.get(&reference, RegistryOperation::Pull)
.await
.is_some(),
"opaque bearer token should be cached"
);
}
}