use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use breaker_machines::{CircuitBreaker, Config as CircuitConfig};
use chrono_machines::{BackoffStrategy, ExponentialBackoff};
use native_tls::TlsConnector;
use postgres_native_tls::MakeTlsConnector;
use rand::{SeedableRng, rngs::SmallRng};
use tokio::sync::Mutex;
use tokio_postgres::{Client, NoTls};
use tracing::{debug, info, warn};
use super::election::{Election, ElectionError, FlagshipSignal};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum SslMode {
Disable,
Allow,
Prefer,
Require,
VerifyCa,
VerifyFull,
}
pub struct PostgresElection {
connection_string: String,
election_client: Mutex<Option<Client>>,
is_flagship: AtomicBool,
instance_id: String,
circuit_breaker: Arc<tokio::sync::Mutex<CircuitBreaker>>,
}
impl PostgresElection {
pub fn new(connection_string: String) -> Self {
let hostname = std::env::var("HOSTNAME")
.or_else(|_| std::env::var("HOST"))
.unwrap_or_else(|_| "unknown".to_string());
let instance_id = format!("{}-{}", hostname, std::process::id());
let config = CircuitConfig {
failure_threshold: Some(10), failure_rate_threshold: None,
minimum_calls: 5,
failure_window_secs: 60.0,
half_open_timeout_secs: 300.0, success_threshold: 1, jitter_factor: 0.1,
};
let circuit_breaker = Arc::new(tokio::sync::Mutex::new(CircuitBreaker::new(
"postgres_election".to_string(),
config,
)));
Self {
connection_string,
election_client: Mutex::new(None),
is_flagship: AtomicBool::new(false),
instance_id,
circuit_breaker,
}
}
fn sslmode_from_url(connection_string: &str) -> SslMode {
if let Some(query) = connection_string.split_once('?').map(|(_, q)| q) {
for pair in query.split('&') {
if pair.is_empty() {
continue;
}
let mut parts = pair.splitn(2, '=');
let key = parts.next().unwrap_or("");
if !key.eq_ignore_ascii_case("sslmode") {
continue;
}
let value = parts.next().unwrap_or("").to_ascii_lowercase();
return match value.as_str() {
"disable" => SslMode::Disable,
"allow" => SslMode::Allow,
"prefer" => SslMode::Prefer,
"require" => SslMode::Require,
"verify-ca" | "verify_ca" => SslMode::VerifyCa,
"verify-full" | "verify_full" => SslMode::VerifyFull,
_ => SslMode::Require,
};
}
}
SslMode::Require }
fn allow_insecure_sslmode() -> bool {
std::env::var("MOTHERSHIP_ALLOW_INSECURE_POSTGRES_SSLMODE")
.map(|v| matches!(v.to_ascii_lowercase().as_str(), "1" | "true" | "yes"))
.unwrap_or(false)
}
async fn connect_tls(connection_string: &str) -> Result<Client, ElectionError> {
let connector = TlsConnector::builder()
.build()
.map_err(|e| ElectionError::Connection(format!("TLS setup failed: {}", e)))?;
let tls = MakeTlsConnector::new(connector);
let (client, connection) = tokio_postgres::connect(connection_string, tls)
.await
.map_err(|e| ElectionError::Connection(e.to_string()))?;
tokio::spawn(async move {
if let Err(e) = connection.await {
warn!(error = %e, "postgres TLS connection error");
}
});
Ok(client)
}
async fn connect_no_tls(connection_string: &str) -> Result<Client, ElectionError> {
let (client, connection) = tokio_postgres::connect(connection_string, NoTls)
.await
.map_err(|e| ElectionError::Connection(e.to_string()))?;
tokio::spawn(async move {
if let Err(e) = connection.await {
warn!(error = %e, "postgres connection error");
}
});
Ok(client)
}
async fn connect(&self) -> Result<Client, ElectionError> {
{
let breaker = self.circuit_breaker.lock().await;
if breaker.is_open() {
warn!("PostgreSQL circuit breaker is open - skipping connection attempt");
return Err(ElectionError::CircuitBreakerOpen);
}
}
let start = std::time::Instant::now();
let sslmode = Self::sslmode_from_url(&self.connection_string);
if matches!(sslmode, SslMode::Disable | SslMode::Allow | SslMode::Prefer)
&& !Self::allow_insecure_sslmode()
{
return Err(ElectionError::Config(
"insecure sslmode for flagship postgres election is blocked; use sslmode=require/verify-ca/verify-full or set MOTHERSHIP_ALLOW_INSECURE_POSTGRES_SSLMODE=true to override".to_string(),
));
}
let backoff = ExponentialBackoff::new()
.base_delay_ms(200)
.max_delay_ms(5000)
.max_attempts(5);
let mut rng = SmallRng::from_os_rng();
let mut attempt = 0u8;
loop {
attempt += 1;
let result = match sslmode {
SslMode::Disable => Self::connect_no_tls(&self.connection_string).await,
SslMode::Allow => {
match Self::connect_no_tls(&self.connection_string).await {
Ok(client) => Ok(client),
Err(no_tls_err) => Self::connect_tls(&self.connection_string)
.await
.map_err(|tls_err| {
ElectionError::Connection(format!(
"non-TLS failed ({}); TLS failed ({})",
no_tls_err, tls_err
))
}),
}
}
SslMode::Prefer => {
match Self::connect_tls(&self.connection_string).await {
Ok(client) => {
debug!("connected with TLS");
Ok(client)
}
Err(tls_err) => {
debug!(error = %tls_err, "TLS connection failed, trying without TLS");
Self::connect_no_tls(&self.connection_string).await.map_err(
|no_tls_err| {
ElectionError::Connection(format!(
"TLS failed ({}); non-TLS failed ({})",
tls_err, no_tls_err
))
},
)
}
}
}
SslMode::Require | SslMode::VerifyCa | SslMode::VerifyFull => {
Self::connect_tls(&self.connection_string).await
}
};
match result {
Ok(client) => {
let elapsed = start.elapsed().as_secs_f64();
if attempt > 1 {
info!(
attempt = attempt,
elapsed_secs = elapsed,
"PostgreSQL connected after retry"
);
}
self.circuit_breaker.lock().await.record_success(elapsed);
return Ok(client);
}
Err(last_error) => {
match backoff.delay(attempt, &mut rng) {
Some(delay_ms) => {
warn!(attempt = attempt, delay_ms = delay_ms, error = %last_error, "PostgreSQL connection failed, retrying");
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
None => {
let elapsed = start.elapsed().as_secs_f64();
let mut breaker = self.circuit_breaker.lock().await;
breaker.record_failure_and_maybe_trip(elapsed);
warn!(
attempts = attempt,
elapsed_secs = elapsed,
circuit_state = breaker.state_name(),
"PostgreSQL connection failed - circuit breaker recorded failure"
);
return Err(last_error);
}
}
}
}
}
}
async fn ensure_table(&self, client: &Client) -> Result<(), ElectionError> {
client
.execute(
r#"
CREATE TABLE IF NOT EXISTS mothership_flagship (
app_name VARCHAR(255) PRIMARY KEY,
status VARCHAR(50) NOT NULL,
instance_id VARCHAR(255),
updated_at TIMESTAMP DEFAULT NOW()
)
"#,
&[],
)
.await
.map_err(|e| ElectionError::Connection(e.to_string()))?;
Ok(())
}
fn lock_key(app_name: &str) -> i64 {
let mut hash: i64 = 0;
for byte in format!("mothership:flagship:{}", app_name).bytes() {
hash = hash.wrapping_mul(31).wrapping_add(byte as i64);
}
hash
}
fn channel_name(app_name: &str) -> String {
format!("mothership_{}", app_name.replace('-', "_"))
}
}
impl Election for PostgresElection {
async fn try_acquire(&self, app_name: &str) -> Result<bool, ElectionError> {
let client = self.connect().await?;
self.ensure_table(&client).await?;
let lock_key = Self::lock_key(app_name);
let row = client
.query_one("SELECT pg_try_advisory_lock($1) as acquired", &[&lock_key])
.await
.map_err(|e| ElectionError::Connection(e.to_string()))?;
let acquired: bool = row.get("acquired");
if acquired {
info!(app = %app_name, instance = %self.instance_id, "acquired flagship lock");
self.is_flagship.store(true, Ordering::SeqCst);
*self.election_client.lock().await = Some(client);
if let Some(ref client) = *self.election_client.lock().await {
let _ = client
.execute(
r#"
INSERT INTO mothership_flagship (app_name, status, instance_id, updated_at)
VALUES ($1, 'running', $2, NOW())
ON CONFLICT (app_name) DO UPDATE SET
status = 'running',
instance_id = $2,
updated_at = NOW()
"#,
&[&app_name, &self.instance_id],
)
.await;
}
} else {
debug!(app = %app_name, "flagship lock held by another instance");
}
Ok(acquired)
}
async fn release(&self, app_name: &str) -> Result<(), ElectionError> {
if !self.is_flagship.load(Ordering::SeqCst) {
return Ok(());
}
let lock_key = Self::lock_key(app_name);
if let Some(ref client) = *self.election_client.lock().await {
let _ = client
.execute("SELECT pg_advisory_unlock($1)", &[&lock_key])
.await;
let _ = client
.execute(
"DELETE FROM mothership_flagship WHERE app_name = $1",
&[&app_name],
)
.await;
}
self.is_flagship.store(false, Ordering::SeqCst);
*self.election_client.lock().await = None;
info!(app = %app_name, "released flagship lock");
Ok(())
}
async fn signal(&self, app_name: &str, status: FlagshipSignal) -> Result<(), ElectionError> {
if !self.is_flagship.load(Ordering::SeqCst) {
return Ok(());
}
let status_str = match status {
FlagshipSignal::Running => "running",
FlagshipSignal::Ready => "ready",
FlagshipSignal::Abort => "abort",
};
if let Some(ref client) = *self.election_client.lock().await {
client
.execute(
r#"
UPDATE mothership_flagship
SET status = $2, updated_at = NOW()
WHERE app_name = $1
"#,
&[&app_name, &status_str],
)
.await
.map_err(|e| ElectionError::Connection(e.to_string()))?;
let channel = Self::channel_name(app_name);
client
.execute(&format!("NOTIFY {}, '{}'", channel, status_str), &[])
.await
.map_err(|e| ElectionError::Connection(e.to_string()))?;
info!(app = %app_name, status = %status_str, "signaled escorts");
}
Ok(())
}
async fn wait_for_signal(
&self,
app_name: &str,
timeout: Duration,
) -> Result<FlagshipSignal, ElectionError> {
let client = self.connect().await?;
let channel = Self::channel_name(app_name);
client
.execute(&format!("LISTEN {}", channel), &[])
.await
.map_err(|e| ElectionError::Connection(e.to_string()))?;
if let Some(signal) = self.get_signal(app_name).await? {
match signal {
FlagshipSignal::Ready | FlagshipSignal::Abort => return Ok(signal),
FlagshipSignal::Running => {} }
}
let deadline = tokio::time::Instant::now() + timeout;
loop {
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
if remaining.is_zero() {
return Err(ElectionError::Timeout);
}
match tokio::time::timeout(remaining.min(Duration::from_secs(5)), async {
if let Ok(Some(signal)) = self.get_signal(app_name).await {
match signal {
FlagshipSignal::Ready | FlagshipSignal::Abort => return Some(signal),
FlagshipSignal::Running => {}
}
}
None
})
.await
{
Ok(Some(signal)) => return Ok(signal),
Ok(None) => continue,
Err(_) => continue, }
}
}
async fn get_signal(&self, app_name: &str) -> Result<Option<FlagshipSignal>, ElectionError> {
let client = self.connect().await?;
let row = client
.query_opt(
"SELECT status FROM mothership_flagship WHERE app_name = $1",
&[&app_name],
)
.await
.map_err(|e| ElectionError::Connection(e.to_string()))?;
Ok(row.and_then(|r| {
let status: String = r.get("status");
match status.as_str() {
"running" => Some(FlagshipSignal::Running),
"ready" => Some(FlagshipSignal::Ready),
"abort" => Some(FlagshipSignal::Abort),
_ => None,
}
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lock_key_generation() {
let key1 = PostgresElection::lock_key("myapp");
let key2 = PostgresElection::lock_key("myapp");
let key3 = PostgresElection::lock_key("otherapp");
assert_eq!(key1, key2); assert_ne!(key1, key3); }
#[test]
fn test_channel_name() {
assert_eq!(
PostgresElection::channel_name("my-app"),
"mothership_my_app"
);
assert_eq!(PostgresElection::channel_name("app"), "mothership_app");
}
#[test]
fn test_sslmode_parsing() {
assert_eq!(
PostgresElection::sslmode_from_url("postgres://localhost/db"),
SslMode::Require
);
assert_eq!(
PostgresElection::sslmode_from_url("postgres://localhost/db?sslmode=disable"),
SslMode::Disable
);
assert_eq!(
PostgresElection::sslmode_from_url("postgres://localhost/db?sslmode=require"),
SslMode::Require
);
assert_eq!(
PostgresElection::sslmode_from_url("postgres://localhost/db?foo=bar&sslmode=prefer"),
SslMode::Prefer
);
assert_eq!(
PostgresElection::sslmode_from_url("postgres://localhost/db?sslmode=verify-full"),
SslMode::VerifyFull
);
assert_eq!(
PostgresElection::sslmode_from_url("postgres://localhost/db?sslmode=unknown"),
SslMode::Require
);
}
#[test]
fn test_allow_insecure_sslmode_env() {
unsafe {
std::env::remove_var("MOTHERSHIP_ALLOW_INSECURE_POSTGRES_SSLMODE");
}
assert!(!PostgresElection::allow_insecure_sslmode());
unsafe {
std::env::set_var("MOTHERSHIP_ALLOW_INSECURE_POSTGRES_SSLMODE", "true");
}
assert!(PostgresElection::allow_insecure_sslmode());
}
#[tokio::test]
async fn test_circuit_breaker_initialization() {
let election = PostgresElection::new("postgres://localhost/test".to_string());
let breaker = election.circuit_breaker.lock().await;
assert!(!breaker.is_open());
assert!(breaker.is_closed());
}
#[tokio::test]
async fn test_circuit_breaker_opens_after_failures() {
let election = PostgresElection::new("postgres://invalid:9999/test".to_string());
let result = election.connect().await;
assert!(result.is_err());
let breaker = election.circuit_breaker.lock().await;
assert!(!breaker.is_open());
}
#[tokio::test]
async fn test_circuit_breaker_prevents_connection_when_open() {
let election = PostgresElection::new("postgres://invalid:9999/test".to_string());
{
let mut breaker = election.circuit_breaker.lock().await;
for _ in 0..10 {
breaker.record_failure_and_maybe_trip(1.0);
}
assert!(
breaker.is_open(),
"Circuit breaker should be open after 10 failures"
);
}
let result = election.connect().await;
match result {
Err(ElectionError::CircuitBreakerOpen) => {
}
_ => panic!("Expected CircuitBreakerOpen error, got {:?}", result),
}
}
}