use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use base64::Engine;
use chrono::Utc;
use k2db_api_contract::{ApiKeyIdentity, AuthHeader};
use scrypt::{Params, scrypt};
use subtle::ConstantTimeEq;
use tokio::sync::RwLock;
use tokio::task;
use crate::api_error::ApiError;
use crate::config::ApiKeyConfig;
const AUTH_CACHE_CAPACITY: usize = 1024;
#[derive(Debug, Clone)]
struct CachedIdentity {
identity: ApiKeyIdentity,
key_expires_at: Option<i64>,
}
#[derive(Debug, Clone)]
struct ParsedSecretHash {
log_n: u8,
r: u32,
p: u32,
salt: Vec<u8>,
expected: Vec<u8>,
}
#[derive(Debug, Clone)]
struct VerifiedKey {
identity: ApiKeyIdentity,
active: bool,
expires_at: Option<i64>,
verifier: ParsedSecretHash,
}
#[derive(Debug)]
struct PositiveAuthCache {
entries: HashMap<String, CachedIdentity>,
order: VecDeque<String>,
capacity: usize,
}
#[derive(Debug, Clone)]
pub struct AuthProvider {
keys: Arc<HashMap<String, VerifiedKey>>,
cache: Arc<RwLock<PositiveAuthCache>>,
}
impl AuthProvider {
pub fn new(keys: HashMap<String, ApiKeyConfig>) -> Result<Self, ApiError> {
Self::with_cache_capacity(keys, AUTH_CACHE_CAPACITY)
}
fn with_cache_capacity(
keys: HashMap<String, ApiKeyConfig>,
capacity: usize,
) -> Result<Self, ApiError> {
let mut parsed_keys = HashMap::with_capacity(keys.len());
for (name, config) in keys {
let verifier = ParsedSecretHash::parse(&config.secret_hash)?;
let identity = ApiKeyIdentity {
kind: "apikey".to_owned(),
key_id: config.key_id.clone(),
name,
database: config.database.clone(),
permissions: config.permissions.clone(),
};
parsed_keys.insert(
config.key_id.clone(),
VerifiedKey {
identity,
active: config.active,
expires_at: config.expires_at,
verifier,
},
);
}
Ok(Self {
keys: Arc::new(parsed_keys),
cache: Arc::new(RwLock::new(PositiveAuthCache::new(capacity))),
})
}
pub async fn verify(&self, header: Option<&str>) -> Result<ApiKeyIdentity, ApiError> {
let header = header.ok_or_else(|| {
ApiError::unauthorized("Missing API key", "t-auth-apikey-missing-001")
})?;
let cache_key = header.trim().to_owned();
if let Some(identity) = self.cache_get(&cache_key).await {
return Ok(identity);
}
let parsed = AuthHeader::parse(header).ok_or_else(|| {
if header.trim().starts_with("ApiKey ") {
ApiError::unauthorized(
"Malformed API key (expected \"<keyId>.<secret>\")",
"t-auth-apikey-format-003",
)
} else {
ApiError::unauthorized(
"Invalid authorization scheme",
"t-auth-apikey-scheme-002",
)
}
})?;
let Some(entry) = self.keys.get(&parsed.key_id) else {
return Err(ApiError::unauthorized(
"Unknown or inactive API key",
"t-auth-apikey-unknown-004",
));
};
if !entry.active {
return Err(ApiError::unauthorized(
"Unknown or inactive API key",
"t-auth-apikey-unknown-004",
));
}
if let Some(expires_at) = &entry.expires_at {
if Utc::now().timestamp_millis() >= *expires_at {
return Err(ApiError::unauthorized(
"API key expired",
"t-auth-apikey-expired-020",
));
}
}
verify_secret_async(parsed.secret, entry.verifier.clone()).await?;
self.cache_set(cache_key, entry.identity.clone(), entry.expires_at)
.await;
Ok(entry.identity.clone())
}
async fn cache_get(&self, key: &str) -> Option<ApiKeyIdentity> {
let mut cache = self.cache.write().await;
cache.get(key)
}
async fn cache_set(&self, key: String, identity: ApiKeyIdentity, key_expires_at: Option<i64>) {
let mut cache = self.cache.write().await;
cache.insert(
key,
CachedIdentity {
identity,
key_expires_at,
},
);
}
}
impl PositiveAuthCache {
fn new(capacity: usize) -> Self {
Self {
entries: HashMap::new(),
order: VecDeque::new(),
capacity,
}
}
fn get(&mut self, key: &str) -> Option<ApiKeyIdentity> {
let Some(entry) = self.entries.get(key) else {
return None;
};
if is_expired(entry.key_expires_at) {
self.entries.remove(key);
return None;
}
Some(entry.identity.clone())
}
fn insert(&mut self, key: String, entry: CachedIdentity) {
if self.capacity == 0 {
return;
}
if let Some(existing) = self.entries.get_mut(&key) {
*existing = entry;
return;
}
while self.entries.len() >= self.capacity {
let Some(oldest) = self.order.pop_front() else {
break;
};
if self.entries.remove(&oldest).is_some() {
break;
}
}
self.order.push_back(key.clone());
self.entries.insert(key, entry);
}
}
pub fn require_permission(identity: &ApiKeyIdentity, permission: &str) -> Result<(), ApiError> {
if identity.permissions.iter().any(|item| item == permission) {
return Ok(());
}
Err(ApiError::forbidden(
format!("Missing required permission: {permission}"),
"t-auth-permission-denied-001",
))
}
fn verify_secret(secret: &str, verifier: &ParsedSecretHash) -> Result<(), ApiError> {
let params = Params::new(
verifier.log_n,
verifier.r,
verifier.p,
verifier.expected.len(),
)
.map_err(|_| {
ApiError::internal("Malformed scrypt parameters", "t-auth-apikey-hash-format-005")
})?;
let mut derived = vec![0_u8; verifier.expected.len()];
scrypt(secret.as_bytes(), &verifier.salt, ¶ms, &mut derived).map_err(|_| {
ApiError::internal(
"Failed to verify API key secret",
"t-auth-apikey-badsecret-018",
)
})?;
if derived.as_slice().ct_eq(verifier.expected.as_slice()).into() {
Ok(())
} else {
Err(ApiError::unauthorized(
"Invalid API key",
"t-auth-apikey-badsecret-018",
))
}
}
async fn verify_secret_async(secret: String, verifier: ParsedSecretHash) -> Result<(), ApiError> {
task::spawn_blocking(move || verify_secret(&secret, &verifier))
.await
.map_err(|_| ApiError::internal("API key verification task failed", "t-auth-apikey-verify-022"))?
}
impl ParsedSecretHash {
fn parse(secret_hash: &str) -> Result<Self, ApiError> {
let parts = secret_hash.split('$').collect::<Vec<_>>();
if parts.len() != 6 {
return Err(ApiError::internal(
"Malformed API key hash (expected 6 parts: scrypt$N$r$p$salt$hash)",
"t-auth-apikey-hash-format-005",
));
}
if parts[0] != "scrypt" {
return Err(ApiError::internal(
format!("Unsupported key hash algo: {}", parts[0]),
"t-auth-apikey-hash-algo-007",
));
}
let n = parse_u32(parts[1], "t-auth-apikey-hash-N-013")?;
if !n.is_power_of_two() {
return Err(ApiError::internal(
format!("Invalid scrypt N: {}", parts[1]),
"t-auth-apikey-hash-N-013",
));
}
Ok(Self {
log_n: n.ilog2() as u8,
r: parse_u32(parts[2], "t-auth-apikey-hash-r-014")?,
p: parse_u32(parts[3], "t-auth-apikey-hash-p-015")?,
salt: decode_base64url(parts[4], "t-auth-apikey-hash-salt-016")?,
expected: decode_base64url(parts[5], "t-auth-apikey-hash-hash-017")?,
})
}
}
fn is_expired(expires_at: Option<i64>) -> bool {
expires_at
.map(|value| Utc::now().timestamp_millis() >= value)
.unwrap_or(false)
}
fn parse_u32(value: &str, trace: &'static str) -> Result<u32, ApiError> {
value
.parse::<u32>()
.map_err(|_| ApiError::internal(format!("Invalid numeric value: {value}"), trace))
}
fn decode_base64url(value: &str, trace: &'static str) -> Result<Vec<u8>, ApiError> {
if value.is_empty() {
return Err(ApiError::internal(
"Invalid base64url input (empty or non-string)",
trace,
));
}
let mut normalized = value.replace('-', "+").replace('_', "/");
match normalized.len() % 4 {
0 => {}
2 => normalized.push_str("=="),
3 => normalized.push('='),
_ => {
return Err(ApiError::internal("Invalid base64url length", trace));
}
}
base64::engine::general_purpose::STANDARD
.decode(normalized)
.map_err(|_| ApiError::internal("Invalid base64 encoding", trace))
}
#[cfg(test)]
mod tests {
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use super::*;
fn hashed_secret(secret: &str) -> String {
let salt = b"0123456789abcdef";
let params = Params::new(14, 8, 1, 32).expect("params");
let mut out = vec![0_u8; 32];
scrypt(secret.as_bytes(), salt, ¶ms, &mut out).expect("scrypt");
format!(
"scrypt$16384$8$1${}${}",
URL_SAFE_NO_PAD.encode(salt),
URL_SAFE_NO_PAD.encode(out)
)
}
#[tokio::test]
async fn verify_accepts_valid_key() {
let provider = AuthProvider::new(HashMap::from([(
"worker".to_owned(),
ApiKeyConfig {
key_id: "demo".to_owned(),
secret_hash: hashed_secret("secret-value"),
database: "project_a".to_owned(),
permissions: vec!["collections.read".to_owned()],
active: true,
expires_at: None,
},
)]))
.expect("provider");
let identity = provider
.verify(Some("ApiKey rbk_sys_demo.secret-value"))
.await
.expect("identity");
assert_eq!(identity.key_id, "demo");
assert_eq!(identity.database, "project_a");
}
#[tokio::test]
async fn verify_rejects_bad_secret() {
let provider = AuthProvider::new(HashMap::from([(
"worker".to_owned(),
ApiKeyConfig {
key_id: "demo".to_owned(),
secret_hash: hashed_secret("secret-value"),
database: "project_a".to_owned(),
permissions: vec![],
active: true,
expires_at: None,
},
)]))
.expect("provider");
assert!(provider.verify(Some("ApiKey demo.nope")).await.is_err());
}
#[tokio::test]
async fn verify_cache_eviction_is_fifo() {
let provider = AuthProvider::with_cache_capacity(
HashMap::from([
(
"worker-a".to_owned(),
ApiKeyConfig {
key_id: "demo-a".to_owned(),
secret_hash: hashed_secret("secret-a"),
database: "project_a".to_owned(),
permissions: vec![],
active: true,
expires_at: None,
},
),
(
"worker-b".to_owned(),
ApiKeyConfig {
key_id: "demo-b".to_owned(),
secret_hash: hashed_secret("secret-b"),
database: "project_b".to_owned(),
permissions: vec![],
active: true,
expires_at: None,
},
),
(
"worker-c".to_owned(),
ApiKeyConfig {
key_id: "demo-c".to_owned(),
secret_hash: hashed_secret("secret-c"),
database: "project_c".to_owned(),
permissions: vec![],
active: true,
expires_at: None,
},
),
]),
2,
)
.expect("provider");
provider.verify(Some("ApiKey demo-a.secret-a")).await.expect("a");
provider.verify(Some("ApiKey demo-b.secret-b")).await.expect("b");
provider.verify(Some("ApiKey demo-c.secret-c")).await.expect("c");
let cache = provider.cache.read().await;
assert!(!cache.entries.contains_key("ApiKey demo-a.secret-a"));
assert!(cache.entries.contains_key("ApiKey demo-b.secret-b"));
assert!(cache.entries.contains_key("ApiKey demo-c.secret-c"));
}
}