use std::collections::HashMap;
use std::sync::{Mutex, OnceLock};
use std::time::Duration;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use tracing::warn;
use crate::error::{Error, Result, StorageError};
#[async_trait]
pub trait Storage: Send + Sync {
async fn store(&self, key: &str, value: &[u8]) -> Result<()>;
async fn load(&self, key: &str) -> Result<Vec<u8>>;
async fn delete(&self, key: &str) -> Result<()>;
async fn exists(&self, key: &str) -> Result<bool>;
async fn list(&self, path: &str, recursive: bool) -> Result<Vec<String>>;
async fn stat(&self, key: &str) -> Result<KeyInfo>;
async fn lock(&self, name: &str) -> Result<()>;
async fn unlock(&self, name: &str) -> Result<()>;
async fn try_lock(&self, name: &str, timeout: Duration) -> Result<bool> {
match tokio::time::timeout(timeout, self.lock(name)).await {
Ok(Ok(())) => Ok(true),
Ok(Err(e)) => Err(e),
Err(_) => Ok(false), }
}
}
static OWNED_LOCKS: OnceLock<Mutex<HashMap<String, ()>>> = OnceLock::new();
fn owned_locks() -> &'static Mutex<HashMap<String, ()>> {
OWNED_LOCKS.get_or_init(|| Mutex::new(HashMap::new()))
}
pub fn track_lock(name: &str) {
if let Ok(mut map) = owned_locks().lock() {
map.insert(name.to_string(), ());
}
}
pub fn untrack_lock(name: &str) {
if let Ok(mut map) = owned_locks().lock() {
map.remove(name);
}
}
pub async fn cleanup_own_locks(storage: &dyn Storage) {
let names: Vec<String> = {
match owned_locks().lock() {
Ok(map) => map.keys().cloned().collect(),
Err(_) => return,
}
};
for name in &names {
if let Err(e) = storage.unlock(name).await {
warn!(lock = %name, error = %e, "failed to release lock during cleanup");
}
}
if let Ok(mut map) = owned_locks().lock() {
map.clear();
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeyInfo {
pub key: String,
pub modified: DateTime<Utc>,
pub size: u64,
pub is_terminal: bool,
}
const PREFIX_CERTS: &str = "certificates";
const PREFIX_OCSP: &str = "ocsp";
const PREFIX_ACME: &str = "acme";
const PREFIX_LOCKS: &str = "locks";
pub fn safe_key(s: &str) -> String {
let s = s.to_lowercase();
let s = s.trim().to_owned();
let s = s
.replace(' ', "_")
.replace('+', "_plus_")
.replace('*', "wildcard_")
.replace(':', "-")
.replace("..", "");
let mut out = String::with_capacity(s.len());
for ch in s.chars() {
if ch.is_alphanumeric() || ch == '_' || ch == '@' || ch == '.' || ch == '-' {
out.push(ch);
}
}
out
}
fn path_join(parts: &[&str]) -> String {
parts
.iter()
.filter(|p| !p.is_empty())
.map(|p| p.trim_matches('/'))
.collect::<Vec<_>>()
.join("/")
}
pub fn certs_prefix(issuer_key: &str) -> String {
path_join(&[PREFIX_CERTS, &safe_key(issuer_key)])
}
pub fn certs_site_prefix(issuer_key: &str, domain: &str) -> String {
path_join(&[&certs_prefix(issuer_key), &safe_key(domain)])
}
pub fn site_cert_key(issuer_key: &str, domain: &str) -> String {
let safe_domain = safe_key(domain);
let filename = format!("{safe_domain}.crt");
path_join(&[&certs_site_prefix(issuer_key, domain), &filename])
}
pub fn site_private_key(issuer_key: &str, domain: &str) -> String {
let safe_domain = safe_key(domain);
let filename = format!("{safe_domain}.key");
path_join(&[&certs_site_prefix(issuer_key, domain), &filename])
}
pub fn site_meta_key(issuer_key: &str, domain: &str) -> String {
let safe_domain = safe_key(domain);
let filename = format!("{safe_domain}.json");
path_join(&[&certs_site_prefix(issuer_key, domain), &filename])
}
pub fn ocsp_key(domain: &str, hash: &str) -> String {
let mut filename = String::new();
if !domain.is_empty() {
filename.push_str(&safe_key(domain));
filename.push('-');
}
filename.push_str(hash);
path_join(&[PREFIX_OCSP, &filename])
}
pub fn issuer_key(ca_url: &str) -> String {
match url::Url::parse(ca_url) {
Ok(parsed) => {
let host = parsed.host_str().unwrap_or(ca_url);
let path = parsed.path();
if path.is_empty() || path == "/" {
host.to_owned()
} else {
let collapsed = path.replace(['/', '\\'], "-");
let collapsed = collapsed.trim_matches('-');
if collapsed.is_empty() {
host.to_owned()
} else {
format!("{host}-{collapsed}")
}
}
}
Err(_) => ca_url.to_owned(),
}
}
pub fn acme_ca_prefix(issuer_key: &str) -> String {
path_join(&[PREFIX_ACME, &safe_key(issuer_key)])
}
pub fn account_key_prefix(issuer_key: &str, email: &str) -> String {
let email = if email.is_empty() { "default" } else { email };
path_join(&[&acme_ca_prefix(issuer_key), "users", &safe_key(email)])
}
pub fn locks_key(name: &str) -> String {
path_join(&[PREFIX_LOCKS, &safe_key(name)])
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StorageKeys {
pub cert: String,
pub key: String,
pub meta: String,
}
impl StorageKeys {
pub fn new(issuer_key: &str, domain: &str) -> Self {
Self {
cert: site_cert_key(issuer_key, domain),
key: site_private_key(issuer_key, domain),
meta: site_meta_key(issuer_key, domain),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CertificateResource {
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub sans: Vec<String>,
#[serde(skip)]
pub certificate_pem: Vec<u8>,
#[serde(skip)]
pub private_key_pem: Vec<u8>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub issuer_data: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "String::is_empty")]
pub issuer_key: String,
}
impl CertificateResource {
pub fn names_key(&self) -> String {
let mut names = self.sans.clone();
names.sort();
let mut result = names.join(",");
const MAX_LEN: usize = 1024;
const TRUNC_SUFFIX: &str = "_trunc";
if result.len() > MAX_LEN {
result.truncate(MAX_LEN - TRUNC_SUFFIX.len());
result.push_str(TRUNC_SUFFIX);
}
result
}
}
struct KeyValue {
key: String,
value: Vec<u8>,
}
async fn store_tx(storage: &dyn Storage, items: &[KeyValue]) -> Result<()> {
for (i, kv) in items.iter().enumerate() {
if let Err(e) = storage.store(&kv.key, &kv.value).await {
for prev in items[..i].iter().rev() {
let _ = storage.delete(&prev.key).await;
}
return Err(e);
}
}
Ok(())
}
pub async fn store_certificate(
storage: &dyn Storage,
issuer_key: &str,
cert: &CertificateResource,
) -> Result<()> {
let cert_key_name = cert.names_key();
let meta_bytes = serde_json::to_vec_pretty(cert).map_err(|e| {
Error::Storage(StorageError::Serialize(format!(
"encoding certificate metadata: {e}"
)))
})?;
let items = [
KeyValue {
key: site_private_key(issuer_key, &cert_key_name),
value: cert.private_key_pem.clone(),
},
KeyValue {
key: site_cert_key(issuer_key, &cert_key_name),
value: cert.certificate_pem.clone(),
},
KeyValue {
key: site_meta_key(issuer_key, &cert_key_name),
value: meta_bytes,
},
];
store_tx(storage, &items).await
}
pub async fn load_certificate(
storage: &dyn Storage,
issuer_key: &str,
domain: &str,
) -> Result<CertificateResource> {
let key_bytes = storage.load(&site_private_key(issuer_key, domain)).await?;
let cert_bytes = storage.load(&site_cert_key(issuer_key, domain)).await?;
let meta_bytes = storage.load(&site_meta_key(issuer_key, domain)).await?;
let mut cert_res: CertificateResource = serde_json::from_slice(&meta_bytes).map_err(|e| {
Error::Storage(StorageError::Deserialize(format!(
"decoding certificate metadata: {e}"
)))
})?;
cert_res.private_key_pem = key_bytes;
cert_res.certificate_pem = cert_bytes;
cert_res.issuer_key = issuer_key.to_owned();
Ok(cert_res)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn safe_key_lowercase_and_trim() {
assert_eq!(safe_key(" Hello World "), "hello_world");
}
#[test]
fn safe_key_replaces_special_chars() {
assert_eq!(safe_key("a+b"), "a_plus_b");
assert_eq!(safe_key("*.example.com"), "wildcard_.example.com");
assert_eq!(safe_key("host:port"), "host-port");
}
#[test]
fn safe_key_prevents_directory_traversal() {
assert_eq!(safe_key("a/../../../foo"), "afoo");
assert_eq!(safe_key("b\\..\\..\\..\\foo"), "bfoo");
}
#[test]
fn safe_key_strips_slashes() {
assert_eq!(safe_key("c/foo"), "cfoo");
}
#[test]
fn safe_key_idempotent() {
let once = safe_key("*.Example.COM");
let twice = safe_key(&once);
assert_eq!(once, twice);
}
#[test]
fn issuer_key_from_url() {
assert_eq!(
issuer_key("https://example.com/acme-ca/directory"),
"example.com-acme-ca-directory"
);
}
#[test]
fn issuer_key_no_path() {
assert_eq!(issuer_key("https://acme.example.com"), "acme.example.com");
}
#[test]
fn issuer_key_non_url() {
assert_eq!(issuer_key("not-a-url"), "not-a-url");
}
#[test]
fn site_cert_key_format() {
let ik = issuer_key("https://example.com/acme-ca/directory");
assert_eq!(
site_cert_key(&ik, "example.com"),
"certificates/example.com-acme-ca-directory/example.com/example.com.crt"
);
}
#[test]
fn site_key_key_format() {
let ik = issuer_key("https://example.com/acme-ca/directory");
assert_eq!(
site_private_key(&ik, "example.com"),
"certificates/example.com-acme-ca-directory/example.com/example.com.key"
);
}
#[test]
fn site_meta_key_format() {
let ik = issuer_key("https://example.com/acme-ca/directory");
assert_eq!(
site_meta_key(&ik, "example.com"),
"certificates/example.com-acme-ca-directory/example.com/example.com.json"
);
}
#[test]
fn wildcard_key_format() {
let ik = issuer_key("https://example.com/acme-ca/directory");
let base = "certificates/example.com-acme-ca-directory";
assert_eq!(
site_cert_key(&ik, "*.example.com"),
format!("{base}/wildcard_.example.com/wildcard_.example.com.crt")
);
assert_eq!(
site_private_key(&ik, "*.example.com"),
format!("{base}/wildcard_.example.com/wildcard_.example.com.key")
);
assert_eq!(
site_meta_key(&ik, "*.example.com"),
format!("{base}/wildcard_.example.com/wildcard_.example.com.json")
);
}
#[test]
fn traversal_key_sanitized() {
let ik = issuer_key("https://example.com/acme-ca/directory");
let base = "certificates/example.com-acme-ca-directory";
assert_eq!(
site_cert_key(&ik, "a/../../../foo"),
format!("{base}/afoo/afoo.crt")
);
assert_eq!(site_cert_key(&ik, "c/foo"), format!("{base}/cfoo/cfoo.crt"));
}
#[test]
fn storage_keys_new() {
let ik = "example.com-acme-ca-directory";
let sk = StorageKeys::new(ik, "example.com");
assert!(sk.cert.ends_with(".crt"));
assert!(sk.key.ends_with(".key"));
assert!(sk.meta.ends_with(".json"));
}
#[test]
fn names_key_basic() {
let cr = CertificateResource {
sans: vec!["b.example.com".into(), "a.example.com".into()],
certificate_pem: vec![],
private_key_pem: vec![],
issuer_data: None,
issuer_key: String::new(),
};
assert_eq!(cr.names_key(), "a.example.com,b.example.com");
}
#[test]
fn names_key_truncation() {
let long_name = "x".repeat(200);
let sans: Vec<String> = (0..10).map(|i| format!("{long_name}{i}")).collect();
let cr = CertificateResource {
sans,
certificate_pem: vec![],
private_key_pem: vec![],
issuer_data: None,
issuer_key: String::new(),
};
let key = cr.names_key();
assert!(key.len() <= 1024);
assert!(key.ends_with("_trunc"));
}
#[test]
fn ocsp_key_with_domain() {
assert_eq!(ocsp_key("example.com", "abc123"), "ocsp/example.com-abc123");
}
#[test]
fn ocsp_key_without_domain() {
assert_eq!(ocsp_key("", "abc123"), "ocsp/abc123");
}
#[test]
fn locks_key_basic() {
assert_eq!(locks_key("my-lock"), "locks/my-lock");
}
#[test]
fn account_key_prefix_with_email() {
let ak = account_key_prefix("example.com-directory", "user@example.com");
assert_eq!(ak, "acme/example.com-directory/users/user@example.com");
}
#[test]
fn account_key_prefix_empty_email() {
let ak = account_key_prefix("example.com-directory", "");
assert_eq!(ak, "acme/example.com-directory/users/default");
}
}