volli-core 0.1.11

Shared types for volli
Documentation
use base64::{Engine as _, engine::general_purpose};
use eyre::Report;
use hmac::{Hmac, Mac};
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use std::time::{SystemTime, UNIX_EPOCH};
use tracing::{trace, warn};

type HmacSha256 = Hmac<Sha256>;

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
pub struct TokenPayload {
    pub tenant: String,
    pub cluster: String,
    pub worker_id: String,
    pub host: Option<String>,
    pub quic_port: Option<u16>,
    pub tcp_port: Option<u16>,
    pub cert: Option<Vec<u8>>,
    pub iat: u64,
    pub exp: u64,
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
pub struct Token {
    pub payload: TokenPayload,
    pub sig: Vec<u8>,
}

pub fn issue_token(
    key: &[u8; 32],
    tenant: &str,
    cluster: &str,
    worker_id: &str,
    ttl_secs: u64,
) -> Result<Token, Report> {
    let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
    let payload = TokenPayload {
        tenant: tenant.to_string(),
        cluster: cluster.to_string(),
        worker_id: worker_id.to_string(),
        host: None,
        quic_port: None,
        tcp_port: None,
        cert: None,
        iat: now,
        exp: now + ttl_secs,
    };
    let encoded = bincode::serialize(&payload)?;
    let mut mac = HmacSha256::new_from_slice(key)?;
    mac.update(&encoded);
    let sig = mac.finalize().into_bytes().to_vec();
    Ok(Token { payload, sig })
}

#[allow(clippy::too_many_arguments)]
pub fn issue_bootstrap_token(
    key: &[u8; 32],
    tenant: &str,
    cluster: &str,
    worker_id: &str,
    ttl_secs: u64,
    host: &str,
    quic_port: u16,
    tcp_port: u16,
    cert: Vec<u8>,
) -> Result<Token, Report> {
    let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
    let payload = TokenPayload {
        tenant: tenant.to_string(),
        cluster: cluster.to_string(),
        worker_id: worker_id.to_string(),
        host: Some(host.to_string()),
        quic_port: Some(quic_port),
        tcp_port: Some(tcp_port),
        cert: Some(cert),
        iat: now,
        exp: now + ttl_secs,
    };
    let encoded = bincode::serialize(&payload)?;
    let mut mac = HmacSha256::new_from_slice(key)?;
    mac.update(&encoded);
    let sig = mac.finalize().into_bytes().to_vec();
    Ok(Token { payload, sig })
}

pub fn verify_token(token: &Token, key: &[u8; 32]) -> Result<(), Report> {
    let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
    if token.payload.exp < now {
        warn!("Token expired");
        return Err(eyre::eyre!("token expired"));
    }
    let encoded = bincode::serialize(&token.payload)?;
    trace!("Encoded token: {:?}", encoded);
    let mut mac = HmacSha256::new_from_slice(key)?;
    mac.update(&encoded);
    mac.verify_slice(&token.sig)?;
    Ok(())
}

pub fn verify_token_signature(token: &Token, key: &[u8; 32]) -> Result<(), Report> {
    let encoded = bincode::serialize(&token.payload)?;
    let mut mac = HmacSha256::new_from_slice(key)?;
    mac.update(&encoded);
    mac.verify_slice(&token.sig)?;
    Ok(())
}

pub fn refresh_token(token: &Token, key: &[u8; 32], ttl_secs: u64) -> Result<Token, Report> {
    verify_token(token, key)?;
    issue_token(
        key,
        &token.payload.tenant,
        &token.payload.cluster,
        &token.payload.worker_id,
        ttl_secs,
    )
}

pub fn encode_token(token: &Token) -> Result<String, Report> {
    let bytes = bincode::serialize(token)?;
    Ok(general_purpose::STANDARD_NO_PAD.encode(bytes))
}

pub fn decode_token(encoded: &str) -> Result<Token, Report> {
    let bytes = general_purpose::STANDARD_NO_PAD.decode(encoded)?;
    bincode::deserialize(&bytes).map_err(|e| eyre::eyre!("failed to deserialize token: {}", e))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn token_issue_verify_roundtrip() {
        let key = [1u8; 32];
        let token = issue_token(&key, "t", "c", "aid", 60).unwrap();
        verify_token(&token, &key).unwrap();
    }

    #[test]
    fn token_refresh_extends_expiry() {
        let key = [2u8; 32];
        let token = issue_token(&key, "t", "c", "a", 2).unwrap();
        verify_token(&token, &key).unwrap();
        let new_token = refresh_token(&token, &key, 5).unwrap();
        verify_token(&new_token, &key).unwrap();
        assert!(new_token.payload.exp > token.payload.exp);
    }

    #[test]
    fn encode_decode_roundtrip() {
        let key = [3u8; 32];
        let token = issue_token(&key, "t", "c", "aid", 30).unwrap();
        let encoded = encode_token(&token).unwrap();
        let decoded = decode_token(&encoded).unwrap();
        assert_eq!(decoded, token);
        verify_token(&decoded, &key).unwrap();
    }

    #[test]
    fn bootstrap_token_roundtrip() {
        let key = [4u8; 32];
        let cert = vec![1, 2, 3];
        let token =
            issue_bootstrap_token(&key, "t", "c", "aid", 60, "h", 1, 2, cert.clone()).unwrap();
        verify_token(&token, &key).unwrap();
        assert_eq!(token.payload.host.as_deref(), Some("h"));
        let encoded = encode_token(&token).unwrap();
        let decoded = decode_token(&encoded).unwrap();
        assert_eq!(decoded, token);
    }
}