#[cfg(feature = "nats")]
mod inner {
use crate::config::NatsConfig;
use crate::election::{ElectionError, ElectionResult, LeaderElection};
use async_nats::jetstream::kv::{Config as KvConfig, CreateErrorKind, Store};
use serde::{Deserialize, Serialize};
const ELECTION_BUCKET: &str = "g2v-elections";
#[derive(Debug, Clone, Serialize, Deserialize)]
struct LockPayload {
instance_id: String,
acquired_at: i64,
expires_at: i64,
}
impl LockPayload {
fn is_expired(&self) -> bool {
chrono::Utc::now().timestamp() > self.expires_at
}
}
#[derive(Debug, Clone)]
pub struct NatsElection {
config: NatsConfig,
election_key: String,
instance_id: String,
}
impl NatsElection {
pub fn new(
config: NatsConfig,
election_key: impl Into<String>,
instance_id: impl Into<String>,
) -> Self {
Self {
config,
election_key: election_key.into(),
instance_id: instance_id.into(),
}
}
async fn bucket(&self) -> ElectionResult<Store> {
let client = async_nats::ConnectOptions::new()
.connect(&self.config.url)
.await
.map_err(|e| ElectionError::Failed(format!("nats connect: {e}")))?;
let js = async_nats::jetstream::new(client);
let store = match js
.create_key_value(KvConfig {
bucket: ELECTION_BUCKET.to_string(),
description: "sunbeam-g2v leader election locks".to_string(),
history: 1,
..Default::default()
})
.await
{
Ok(store) => store,
Err(_) => js
.get_key_value(ELECTION_BUCKET)
.await
.map_err(|e| ElectionError::Failed(format!("nats kv get bucket: {e}")))?,
};
Ok(store)
}
fn encode_fresh_payload(&self) -> ElectionResult<bytes::Bytes> {
let now = chrono::Utc::now().timestamp();
let payload = LockPayload {
instance_id: self.instance_id.clone(),
acquired_at: now,
expires_at: now + self.config.lease_duration as i64,
};
let json = serde_json::to_vec(&payload)
.map_err(|e| ElectionError::Failed(format!("nats payload encode: {e}")))?;
Ok(bytes::Bytes::from(json))
}
fn decode_payload(bytes: bytes::Bytes) -> Option<LockPayload> {
serde_json::from_slice(&bytes).ok()
}
}
#[async_trait::async_trait]
impl LeaderElection for NatsElection {
async fn become_leader(&mut self) -> ElectionResult<()> {
let kv = self.bucket().await?;
let payload = self.encode_fresh_payload()?;
match kv.create(&self.election_key, payload).await {
Ok(_) => return Ok(()),
Err(e) if e.kind() == CreateErrorKind::AlreadyExists => {
}
Err(e) => return Err(ElectionError::Failed(format!("nats become_leader: {e}"))),
}
match kv.get(&self.election_key).await {
Ok(Some(bytes)) => {
let existing = Self::decode_payload(bytes)
.ok_or_else(|| ElectionError::Failed("nats: corrupt lock payload".into()))?;
if !existing.is_expired() {
return Err(ElectionError::AlreadyLeader);
}
kv.delete(&self.election_key).await.map_err(|e| {
ElectionError::Failed(format!("nats preempt expired delete: {e}"))
})?;
let retry_payload = self.encode_fresh_payload()?;
kv.create(&self.election_key, retry_payload)
.await
.map(|_| ())
.map_err(|e| {
ElectionError::Failed(format!("nats become_leader retry: {e}"))
})
}
Ok(None) => {
let retry_payload = self.encode_fresh_payload()?;
kv.create(&self.election_key, retry_payload)
.await
.map(|_| ())
.map_err(|e| {
ElectionError::Failed(format!("nats become_leader retry: {e}"))
})
}
Err(e) => Err(ElectionError::Failed(format!("nats become_leader read: {e}"))),
}
}
async fn is_leader(&self) -> bool {
let Ok(kv) = self.bucket().await else {
return false;
};
match kv.get(&self.election_key).await {
Ok(Some(bytes)) => Self::decode_payload(bytes)
.map(|p| p.instance_id == self.instance_id && !p.is_expired())
.unwrap_or(false),
_ => false,
}
}
async fn resign(&mut self) -> ElectionResult<()> {
let kv = self.bucket().await?;
match kv.get(&self.election_key).await {
Ok(Some(bytes)) => {
let payload = Self::decode_payload(bytes)
.ok_or_else(|| ElectionError::Failed("nats: corrupt lock payload".into()))?;
if payload.instance_id != self.instance_id || payload.is_expired() {
return Err(ElectionError::NotLeader);
}
}
Ok(None) => return Err(ElectionError::NotLeader),
Err(e) => return Err(ElectionError::Failed(format!("nats resign read: {e}"))),
}
kv.delete(&self.election_key)
.await
.map_err(|e| ElectionError::Failed(format!("nats resign delete: {e}")))?;
Ok(())
}
async fn get_leader(&self) -> Option<String> {
let kv = self.bucket().await.ok()?;
let bytes = kv.get(&self.election_key).await.ok()??;
Self::decode_payload(bytes)
.filter(|p| !p.is_expired())
.map(|p| p.instance_id)
}
async fn renew(&mut self) -> ElectionResult<()> {
let kv = self.bucket().await?;
let existing = match kv.get(&self.election_key).await {
Ok(Some(bytes)) => Self::decode_payload(bytes)
.ok_or_else(|| ElectionError::Failed("nats renew: corrupt payload".into()))?,
Ok(None) => return Err(ElectionError::NotLeader),
Err(e) => return Err(ElectionError::Failed(format!("nats renew read: {e}"))),
};
if existing.instance_id != self.instance_id || existing.is_expired() {
return Err(ElectionError::NotLeader);
}
let renewed = LockPayload {
instance_id: self.instance_id.clone(),
acquired_at: existing.acquired_at,
expires_at: chrono::Utc::now().timestamp() + self.config.lease_duration as i64,
};
let json = serde_json::to_vec(&renewed)
.map_err(|e| ElectionError::Failed(format!("nats renew encode: {e}")))?;
kv.put(&self.election_key, bytes::Bytes::from(json))
.await
.map(|_| ())
.map_err(|e| ElectionError::Failed(format!("nats renew put: {e}")))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_nats_election_new() {
let config = NatsConfig::default();
let election = NatsElection::new(config, "my_election", "instance-1");
assert_eq!(election.election_key, "my_election");
assert_eq!(election.instance_id, "instance-1");
}
#[test]
fn test_decode_payload_roundtrip() {
let payload = LockPayload {
instance_id: "node-42".to_string(),
acquired_at: 1_700_000_000,
expires_at: 1_700_000_030,
};
let encoded = serde_json::to_vec(&payload).unwrap();
let decoded = NatsElection::decode_payload(bytes::Bytes::from(encoded)).unwrap();
assert_eq!(decoded.instance_id, "node-42");
assert_eq!(decoded.acquired_at, 1_700_000_000);
assert_eq!(decoded.expires_at, 1_700_000_030);
}
#[test]
fn test_decode_payload_bad_bytes() {
let result = NatsElection::decode_payload(bytes::Bytes::from(b"not-json".as_slice()));
assert!(result.is_none());
}
#[test]
fn test_lock_payload_not_expired() {
let payload = LockPayload {
instance_id: "inst-1".to_string(),
acquired_at: chrono::Utc::now().timestamp(),
expires_at: chrono::Utc::now().timestamp() + 60,
};
assert!(!payload.is_expired());
}
#[test]
fn test_lock_payload_expired() {
let payload = LockPayload {
instance_id: "inst-1".to_string(),
acquired_at: 1_000_000,
expires_at: 1_000_030,
};
assert!(payload.is_expired());
}
}
}
#[cfg(feature = "nats")]
pub use inner::NatsElection;