use reqwest::tls::Identity;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::{Mutex, RwLock};
use tracing::{error, info, warn};
#[derive(Clone, Debug)]
pub struct Svid {
pub spiffe_id: String,
pub cert_pem: Vec<u8>,
pub key_pem: Vec<u8>,
pub issued_at: SystemTime,
pub expires_at: SystemTime,
pub serial: String,
}
impl Svid {
pub fn is_valid(&self) -> bool {
SystemTime::now() < self.expires_at
}
pub fn should_renew(&self) -> bool {
let issued = self
.issued_at
.duration_since(UNIX_EPOCH)
.unwrap_or_default();
let expires = self
.expires_at
.duration_since(UNIX_EPOCH)
.unwrap_or_default();
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default();
let lifetime = expires.as_secs().saturating_sub(issued.as_secs());
let elapsed = now.as_secs().saturating_sub(issued.as_secs());
elapsed >= lifetime / 2
}
pub fn remaining(&self) -> Duration {
self.expires_at
.duration_since(SystemTime::now())
.unwrap_or(Duration::ZERO)
}
pub fn to_identity(&self) -> Result<Identity, MtlsError> {
let mut pem_bundle = Vec::with_capacity(self.key_pem.len() + self.cert_pem.len());
pem_bundle.extend_from_slice(&self.key_pem);
pem_bundle.extend_from_slice(&self.cert_pem);
Identity::from_pem(&pem_bundle)
.map_err(|e| MtlsError::Identity(format!("Failed to create TLS identity: {}", e)))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MtlsConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_trust_domain")]
pub trust_domain: String,
#[serde(default = "default_service_name")]
pub service_name: String,
#[serde(default = "default_namespace")]
pub namespace: String,
#[serde(default = "default_cert_ttl")]
pub cert_ttl_secs: u64,
#[serde(default)]
pub ca_cert_path: Option<String>,
#[serde(default)]
pub ca_key_path: Option<String>,
#[serde(default)]
pub trust_bundle_path: Option<String>,
#[serde(default)]
pub allow_fallback: bool,
#[serde(default = "default_true")]
pub verify_hostname: bool,
}
fn default_trust_domain() -> String {
"cluster.local".to_string()
}
fn default_service_name() -> String {
"router".to_string()
}
fn default_namespace() -> String {
"default".to_string()
}
fn default_cert_ttl() -> u64 {
3600 }
fn default_true() -> bool {
true
}
impl Default for MtlsConfig {
fn default() -> Self {
Self {
enabled: false,
trust_domain: default_trust_domain(),
service_name: default_service_name(),
namespace: default_namespace(),
cert_ttl_secs: default_cert_ttl(),
ca_cert_path: None,
ca_key_path: None,
trust_bundle_path: None,
allow_fallback: false,
verify_hostname: true,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum MtlsError {
#[error("Certificate generation failed: {0}")]
CertGeneration(String),
#[error("Failed to load CA: {0}")]
CaLoad(String),
#[error("Certificate expired: {0}")]
Expired(String),
#[error("Identity creation failed: {0}")]
Identity(String),
#[error("TLS configuration error: {0}")]
TlsConfig(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("SPIFFE ID validation error: {0}")]
SpiffeId(String),
}
fn resolve_temp_dir() -> std::path::PathBuf {
let raw = std::env::temp_dir();
match std::fs::canonicalize(&raw) {
Ok(canonical) => {
#[cfg(windows)]
{
let s = canonical.to_string_lossy();
if let Some(stripped) = s.strip_prefix(r"\\?\") {
return std::path::PathBuf::from(stripped);
}
}
canonical
}
Err(_) => raw,
}
}
fn write_secret_file(path: &std::path::Path, contents: &[u8]) -> std::io::Result<()> {
#[cfg(unix)]
{
use std::io::Write;
use std::os::unix::fs::OpenOptionsExt;
let mut file = std::fs::OpenOptions::new()
.write(true)
.create_new(true)
.mode(0o600)
.open(path)?;
file.write_all(contents)
}
#[cfg(not(unix))]
{
std::fs::write(path, contents)
}
}
pub struct CertificateAuthority {
ca_cert_pem: Vec<u8>,
ca_key_pem: Vec<u8>,
trust_domain: String,
serial_counter: std::sync::atomic::AtomicU64,
}
impl CertificateAuthority {
pub fn new_ephemeral(trust_domain: &str) -> Result<Self, MtlsError> {
use std::process::Command;
let ca_key_output = Command::new("openssl")
.args(["ecparam", "-genkey", "-name", "prime256v1", "-noout"])
.output()
.map_err(|e| {
MtlsError::CertGeneration(format!(
"Failed to run openssl ecparam: {}. Is openssl installed?",
e
))
})?;
if !ca_key_output.status.success() {
return Err(MtlsError::CertGeneration(format!(
"openssl ecparam failed: {}",
String::from_utf8_lossy(&ca_key_output.stderr)
)));
}
let ca_key_pem = ca_key_output.stdout;
static TEMP_COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
let unique_id = TEMP_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let tmp_base = resolve_temp_dir();
let key_tmp = tmp_base.join(format!(
"gbp_eca_key_{}_{}.pem",
std::process::id(),
unique_id
));
write_secret_file(&key_tmp, &ca_key_pem)
.map_err(|e| MtlsError::CertGeneration(format!("Failed to write temp key: {}", e)))?;
let ca_cert_output = Command::new("openssl")
.env("MSYS_NO_PATHCONV", "1") .args([
"req",
"-new",
"-x509",
"-key",
&key_tmp.to_string_lossy(),
"-sha256",
"-days",
"365",
"-subj",
&format!("/O=GBP Gateway/CN=GBP mTLS CA [{}]", trust_domain),
])
.output()
.map_err(|e| MtlsError::CertGeneration(format!("Failed to run openssl req: {}", e)))?;
let _ = std::fs::remove_file(&key_tmp);
if !ca_cert_output.status.success() {
return Err(MtlsError::CertGeneration(format!(
"openssl req failed: {}",
String::from_utf8_lossy(&ca_cert_output.stderr)
)));
}
let ca_cert_pem = ca_cert_output.stdout;
info!(
trust_domain = %trust_domain,
"🔐 Ephemeral mTLS Certificate Authority initialized"
);
Ok(Self {
ca_cert_pem,
ca_key_pem,
trust_domain: trust_domain.to_string(),
serial_counter: std::sync::atomic::AtomicU64::new(1),
})
}
pub fn from_files(
cert_path: &str,
key_path: &str,
trust_domain: &str,
) -> Result<Self, MtlsError> {
let ca_cert_pem = std::fs::read(cert_path).map_err(|e| {
MtlsError::CaLoad(format!("Failed to read CA cert {}: {}", cert_path, e))
})?;
let ca_key_pem = std::fs::read(key_path)
.map_err(|e| MtlsError::CaLoad(format!("Failed to read CA key {}: {}", key_path, e)))?;
info!(
cert_path = %cert_path,
trust_domain = %trust_domain,
"🔐 Loaded mTLS CA from files"
);
Ok(Self {
ca_cert_pem,
ca_key_pem,
trust_domain: trust_domain.to_string(),
serial_counter: std::sync::atomic::AtomicU64::new(1),
})
}
pub fn issue_svid(
&self,
namespace: &str,
service_name: &str,
ttl: Duration,
) -> Result<Svid, MtlsError> {
use std::process::Command;
let spiffe_id = format!(
"spiffe://{}/ns/{}/sa/{}",
self.trust_domain, namespace, service_name
);
let serial = self
.serial_counter
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let serial_hex = format!("{:016X}", serial);
let ttl_secs = ttl.as_secs();
let key_output = Command::new("openssl")
.args(["ecparam", "-genkey", "-name", "prime256v1", "-noout"])
.output()
.map_err(|e| {
MtlsError::CertGeneration(format!("Failed to generate workload key: {}", e))
})?;
if !key_output.status.success() {
return Err(MtlsError::CertGeneration(format!(
"Workload key generation failed: {}",
String::from_utf8_lossy(&key_output.stderr)
)));
}
let key_pem = key_output.stdout;
static SVID_TEMP_COUNTER: std::sync::atomic::AtomicU64 =
std::sync::atomic::AtomicU64::new(1_000_000); let tmp_dir = resolve_temp_dir();
let pid = std::process::id();
let unique_id = SVID_TEMP_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let key_path = tmp_dir.join(format!("gbp_svid_key_{pid}_{unique_id}.pem"));
let ca_cert_path = tmp_dir.join(format!("gbp_svid_cacert_{pid}_{unique_id}.pem"));
let ca_key_path = tmp_dir.join(format!("gbp_svid_cakey_{pid}_{unique_id}.pem"));
let csr_path = tmp_dir.join(format!("gbp_svid_csr_{pid}_{unique_id}.pem"));
let ext_path = tmp_dir.join(format!("gbp_svid_ext_{pid}_{unique_id}.cnf"));
write_secret_file(&key_path, &key_pem)
.map_err(|e| MtlsError::CertGeneration(format!("Write key: {}", e)))?;
std::fs::write(&ca_cert_path, &self.ca_cert_pem)
.map_err(|e| MtlsError::CertGeneration(format!("Write CA cert: {}", e)))?;
write_secret_file(&ca_key_path, &self.ca_key_pem)
.map_err(|e| MtlsError::CertGeneration(format!("Write CA key: {}", e)))?;
let csr_output = Command::new("openssl")
.env("MSYS_NO_PATHCONV", "1") .args([
"req",
"-new",
"-key",
&key_path.to_string_lossy(),
"-subj",
&format!("/O=GBP Workload/CN={}", service_name),
"-out",
&csr_path.to_string_lossy(),
])
.output()
.map_err(|e| MtlsError::CertGeneration(format!("CSR generation failed: {}", e)))?;
if !csr_output.status.success() {
Self::cleanup_temp_files(&[
&key_path,
&ca_cert_path,
&ca_key_path,
&csr_path,
&ext_path,
]);
return Err(MtlsError::CertGeneration(format!(
"CSR generation failed: {}",
String::from_utf8_lossy(&csr_output.stderr)
)));
}
let ext_content = format!(
"[v3_svid]\n\
basicConstraints = CA:FALSE\n\
keyUsage = digitalSignature, keyEncipherment\n\
extendedKeyUsage = serverAuth, clientAuth\n\
subjectAltName = URI:{}\n",
spiffe_id
);
std::fs::write(&ext_path, &ext_content)
.map_err(|e| MtlsError::CertGeneration(format!("Write extensions: {}", e)))?;
let ttl_days = std::cmp::max(1, ttl_secs / 86400);
let cert_output = Command::new("openssl")
.args([
"x509",
"-req",
"-in",
&csr_path.to_string_lossy(),
"-CA",
&ca_cert_path.to_string_lossy(),
"-CAkey",
&ca_key_path.to_string_lossy(),
"-set_serial",
&format!("0x{}", serial_hex),
"-days",
&ttl_days.to_string(),
"-sha256",
"-extfile",
&ext_path.to_string_lossy(),
"-extensions",
"v3_svid",
])
.output()
.map_err(|e| MtlsError::CertGeneration(format!("Certificate signing failed: {}", e)))?;
Self::cleanup_temp_files(&[&key_path, &ca_cert_path, &ca_key_path, &csr_path, &ext_path]);
if !cert_output.status.success() {
return Err(MtlsError::CertGeneration(format!(
"Certificate signing failed: {}",
String::from_utf8_lossy(&cert_output.stderr)
)));
}
let cert_pem = cert_output.stdout;
let now = SystemTime::now();
let expires_at = now + ttl;
info!(
spiffe_id = %spiffe_id,
serial = %serial_hex,
ttl_secs = ttl_secs,
"🔑 Issued SVID certificate"
);
Ok(Svid {
spiffe_id,
cert_pem,
key_pem,
issued_at: now,
expires_at,
serial: serial_hex,
})
}
pub fn trust_bundle(&self) -> &[u8] {
&self.ca_cert_pem
}
pub fn export_trust_bundle(&self, path: &str) -> Result<(), MtlsError> {
std::fs::write(path, &self.ca_cert_pem)?;
info!(path = %path, "📦 Exported mTLS trust bundle");
Ok(())
}
fn cleanup_temp_files(paths: &[&PathBuf]) {
for path in paths {
let _ = std::fs::remove_file(path);
}
}
}
pub struct MtlsProvider {
config: MtlsConfig,
ca: Arc<CertificateAuthority>,
current_svid: Arc<RwLock<Svid>>,
rotation_lock: Arc<Mutex<()>>,
}
impl MtlsProvider {
pub fn new(config: MtlsConfig) -> Result<Self, MtlsError> {
let ca = if let (Some(cert_path), Some(key_path)) =
(&config.ca_cert_path, &config.ca_key_path)
{
CertificateAuthority::from_files(cert_path, key_path, &config.trust_domain)?
} else {
CertificateAuthority::new_ephemeral(&config.trust_domain)?
};
let ca = Arc::new(ca);
if let Some(ref bundle_path) = config.trust_bundle_path {
ca.export_trust_bundle(bundle_path)?;
}
let ttl = Duration::from_secs(config.cert_ttl_secs);
let svid = ca.issue_svid(&config.namespace, &config.service_name, ttl)?;
let current_svid = Arc::new(RwLock::new(svid));
info!(
trust_domain = %config.trust_domain,
service = %config.service_name,
namespace = %config.namespace,
cert_ttl_secs = config.cert_ttl_secs,
"🔒 mTLS Provider initialized (Zero-Trust mode)"
);
Ok(Self {
config,
ca,
current_svid,
rotation_lock: Arc::new(Mutex::new(())),
})
}
pub async fn get_svid(&self) -> Result<Svid, MtlsError> {
{
let svid = self.current_svid.read().await;
if svid.is_valid() && !svid.should_renew() {
return Ok(svid.clone());
}
}
let _rotation_guard = self.rotation_lock.lock().await;
{
let svid = self.current_svid.read().await;
if svid.is_valid() && !svid.should_renew() {
return Ok(svid.clone());
}
}
self.rotate_locked().await
}
async fn rotate_locked(&self) -> Result<Svid, MtlsError> {
let ttl = Duration::from_secs(self.config.cert_ttl_secs);
let new_svid =
self.ca
.issue_svid(&self.config.namespace, &self.config.service_name, ttl)?;
info!(
serial = %new_svid.serial,
remaining_secs = new_svid.remaining().as_secs(),
"🔄 SVID rotated"
);
let mut current = self.current_svid.write().await;
*current = new_svid.clone();
Ok(new_svid)
}
pub async fn build_client(&self) -> Result<reqwest::Client, MtlsError> {
let svid = self.get_svid().await?;
self.build_client_with_svid(&svid)
}
pub async fn rotate(&self) -> Result<Svid, MtlsError> {
let _guard = self.rotation_lock.lock().await;
self.rotate_locked().await
}
pub fn build_client_with_svid(&self, svid: &Svid) -> Result<reqwest::Client, MtlsError> {
let identity = svid.to_identity()?;
let ca_cert = reqwest::tls::Certificate::from_pem(self.ca.trust_bundle())
.map_err(|e| MtlsError::TlsConfig(format!("Failed to load CA certificate: {}", e)))?;
let mut builder = reqwest::Client::builder()
.identity(identity)
.add_root_certificate(ca_cert)
.min_tls_version(reqwest::tls::Version::TLS_1_2)
.timeout(Duration::from_secs(30));
if !self.config.verify_hostname {
builder = builder.danger_accept_invalid_hostnames(true);
}
builder
.build()
.map_err(|e| MtlsError::TlsConfig(format!("Failed to build TLS client: {}", e)))
}
pub fn ca(&self) -> &CertificateAuthority {
&self.ca
}
pub fn trust_bundle(&self) -> &[u8] {
self.ca.trust_bundle()
}
pub fn config(&self) -> &MtlsConfig {
&self.config
}
pub fn start_rotation_task(self: &Arc<Self>) -> tokio::task::JoinHandle<()> {
let provider = Arc::clone(self);
tokio::spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(30)).await;
let svid = provider.current_svid.read().await;
if svid.should_renew() {
let serial = svid.serial.clone();
let remaining = svid.remaining();
drop(svid);
info!(
old_serial = %serial,
remaining_secs = remaining.as_secs(),
"🔄 SVID approaching expiry, rotating..."
);
match provider.rotate().await {
Ok(new_svid) => {
info!(
new_serial = %new_svid.serial,
expires_in_secs = new_svid.remaining().as_secs(),
"✅ SVID rotation successful"
);
}
Err(e) => {
error!(
error = %e,
"❌ SVID rotation failed! Current certificate may expire soon."
);
}
}
} else if !svid.is_valid() {
drop(svid);
warn!("⚠️ Current SVID has expired! Attempting emergency rotation...");
match provider.rotate().await {
Ok(new_svid) => {
info!(
new_serial = %new_svid.serial,
"✅ Emergency SVID rotation successful"
);
}
Err(e) => {
error!(
error = %e,
"❌ Emergency SVID rotation failed! mTLS connections will fail."
);
}
}
}
}
})
}
pub async fn status(&self) -> MtlsStatus {
let svid = self.current_svid.read().await;
MtlsStatus {
enabled: self.config.enabled,
trust_domain: self.config.trust_domain.clone(),
spiffe_id: svid.spiffe_id.clone(),
cert_serial: svid.serial.clone(),
cert_valid: svid.is_valid(),
remaining_secs: svid.remaining().as_secs(),
needs_renewal: svid.should_renew(),
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct MtlsStatus {
pub enabled: bool,
pub trust_domain: String,
pub spiffe_id: String,
pub cert_serial: String,
pub cert_valid: bool,
pub remaining_secs: u64,
pub needs_renewal: bool,
}
pub fn issue_subgraph_svid(
ca: &CertificateAuthority,
subgraph_name: &str,
namespace: &str,
ttl: Duration,
) -> Result<Svid, MtlsError> {
ca.issue_svid(namespace, subgraph_name, ttl)
}
pub fn export_svid(svid: &Svid, cert_path: &str, key_path: &str) -> Result<(), MtlsError> {
std::fs::write(cert_path, &svid.cert_pem)?;
std::fs::write(key_path, &svid.key_pem)?;
info!(
spiffe_id = %svid.spiffe_id,
cert_path = %cert_path,
key_path = %key_path,
"📦 Exported subgraph SVID"
);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = MtlsConfig::default();
assert!(!config.enabled);
assert_eq!(config.trust_domain, "cluster.local");
assert_eq!(config.cert_ttl_secs, 3600);
assert_eq!(config.service_name, "router");
assert_eq!(config.namespace, "default");
assert!(!config.allow_fallback);
assert!(config.verify_hostname);
}
#[test]
fn test_ephemeral_ca_creation() {
let ca = CertificateAuthority::new_ephemeral("test.local");
assert!(ca.is_ok(), "Ephemeral CA should be created successfully");
let ca = ca.unwrap();
assert!(!ca.trust_bundle().is_empty());
}
#[test]
fn test_svid_issuance() {
let ca = CertificateAuthority::new_ephemeral("test.local").unwrap();
let svid = ca.issue_svid("default", "my-service", Duration::from_secs(3600));
assert!(svid.is_ok(), "SVID should be issued successfully");
let svid = svid.unwrap();
assert_eq!(
svid.spiffe_id,
"spiffe://test.local/ns/default/sa/my-service"
);
assert!(svid.is_valid());
assert!(!svid.should_renew());
assert!(!svid.cert_pem.is_empty());
assert!(!svid.key_pem.is_empty());
}
#[test]
fn test_svid_identity_creation() {
let ca = CertificateAuthority::new_ephemeral("test.local").unwrap();
let svid = ca
.issue_svid("default", "test-svc", Duration::from_secs(3600))
.unwrap();
let identity = svid.to_identity();
assert!(
identity.is_ok(),
"Identity should be created from valid SVID"
);
}
#[test]
fn test_multiple_svid_serials() {
let ca = CertificateAuthority::new_ephemeral("test.local").unwrap();
let svid1 = ca
.issue_svid("default", "svc1", Duration::from_secs(3600))
.unwrap();
let svid2 = ca
.issue_svid("default", "svc2", Duration::from_secs(3600))
.unwrap();
assert_ne!(
svid1.serial, svid2.serial,
"Each SVID should have a unique serial number"
);
}
#[test]
fn test_trust_bundle_export() {
let ca = CertificateAuthority::new_ephemeral("test.local").unwrap();
let tmp_path = std::env::temp_dir()
.join("test_trust_bundle.pem")
.to_str()
.unwrap()
.to_string();
let result = ca.export_trust_bundle(&tmp_path);
assert!(result.is_ok());
let content = std::fs::read(&tmp_path).unwrap();
assert!(!content.is_empty());
assert!(String::from_utf8_lossy(&content).contains("BEGIN CERTIFICATE"));
let _ = std::fs::remove_file(&tmp_path);
}
#[tokio::test]
async fn test_mtls_provider() {
let config = MtlsConfig {
enabled: true,
trust_domain: "test.local".to_string(),
service_name: "test-router".to_string(),
cert_ttl_secs: 3600,
..Default::default()
};
let provider = MtlsProvider::new(config);
assert!(provider.is_ok(), "Provider should initialize successfully");
let provider = provider.unwrap();
let svid = provider.get_svid().await;
assert!(svid.is_ok());
let status = provider.status().await;
assert!(status.enabled);
assert_eq!(status.trust_domain, "test.local");
assert!(status.cert_valid);
}
#[tokio::test]
async fn test_svid_rotation() {
let config = MtlsConfig {
enabled: true,
cert_ttl_secs: 3600,
..Default::default()
};
let provider = MtlsProvider::new(config).unwrap();
let svid1 = provider.get_svid().await.unwrap();
let svid2 = provider.rotate().await.unwrap();
assert_ne!(
svid1.serial, svid2.serial,
"Rotated SVID should have new serial"
);
}
}