use anyhow::{Context, Result, bail};
use base64::Engine;
use sccache::dist::http::{ClientAuthCheck, ClientVisibleMsg};
use sccache::util::{BASE64_URL_SAFE_ENGINE, new_reqwest_blocking_client};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::result::Result as StdResult;
use std::sync::Mutex;
use std::time::{Duration, Instant};
#[derive(Debug, Serialize, Deserialize)]
pub struct Jwks {
pub keys: Vec<Jwk>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Jwk {
pub kid: String,
kty: String,
n: String,
e: String,
}
impl Jwk {
pub fn to_der_pkcs1(&self) -> Result<Vec<u8>> {
if self.kty != "RSA" {
bail!("Cannot handle non-RSA JWK")
}
let n = BASE64_URL_SAFE_ENGINE
.decode(&self.n)
.context("Failed to base64 decode n")?;
let e = BASE64_URL_SAFE_ENGINE
.decode(&self.e)
.context("Failed to base64 decode e")?;
let n_bn = openssl::bn::BigNum::from_slice(&n)
.context("Failed to create openssl bignum from n")?;
let e_bn = openssl::bn::BigNum::from_slice(&e)
.context("Failed to create openssl bignum from e")?;
let pubkey = openssl::rsa::Rsa::from_public_components(n_bn, e_bn)
.context("Failed to create pubkey from n and e")?;
let der: Vec<u8> = pubkey
.public_key_to_der_pkcs1()
.context("Failed to convert public key to der pkcs1")?;
Ok(der)
}
}
pub struct EqCheck {
s: String,
}
impl ClientAuthCheck for EqCheck {
fn check(&self, token: &str) -> StdResult<(), ClientVisibleMsg> {
if self.s == token {
Ok(())
} else {
warn!("User token {} != expected token {}", token, self.s);
Err(ClientVisibleMsg::from_nonsensitive(
"Fixed token mismatch".to_owned(),
))
}
}
}
impl EqCheck {
pub fn new(s: String) -> Self {
Self { s }
}
}
pub struct ProxyTokenCheck {
client: reqwest::blocking::Client,
maybe_auth_cache: Option<Mutex<(HashMap<String, Instant>, Duration)>>,
url: String,
}
impl ClientAuthCheck for ProxyTokenCheck {
fn check(&self, token: &str) -> StdResult<(), ClientVisibleMsg> {
match self.check_token_with_forwarding(token) {
Ok(()) => Ok(()),
Err(e) => {
warn!("Proxying token validation failed: {}", e);
Err(ClientVisibleMsg::from_nonsensitive(
"Validation with token forwarding failed".to_owned(),
))
}
}
}
}
impl ProxyTokenCheck {
pub fn new(url: String, cache_secs: Option<u64>) -> Self {
let maybe_auth_cache: Option<Mutex<(HashMap<String, Instant>, Duration)>> =
cache_secs.map(|secs| Mutex::new((HashMap::new(), Duration::from_secs(secs))));
Self {
client: new_reqwest_blocking_client(),
maybe_auth_cache,
url,
}
}
fn check_token_with_forwarding(&self, token: &str) -> Result<()> {
trace!("Validating token by forwarding to {}", self.url);
if let Some(ref auth_cache) = self.maybe_auth_cache {
let mut auth_cache = auth_cache.lock().unwrap();
let (ref mut auth_cache, cache_duration) = *auth_cache;
if let Some(cached_at) = auth_cache.get(token) {
if cached_at.elapsed() < cache_duration {
return Ok(());
}
}
auth_cache.remove(token);
}
let res = self
.client
.get(&self.url)
.bearer_auth(token)
.send()
.context("Failed to make request to proxying url")?;
if !res.status().is_success() {
bail!("Token forwarded to {} returned {}", self.url, res.status());
}
if let Some(ref auth_cache) = self.maybe_auth_cache {
let mut auth_cache = auth_cache.lock().unwrap();
let (ref mut auth_cache, _) = *auth_cache;
auth_cache.insert(token.to_owned(), Instant::now());
}
Ok(())
}
}
pub struct ValidJWTCheck {
audience: String,
issuer: String,
kid_to_pkcs1: HashMap<String, Vec<u8>>,
}
impl ClientAuthCheck for ValidJWTCheck {
fn check(&self, token: &str) -> StdResult<(), ClientVisibleMsg> {
match self.check_jwt_validity(token) {
Ok(()) => Ok(()),
Err(e) => {
warn!("JWT validation failed: {}", e);
Err(ClientVisibleMsg::from_nonsensitive(
"JWT could not be validated".to_owned(),
))
}
}
}
}
impl ValidJWTCheck {
pub fn new(audience: String, issuer: String, jwks_url: &str) -> Result<Self> {
let res = reqwest::blocking::get(jwks_url).context("Failed to make request to JWKs url")?;
if !res.status().is_success() {
bail!("Could not retrieve JWKs, HTTP error: {}", res.status())
}
let jwks: Jwks = res.json().context("Failed to parse JWKs json")?;
let kid_to_pkcs1 = jwks
.keys
.into_iter()
.map(|k| k.to_der_pkcs1().map(|pkcs1| (k.kid, pkcs1)))
.collect::<Result<_>>()
.context("Failed to convert JWKs into pkcs1")?;
Ok(Self {
audience,
issuer,
kid_to_pkcs1,
})
}
fn check_jwt_validity(&self, token: &str) -> Result<()> {
let header = jwt::decode_header(token).context("Could not decode jwt header")?;
trace!("Validating JWT in scheduler");
let kid = header.kid.context("No kid found")?;
let pkcs1 = jwt::DecodingKey::from_rsa_der(
self.kid_to_pkcs1
.get(&kid)
.context("kid not found in jwks")?,
);
let mut validation = jwt::Validation::new(header.alg);
validation.set_audience(&[&self.audience]);
validation.set_issuer(&[&self.issuer]);
#[derive(Deserialize)]
struct Claims {}
let _tokendata = jwt::decode::<Claims>(token, &pkcs1, &validation)
.context("Unable to validate and decode jwt")?;
Ok(())
}
}