use std::collections::HashSet;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use crate::filter_ir::{AuthCapabilities, AuthScope};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CapabilityToken {
pub version: u8,
pub token_id: String,
pub allowed_namespaces: Vec<String>,
pub tenant_id: Option<String>,
pub project_id: Option<String>,
pub capabilities: TokenCapabilities,
pub issued_at: u64,
pub expires_at: u64,
pub acl_tags: Vec<String>,
pub signature: Vec<u8>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TokenCapabilities {
pub can_read: bool,
pub can_write: bool,
pub can_delete: bool,
pub can_admin: bool,
pub can_delegate: bool,
}
impl CapabilityToken {
pub const CURRENT_VERSION: u8 = 1;
pub fn is_expired(&self) -> bool {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
now > self.expires_at
}
pub fn is_namespace_allowed(&self, namespace: &str) -> bool {
self.allowed_namespaces.iter().any(|ns| ns == namespace)
}
pub fn to_auth_scope(&self) -> AuthScope {
AuthScope {
allowed_namespaces: self.allowed_namespaces.clone(),
tenant_id: self.tenant_id.clone(),
project_id: self.project_id.clone(),
expires_at: Some(self.expires_at),
capabilities: AuthCapabilities {
can_read: self.capabilities.can_read,
can_write: self.capabilities.can_write,
can_delete: self.capabilities.can_delete,
can_admin: self.capabilities.can_admin,
},
acl_tags: self.acl_tags.clone(),
}
}
pub fn remaining_validity(&self) -> Option<Duration> {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
if now >= self.expires_at {
None
} else {
Some(Duration::from_secs(self.expires_at - now))
}
}
}
pub struct TokenBuilder {
namespaces: Vec<String>,
tenant_id: Option<String>,
project_id: Option<String>,
capabilities: TokenCapabilities,
validity: Duration,
acl_tags: Vec<String>,
}
impl TokenBuilder {
pub fn new(namespace: impl Into<String>) -> Self {
Self {
namespaces: vec![namespace.into()],
tenant_id: None,
project_id: None,
capabilities: TokenCapabilities {
can_read: true,
..Default::default()
},
validity: Duration::from_secs(3600), acl_tags: Vec::new(),
}
}
pub fn with_namespace(mut self, namespace: impl Into<String>) -> Self {
self.namespaces.push(namespace.into());
self
}
pub fn with_tenant(mut self, tenant_id: impl Into<String>) -> Self {
self.tenant_id = Some(tenant_id.into());
self
}
pub fn with_project(mut self, project_id: impl Into<String>) -> Self {
self.project_id = Some(project_id.into());
self
}
pub fn can_read(mut self) -> Self {
self.capabilities.can_read = true;
self
}
pub fn can_write(mut self) -> Self {
self.capabilities.can_write = true;
self
}
pub fn can_delete(mut self) -> Self {
self.capabilities.can_delete = true;
self
}
pub fn can_admin(mut self) -> Self {
self.capabilities.can_admin = true;
self
}
pub fn full_access(mut self) -> Self {
self.capabilities = TokenCapabilities {
can_read: true,
can_write: true,
can_delete: true,
can_admin: true,
can_delegate: false,
};
self
}
pub fn valid_for(mut self, duration: Duration) -> Self {
self.validity = duration;
self
}
pub fn with_acl_tags(mut self, tags: Vec<String>) -> Self {
self.acl_tags = tags;
self
}
pub fn build_unsigned(self) -> CapabilityToken {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
CapabilityToken {
version: CapabilityToken::CURRENT_VERSION,
token_id: generate_token_id(),
allowed_namespaces: self.namespaces,
tenant_id: self.tenant_id,
project_id: self.project_id,
capabilities: self.capabilities,
issued_at: now,
expires_at: now + self.validity.as_secs(),
acl_tags: self.acl_tags,
signature: Vec::new(),
}
}
}
fn generate_token_id() -> String {
format!(
"tok_{:x}",
std::time::SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos()
)
}
pub struct TokenSigner {
secret: Vec<u8>,
}
impl TokenSigner {
pub fn new(secret: impl AsRef<[u8]>) -> Self {
Self {
secret: secret.as_ref().to_vec(),
}
}
pub fn sign(&self, token: &mut CapabilityToken) {
let payload = self.compute_payload(token);
token.signature = self.hmac_sha256(&payload);
}
pub fn verify(&self, token: &CapabilityToken) -> Result<(), TokenError> {
if token.version != CapabilityToken::CURRENT_VERSION {
return Err(TokenError::UnsupportedVersion(token.version));
}
if token.is_expired() {
return Err(TokenError::Expired);
}
let payload = self.compute_payload(token);
let expected = self.hmac_sha256(&payload);
if !constant_time_eq(&token.signature, &expected) {
return Err(TokenError::InvalidSignature);
}
Ok(())
}
fn compute_payload(&self, token: &CapabilityToken) -> Vec<u8> {
let mut payload = Vec::new();
payload.push(token.version);
payload.extend(token.token_id.as_bytes());
for ns in &token.allowed_namespaces {
payload.extend(ns.as_bytes());
payload.push(0); }
if let Some(ref tenant) = token.tenant_id {
payload.extend(tenant.as_bytes());
}
payload.push(0);
if let Some(ref project) = token.project_id {
payload.extend(project.as_bytes());
}
payload.push(0);
let caps = (token.capabilities.can_read as u8)
| ((token.capabilities.can_write as u8) << 1)
| ((token.capabilities.can_delete as u8) << 2)
| ((token.capabilities.can_admin as u8) << 3)
| ((token.capabilities.can_delegate as u8) << 4);
payload.push(caps);
payload.extend(&token.issued_at.to_le_bytes());
payload.extend(&token.expires_at.to_le_bytes());
for tag in &token.acl_tags {
payload.extend(tag.as_bytes());
payload.push(0);
}
payload
}
fn hmac_sha256(&self, data: &[u8]) -> Vec<u8> {
use ring::hmac;
let key = hmac::Key::new(hmac::HMAC_SHA256, &self.secret);
let tag = hmac::sign(&key, data);
tag.as_ref().to_vec() }
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff = 0u8;
for (x, y) in a.iter().zip(b.iter()) {
diff |= x ^ y;
}
diff == 0
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum TokenError {
#[error("token has expired")]
Expired,
#[error("invalid signature")]
InvalidSignature,
#[error("unsupported token version: {0}")]
UnsupportedVersion(u8),
#[error("token revoked")]
Revoked,
#[error("namespace not allowed: {0}")]
NamespaceNotAllowed(String),
#[error("insufficient capabilities")]
InsufficientCapabilities,
}
pub struct RevocationList {
revoked: std::sync::RwLock<HashSet<String>>,
}
impl RevocationList {
pub fn new() -> Self {
Self {
revoked: std::sync::RwLock::new(HashSet::new()),
}
}
pub fn revoke(&self, token_id: &str) {
self.revoked.write().unwrap().insert(token_id.to_string());
}
pub fn is_revoked(&self, token_id: &str) -> bool {
self.revoked.read().unwrap().contains(token_id)
}
pub fn count(&self) -> usize {
self.revoked.read().unwrap().len()
}
}
impl Default for RevocationList {
fn default() -> Self {
Self::new()
}
}
pub struct TokenValidator {
signer: TokenSigner,
revocation_list: RevocationList,
}
impl TokenValidator {
pub fn new(secret: impl AsRef<[u8]>) -> Self {
Self {
signer: TokenSigner::new(secret),
revocation_list: RevocationList::new(),
}
}
pub fn issue(&self, builder: TokenBuilder) -> CapabilityToken {
let mut token = builder.build_unsigned();
self.signer.sign(&mut token);
token
}
pub fn validate(&self, token: &CapabilityToken) -> Result<AuthScope, TokenError> {
if self.revocation_list.is_revoked(&token.token_id) {
return Err(TokenError::Revoked);
}
self.signer.verify(token)?;
Ok(token.to_auth_scope())
}
pub fn revoke(&self, token_id: &str) {
self.revocation_list.revoke(token_id);
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct AclTag(String);
impl AclTag {
pub fn new(tag: impl Into<String>) -> Self {
Self(tag.into())
}
pub fn name(&self) -> &str {
&self.0
}
}
#[derive(Debug, Default)]
pub struct AclTagIndex {
tag_to_docs: std::collections::HashMap<String, Vec<u64>>,
}
impl AclTagIndex {
pub fn new() -> Self {
Self::default()
}
pub fn add_tag(&mut self, doc_id: u64, tag: &str) {
self.tag_to_docs
.entry(tag.to_string())
.or_default()
.push(doc_id);
}
pub fn docs_with_tag(&self, tag: &str) -> &[u64] {
self.tag_to_docs
.get(tag)
.map(|v| v.as_slice())
.unwrap_or(&[])
}
pub fn accessible_docs(&self, allowed_tags: &[String]) -> Vec<u64> {
let mut result = HashSet::new();
for tag in allowed_tags {
if let Some(docs) = self.tag_to_docs.get(tag) {
result.extend(docs.iter().copied());
}
}
result.into_iter().collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_builder() {
let token = TokenBuilder::new("production")
.with_namespace("staging")
.with_tenant("acme")
.can_read()
.can_write()
.valid_for(Duration::from_secs(3600))
.build_unsigned();
assert_eq!(token.allowed_namespaces.len(), 2);
assert_eq!(token.tenant_id, Some("acme".to_string()));
assert!(token.capabilities.can_read);
assert!(token.capabilities.can_write);
assert!(!token.capabilities.can_delete);
}
#[test]
fn test_token_signing_and_verification() {
let signer = TokenSigner::new("super_secret_key");
let mut token = TokenBuilder::new("production")
.can_read()
.valid_for(Duration::from_secs(3600))
.build_unsigned();
signer.sign(&mut token);
assert!(!token.signature.is_empty());
assert!(signer.verify(&token).is_ok());
token.allowed_namespaces.push("hacked".to_string());
assert!(signer.verify(&token).is_err());
}
#[test]
fn test_token_expiry() {
let mut token = TokenBuilder::new("production")
.valid_for(Duration::from_secs(3600))
.build_unsigned();
token.expires_at = 0;
assert!(token.is_expired());
}
#[test]
fn test_token_to_auth_scope() {
let token = TokenBuilder::new("production")
.with_tenant("acme")
.can_read()
.can_write()
.with_acl_tags(vec!["public".to_string(), "internal".to_string()])
.build_unsigned();
let scope = token.to_auth_scope();
assert!(scope.is_namespace_allowed("production"));
assert!(!scope.is_namespace_allowed("staging"));
assert_eq!(scope.tenant_id, Some("acme".to_string()));
assert!(scope.capabilities.can_read);
assert!(scope.capabilities.can_write);
assert_eq!(scope.acl_tags.len(), 2);
}
#[test]
fn test_revocation() {
let validator = TokenValidator::new("secret");
let token = validator.issue(
TokenBuilder::new("production")
.can_read()
.valid_for(Duration::from_secs(3600)),
);
assert!(validator.validate(&token).is_ok());
validator.revoke(&token.token_id);
assert!(matches!(
validator.validate(&token),
Err(TokenError::Revoked)
));
}
#[test]
fn test_acl_tag_index() {
let mut index = AclTagIndex::new();
index.add_tag(1, "public");
index.add_tag(2, "public");
index.add_tag(3, "internal");
index.add_tag(4, "confidential");
assert_eq!(index.docs_with_tag("public").len(), 2);
assert_eq!(index.docs_with_tag("internal").len(), 1);
let accessible = index.accessible_docs(&["public".to_string(), "internal".to_string()]);
assert_eq!(accessible.len(), 3);
}
}