rok-core 0.6.1

Core primitives for the rok ecosystem — errors, crypto, i18n, config, DI, and more
Documentation
mod config;
mod error;
mod signer;

pub use config::EncryptConfig;
pub use error::EncryptError;
pub use signer::Signer;

use std::time::Duration;

use aes_gcm::{
    aead::{Aead, KeyInit},
    Aes256Gcm, Key, Nonce,
};
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};

#[derive(Serialize, Deserialize)]
struct Payload {
    v: String,
    #[serde(skip_serializing_if = "Option::is_none")]
    p: Option<String>,
    #[serde(skip_serializing_if = "Option::is_none")]
    e: Option<i64>,
}

fn derive_key(secret: &str) -> [u8; 32] {
    Sha256::digest(secret.as_bytes()).into()
}

#[derive(Clone)]
pub struct Encrypter {
    primary_key: [u8; 32],
    old_keys: Vec<[u8; 32]>,
}

impl Encrypter {
    pub fn from_config(config: EncryptConfig) -> Self {
        Self {
            primary_key: derive_key(&config.key),
            old_keys: config.old_keys.iter().map(|k| derive_key(k)).collect(),
        }
    }

    fn cipher(key: &[u8; 32]) -> Aes256Gcm {
        Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(key))
    }

    fn encrypt_payload(&self, payload: &Payload) -> String {
        let json = serde_json::to_vec(payload).expect("Payload is always serialisable");
        let nonce_bytes: [u8; 12] = rand::random();
        let nonce = Nonce::from_slice(&nonce_bytes);
        let mut ct = Self::cipher(&self.primary_key)
            .encrypt(nonce, json.as_slice())
            .expect("AES-256-GCM encryption is infallible for valid keys");
        let mut out = nonce_bytes.to_vec();
        out.append(&mut ct);
        URL_SAFE_NO_PAD.encode(out)
    }

    fn decrypt_token(&self, token: &str) -> Result<Payload, EncryptError> {
        let bytes = URL_SAFE_NO_PAD
            .decode(token)
            .map_err(|_| EncryptError::InvalidFormat)?;
        if bytes.len() <= 12 {
            return Err(EncryptError::InvalidFormat);
        }
        let (nonce_bytes, ciphertext) = bytes.split_at(12);
        let nonce = Nonce::from_slice(nonce_bytes);

        let keys = std::iter::once(&self.primary_key).chain(self.old_keys.iter());
        for key in keys {
            if let Ok(plaintext) = Self::cipher(key).decrypt(nonce, ciphertext) {
                return serde_json::from_slice(&plaintext).map_err(|_| EncryptError::InvalidFormat);
            }
        }
        Err(EncryptError::DecryptionFailed)
    }

    fn check_expiry(payload: &Payload) -> Result<(), EncryptError> {
        if let Some(exp) = payload.e {
            if chrono::Utc::now().timestamp() > exp {
                return Err(EncryptError::Expired);
            }
        }
        Ok(())
    }

    pub fn seal(&self, value: &str) -> String {
        self.encrypt_payload(&Payload {
            v: value.to_string(),
            p: None,
            e: None,
        })
    }

    pub fn open(&self, token: &str) -> Result<String, EncryptError> {
        let payload = self.decrypt_token(token)?;
        Self::check_expiry(&payload)?;
        Ok(payload.v)
    }

    pub fn try_open(&self, token: &str) -> Option<String> {
        self.open(token).ok()
    }

    pub fn seal_for(&self, purpose: &str, value: &str) -> String {
        self.encrypt_payload(&Payload {
            v: value.to_string(),
            p: Some(purpose.to_string()),
            e: None,
        })
    }

    pub fn open_for(&self, expected_purpose: &str, token: &str) -> Result<String, EncryptError> {
        let payload = self.decrypt_token(token)?;
        let actual = payload.p.as_deref().unwrap_or("(none)");
        if actual != expected_purpose {
            return Err(EncryptError::WrongPurpose {
                expected: expected_purpose.to_string(),
                actual: actual.to_string(),
            });
        }
        Self::check_expiry(&payload)?;
        Ok(payload.v)
    }

    pub fn seal_expiring(&self, value: &str, ttl: Duration) -> String {
        let expires_at = chrono::Utc::now().timestamp() + ttl.as_secs() as i64;
        self.encrypt_payload(&Payload {
            v: value.to_string(),
            p: None,
            e: Some(expires_at),
        })
    }

    pub fn seal_for_expiring(&self, purpose: &str, value: &str, ttl: Duration) -> String {
        let expires_at = chrono::Utc::now().timestamp() + ttl.as_secs() as i64;
        self.encrypt_payload(&Payload {
            v: value.to_string(),
            p: Some(purpose.to_string()),
            e: Some(expires_at),
        })
    }
}