use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use chrono::Utc;
use forge_core::cluster::{LeaderInfo, LeaderRole, NodeId};
use tokio::sync::{Mutex, watch};
#[derive(Debug, Clone)]
pub struct LeaderConfig {
pub check_interval: Duration,
pub lease_duration: Duration,
pub refresh_interval: Duration,
}
impl Default for LeaderConfig {
fn default() -> Self {
Self {
check_interval: Duration::from_secs(5),
lease_duration: Duration::from_secs(60),
refresh_interval: Duration::from_secs(30),
}
}
}
pub struct LeaderElection {
pool: sqlx::PgPool,
node_id: NodeId,
role: LeaderRole,
config: LeaderConfig,
is_leader: Arc<AtomicBool>,
lock_connection: Arc<Mutex<Option<sqlx::pool::PoolConnection<sqlx::Postgres>>>>,
shutdown_tx: watch::Sender<bool>,
shutdown_rx: watch::Receiver<bool>,
}
impl LeaderElection {
pub fn new(
pool: sqlx::PgPool,
node_id: NodeId,
role: LeaderRole,
config: LeaderConfig,
) -> Self {
let (shutdown_tx, shutdown_rx) = watch::channel(false);
Self {
pool,
node_id,
role,
config,
is_leader: Arc::new(AtomicBool::new(false)),
lock_connection: Arc::new(Mutex::new(None)),
shutdown_tx,
shutdown_rx,
}
}
pub fn is_leader(&self) -> bool {
self.is_leader.load(Ordering::SeqCst)
}
pub fn shutdown_receiver(&self) -> watch::Receiver<bool> {
self.shutdown_rx.clone()
}
pub fn stop(&self) {
let _ = self.shutdown_tx.send(true);
}
pub async fn try_become_leader(&self) -> forge_core::Result<bool> {
if self.is_leader() {
return Ok(true);
}
let mut conn = self
.pool
.acquire()
.await
.map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
let acquired = sqlx::query_scalar!(
r#"SELECT pg_try_advisory_lock($1) as "acquired!""#,
self.role.lock_id()
)
.fetch_one(&mut *conn)
.await
.map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
super::metrics::record_leader_election_attempt(self.role.as_str(), acquired);
if acquired {
let lease_until =
Utc::now() + chrono::Duration::seconds(self.config.lease_duration.as_secs() as i64);
sqlx::query!(
r#"
INSERT INTO forge_leaders (role, node_id, acquired_at, lease_until)
VALUES ($1, $2, NOW(), $3)
ON CONFLICT (role) DO UPDATE SET
node_id = EXCLUDED.node_id,
acquired_at = NOW(),
lease_until = EXCLUDED.lease_until
"#,
self.role.as_str(),
self.node_id.as_uuid(),
lease_until,
)
.execute(&self.pool)
.await
.map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
self.is_leader.store(true, Ordering::SeqCst);
super::metrics::set_is_leader(self.role.as_str(), true);
*self.lock_connection.lock().await = Some(conn);
tracing::info!(role = self.role.as_str(), "Acquired leadership");
}
Ok(acquired)
}
pub async fn refresh_lease(&self) -> forge_core::Result<()> {
if !self.is_leader() {
return Ok(());
}
let lease_until =
Utc::now() + chrono::Duration::seconds(self.config.lease_duration.as_secs() as i64);
sqlx::query!(
r#"
UPDATE forge_leaders
SET lease_until = $3
WHERE role = $1 AND node_id = $2
"#,
self.role.as_str(),
self.node_id.as_uuid(),
lease_until,
)
.execute(&self.pool)
.await
.map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
Ok(())
}
pub async fn release_leadership(&self) -> forge_core::Result<()> {
if !self.is_leader() {
return Ok(());
}
let mut lock_connection = self.lock_connection.lock().await;
if let Some(mut conn) = lock_connection.take() {
sqlx::query_scalar!("SELECT pg_advisory_unlock($1)", self.role.lock_id())
.fetch_one(&mut *conn)
.await
.map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
} else {
tracing::warn!(
role = self.role.as_str(),
"Leader lock connection missing during release"
);
}
drop(lock_connection);
sqlx::query!(
r#"
DELETE FROM forge_leaders
WHERE role = $1 AND node_id = $2
"#,
self.role.as_str(),
self.node_id.as_uuid(),
)
.execute(&self.pool)
.await
.map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
self.is_leader.store(false, Ordering::SeqCst);
super::metrics::set_is_leader(self.role.as_str(), false);
tracing::info!(role = self.role.as_str(), "Released leadership");
Ok(())
}
pub async fn check_leader_health(&self) -> forge_core::Result<bool> {
let result = sqlx::query_scalar!(
"SELECT lease_until FROM forge_leaders WHERE role = $1",
self.role.as_str()
)
.fetch_optional(&self.pool)
.await
.map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
match result {
Some(lease_until) => Ok(lease_until > Utc::now()),
None => Ok(false), }
}
pub async fn get_leader(&self) -> forge_core::Result<Option<LeaderInfo>> {
let row = sqlx::query!(
r#"
SELECT role, node_id, acquired_at, lease_until
FROM forge_leaders
WHERE role = $1
"#,
self.role.as_str(),
)
.fetch_optional(&self.pool)
.await
.map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
match row {
Some(row) => {
let role = row.role.parse().unwrap_or(LeaderRole::Scheduler);
Ok(Some(LeaderInfo {
role,
node_id: NodeId::from_uuid(row.node_id),
acquired_at: row.acquired_at,
lease_until: row.lease_until,
}))
}
None => Ok(None),
}
}
pub async fn run(&self) {
let mut shutdown_rx = self.shutdown_rx.clone();
loop {
tokio::select! {
_ = tokio::time::sleep(self.config.check_interval) => {
if self.is_leader() {
if let Err(e) = self.refresh_lease().await {
tracing::debug!(error = %e, "Failed to refresh lease");
}
} else {
match self.check_leader_health().await {
Ok(false) => {
if let Err(e) = self.try_become_leader().await {
tracing::debug!(error = %e, "Failed to acquire leadership");
}
}
Ok(true) => {
}
Err(e) => {
tracing::debug!(error = %e, "Failed to check leader health");
}
}
}
}
_ = shutdown_rx.changed() => {
if *shutdown_rx.borrow() {
tracing::debug!("Leader election shutting down");
if let Err(e) = self.release_leadership().await {
tracing::debug!(error = %e, "Failed to release leadership");
}
break;
}
}
}
}
}
}
pub struct LeaderGuard<'a> {
election: &'a LeaderElection,
}
impl<'a> LeaderGuard<'a> {
pub fn try_new(election: &'a LeaderElection) -> Option<Self> {
if election.is_leader() {
Some(Self { election })
} else {
None
}
}
pub fn is_leader(&self) -> bool {
self.election.is_leader()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_leader_config_default() {
let config = LeaderConfig::default();
assert_eq!(config.check_interval, Duration::from_secs(5));
assert_eq!(config.lease_duration, Duration::from_secs(60));
assert_eq!(config.refresh_interval, Duration::from_secs(30));
}
}