use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{info, warn};
pub trait HttpsReloader: Send + Sync {
fn reload_https(
&self,
cert_path: &str,
key_path: &str,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<()>> + Send + '_>>;
}
#[async_trait::async_trait]
pub trait SipTlsReloader: Send + Sync {
async fn reload_sip_tls(&self, cert_pem: Vec<u8>, key_pem: Vec<u8>) -> anyhow::Result<()>;
}
pub struct TlsReloaderRegistry {
https: Arc<RwLock<Option<Arc<dyn HttpsReloader>>>>,
sip_tls: Arc<RwLock<Option<Arc<dyn SipTlsReloader>>>>,
}
impl TlsReloaderRegistry {
pub fn new() -> Self {
Self {
https: Arc::new(RwLock::new(None)),
sip_tls: Arc::new(RwLock::new(None)),
}
}
pub async fn register_https(&self, reloader: Arc<dyn HttpsReloader>) {
let mut guard = self.https.write().await;
*guard = Some(reloader);
info!("HTTPS TLS reloader registered");
}
pub async fn register_sip_tls(&self, reloader: Arc<dyn SipTlsReloader>) {
let mut guard = self.sip_tls.write().await;
*guard = Some(reloader);
info!("SIP TLS reloader registered");
}
pub async fn has_https_reloader(&self) -> bool {
self.https.read().await.is_some()
}
pub async fn has_sip_tls_reloader(&self) -> bool {
self.sip_tls.read().await.is_some()
}
pub async fn reload_https(&self, cert_path: &str, key_path: &str) -> anyhow::Result<()> {
let guard = self.https.read().await;
if let Some(reloader) = guard.as_ref() {
info!(
"Reloading HTTPS certificate from {} and {}",
cert_path, key_path
);
reloader.reload_https(cert_path, key_path).await?;
info!("HTTPS certificate reloaded successfully");
Ok(())
} else {
warn!("No HTTPS reloader registered, skipping reload");
Err(anyhow::anyhow!("No HTTPS reloader registered"))
}
}
pub async fn reload_sip_tls(&self, cert_pem: Vec<u8>, key_pem: Vec<u8>) -> anyhow::Result<()> {
let guard = self.sip_tls.read().await;
if let Some(reloader) = guard.as_ref() {
info!("Reloading SIP TLS certificate");
reloader.reload_sip_tls(cert_pem, key_pem).await?;
info!("SIP TLS certificate reloaded successfully");
Ok(())
} else {
warn!("No SIP TLS reloader registered, skipping reload");
Err(anyhow::anyhow!("No SIP TLS reloader registered"))
}
}
}
impl Default for TlsReloaderRegistry {
fn default() -> Self {
Self::new()
}
}
pub struct AxumRustlsReloader {
config: Arc<axum_server::tls_rustls::RustlsConfig>,
}
impl AxumRustlsReloader {
pub fn new(config: Arc<axum_server::tls_rustls::RustlsConfig>) -> Self {
Self { config }
}
}
impl HttpsReloader for AxumRustlsReloader {
fn reload_https(
&self,
cert_path: &str,
key_path: &str,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<()>> + Send + '_>> {
let cert_path = cert_path.to_string();
let key_path = key_path.to_string();
let config = self.config.clone();
Box::pin(async move {
config
.reload_from_pem_file(&cert_path, &key_path)
.await
.map_err(|e| anyhow::anyhow!("Failed to reload HTTPS certificate: {}", e))
})
}
}
pub struct RsipstackTlsReloader {
listener: rsipstack::transport::TlsListenerConnection,
}
impl RsipstackTlsReloader {
pub fn new(listener: rsipstack::transport::TlsListenerConnection) -> Self {
Self { listener }
}
}
#[async_trait::async_trait]
impl SipTlsReloader for RsipstackTlsReloader {
async fn reload_sip_tls(&self, cert_pem: Vec<u8>, key_pem: Vec<u8>) -> anyhow::Result<()> {
self.listener
.reload_tls_config(cert_pem, key_pem)
.await
.map_err(|e| anyhow::anyhow!("Failed to reload SIP TLS certificate: {}", e))
}
}