use base64::{Engine as _, engine::general_purpose};
use eyre::Report;
use hmac::{Hmac, Mac};
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use std::time::{SystemTime, UNIX_EPOCH};
use tracing::{trace, warn};
type HmacSha256 = Hmac<Sha256>;
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
pub struct TokenPayload {
pub tenant: String,
pub cluster: String,
pub worker_id: String,
pub host: Option<String>,
pub quic_port: Option<u16>,
pub tcp_port: Option<u16>,
pub cert: Option<Vec<u8>>,
pub iat: u64,
pub exp: u64,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
pub struct Token {
pub payload: TokenPayload,
pub sig: Vec<u8>,
}
pub fn issue_token(
key: &[u8; 32],
tenant: &str,
cluster: &str,
worker_id: &str,
ttl_secs: u64,
) -> Result<Token, Report> {
let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
let payload = TokenPayload {
tenant: tenant.to_string(),
cluster: cluster.to_string(),
worker_id: worker_id.to_string(),
host: None,
quic_port: None,
tcp_port: None,
cert: None,
iat: now,
exp: now + ttl_secs,
};
let encoded = bincode::serialize(&payload)?;
let mut mac = HmacSha256::new_from_slice(key)?;
mac.update(&encoded);
let sig = mac.finalize().into_bytes().to_vec();
Ok(Token { payload, sig })
}
#[allow(clippy::too_many_arguments)]
pub fn issue_bootstrap_token(
key: &[u8; 32],
tenant: &str,
cluster: &str,
worker_id: &str,
ttl_secs: u64,
host: &str,
quic_port: u16,
tcp_port: u16,
cert: Vec<u8>,
) -> Result<Token, Report> {
let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
let payload = TokenPayload {
tenant: tenant.to_string(),
cluster: cluster.to_string(),
worker_id: worker_id.to_string(),
host: Some(host.to_string()),
quic_port: Some(quic_port),
tcp_port: Some(tcp_port),
cert: Some(cert),
iat: now,
exp: now + ttl_secs,
};
let encoded = bincode::serialize(&payload)?;
let mut mac = HmacSha256::new_from_slice(key)?;
mac.update(&encoded);
let sig = mac.finalize().into_bytes().to_vec();
Ok(Token { payload, sig })
}
pub fn verify_token(token: &Token, key: &[u8; 32]) -> Result<(), Report> {
let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
if token.payload.exp < now {
warn!("Token expired");
return Err(eyre::eyre!("token expired"));
}
let encoded = bincode::serialize(&token.payload)?;
trace!("Encoded token: {:?}", encoded);
let mut mac = HmacSha256::new_from_slice(key)?;
mac.update(&encoded);
mac.verify_slice(&token.sig)?;
Ok(())
}
pub fn verify_token_signature(token: &Token, key: &[u8; 32]) -> Result<(), Report> {
let encoded = bincode::serialize(&token.payload)?;
let mut mac = HmacSha256::new_from_slice(key)?;
mac.update(&encoded);
mac.verify_slice(&token.sig)?;
Ok(())
}
pub fn refresh_token(token: &Token, key: &[u8; 32], ttl_secs: u64) -> Result<Token, Report> {
verify_token(token, key)?;
issue_token(
key,
&token.payload.tenant,
&token.payload.cluster,
&token.payload.worker_id,
ttl_secs,
)
}
pub fn encode_token(token: &Token) -> Result<String, Report> {
let bytes = bincode::serialize(token)?;
Ok(general_purpose::STANDARD_NO_PAD.encode(bytes))
}
pub fn decode_token(encoded: &str) -> Result<Token, Report> {
let bytes = general_purpose::STANDARD_NO_PAD.decode(encoded)?;
bincode::deserialize(&bytes).map_err(|e| eyre::eyre!("failed to deserialize token: {}", e))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn token_issue_verify_roundtrip() {
let key = [1u8; 32];
let token = issue_token(&key, "t", "c", "aid", 60).unwrap();
verify_token(&token, &key).unwrap();
}
#[test]
fn token_refresh_extends_expiry() {
let key = [2u8; 32];
let token = issue_token(&key, "t", "c", "a", 2).unwrap();
verify_token(&token, &key).unwrap();
let new_token = refresh_token(&token, &key, 5).unwrap();
verify_token(&new_token, &key).unwrap();
assert!(new_token.payload.exp > token.payload.exp);
}
#[test]
fn encode_decode_roundtrip() {
let key = [3u8; 32];
let token = issue_token(&key, "t", "c", "aid", 30).unwrap();
let encoded = encode_token(&token).unwrap();
let decoded = decode_token(&encoded).unwrap();
assert_eq!(decoded, token);
verify_token(&decoded, &key).unwrap();
}
#[test]
fn bootstrap_token_roundtrip() {
let key = [4u8; 32];
let cert = vec![1, 2, 3];
let token =
issue_bootstrap_token(&key, "t", "c", "aid", 60, "h", 1, 2, cert.clone()).unwrap();
verify_token(&token, &key).unwrap();
assert_eq!(token.payload.host.as_deref(), Some("h"));
let encoded = encode_token(&token).unwrap();
let decoded = decode_token(&encoded).unwrap();
assert_eq!(decoded, token);
}
}