#[cfg(feature = "vault")]
mod inner {
use crate::config::VaultConfig;
use crate::election::{ElectionError, ElectionResult, LeaderElection};
use serde::{Deserialize, Serialize};
use vaultrs::api::kv2::requests::SetSecretRequestOptions;
use vaultrs::client::{VaultClient, VaultClientSettingsBuilder};
use vaultrs::error::ClientError;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct LockPayload {
instance_id: String,
acquired_at: i64,
expires_at: i64,
}
impl LockPayload {
fn is_expired(&self) -> bool {
chrono::Utc::now().timestamp() > self.expires_at
}
}
#[derive(Debug, Clone)]
pub struct VaultElection {
config: VaultConfig,
election_key: String,
instance_id: String,
}
impl VaultElection {
pub fn new(
config: VaultConfig,
election_key: impl Into<String>,
instance_id: impl Into<String>,
) -> Self {
Self {
config,
election_key: election_key.into(),
instance_id: instance_id.into(),
}
}
fn build_client(&self) -> ElectionResult<VaultClient> {
VaultClient::new(
VaultClientSettingsBuilder::default()
.address(&self.config.url)
.token(&self.config.token)
.build()
.map_err(|e| ElectionError::Failed(format!("vault client settings: {e}")))?,
)
.map_err(|e| ElectionError::Failed(format!("vault client build: {e}")))
}
fn mount_and_prefix(&self) -> (&str, &str) {
let path = self.config.kv_path.trim_matches('/');
match path.find('/') {
Some(idx) => (&path[..idx], &path[idx + 1..]),
None => (path, ""),
}
}
fn secret_path(&self) -> String {
let (_, prefix) = self.mount_and_prefix();
if prefix.is_empty() {
self.election_key.clone()
} else {
format!("{}/{}", prefix, self.election_key)
}
}
fn fresh_payload(&self) -> LockPayload {
let now = chrono::Utc::now().timestamp();
LockPayload {
instance_id: self.instance_id.clone(),
acquired_at: now,
expires_at: now + self.config.lease_duration as i64,
}
}
async fn read_lock(
&self,
client: &VaultClient,
) -> ElectionResult<Option<LockPayload>> {
let (mount, _) = self.mount_and_prefix();
let path = self.secret_path();
match vaultrs::kv2::read::<LockPayload>(client, mount, &path).await {
Ok(payload) if payload.is_expired() => Ok(None),
Ok(payload) => Ok(Some(payload)),
Err(ClientError::APIError { code: 404, .. }) => Ok(None),
Err(ClientError::ResponseEmptyError) => Ok(None),
Err(ClientError::ResponseDataEmptyError) => Ok(None),
Err(e) => Err(ElectionError::Failed(format!("vault read: {e}"))),
}
}
async fn read_lock_raw(
&self,
client: &VaultClient,
) -> ElectionResult<Option<LockPayload>> {
let (mount, _) = self.mount_and_prefix();
let path = self.secret_path();
match vaultrs::kv2::read::<LockPayload>(client, mount, &path).await {
Ok(payload) => Ok(Some(payload)),
Err(ClientError::APIError { code: 404, .. }) => Ok(None),
Err(ClientError::ResponseEmptyError) => Ok(None),
Err(ClientError::ResponseDataEmptyError) => Ok(None),
Err(e) => Err(ElectionError::Failed(format!("vault read raw: {e}"))),
}
}
async fn force_delete(&self, client: &VaultClient) -> ElectionResult<()> {
let (mount, _) = self.mount_and_prefix();
let path = self.secret_path();
vaultrs::kv2::delete_metadata(client, mount, &path)
.await
.map_err(|e| ElectionError::Failed(format!("vault force delete: {e}")))
}
}
#[async_trait::async_trait]
impl LeaderElection for VaultElection {
async fn become_leader(&mut self) -> ElectionResult<()> {
let client = self.build_client()?;
let (mount, _) = self.mount_and_prefix();
let path = self.secret_path();
let payload = self.fresh_payload();
let options = SetSecretRequestOptions { cas: 0 };
match vaultrs::kv2::set_with_options(&client, mount, &path, &payload, options).await {
Ok(_) => return Ok(()),
Err(ClientError::APIError { code: 400, .. }) => {
}
Err(e) => return Err(ElectionError::Failed(format!("vault become_leader: {e}"))),
}
match self.read_lock_raw(&client).await? {
Some(existing) if existing.is_expired() => {
self.force_delete(&client).await?;
let retry_options = SetSecretRequestOptions { cas: 0 };
vaultrs::kv2::set_with_options(
&client,
mount,
&path,
&self.fresh_payload(),
retry_options,
)
.await
.map(|_| ())
.map_err(|e| ElectionError::Failed(format!("vault become_leader retry: {e}")))
}
_ => Err(ElectionError::AlreadyLeader),
}
}
async fn is_leader(&self) -> bool {
let Ok(client) = self.build_client() else {
return false;
};
match self.read_lock(&client).await {
Ok(Some(payload)) => payload.instance_id == self.instance_id,
_ => false,
}
}
async fn resign(&mut self) -> ElectionResult<()> {
let client = self.build_client()?;
match self.read_lock(&client).await? {
Some(payload) if payload.instance_id == self.instance_id => {}
Some(_) => return Err(ElectionError::NotLeader),
None => return Err(ElectionError::NotLeader),
}
self.force_delete(&client).await
}
async fn get_leader(&self) -> Option<String> {
let client = self.build_client().ok()?;
self.read_lock(&client)
.await
.ok()
.flatten()
.map(|p| p.instance_id)
}
async fn renew(&mut self) -> ElectionResult<()> {
let client = self.build_client()?;
let (mount, _) = self.mount_and_prefix();
let path = self.secret_path();
let existing = self
.read_lock_raw(&client)
.await?
.ok_or(ElectionError::NotLeader)?;
if existing.instance_id != self.instance_id {
return Err(ElectionError::NotLeader);
}
let meta = vaultrs::kv2::read_metadata(&client, mount, &path)
.await
.map_err(|e| ElectionError::Failed(format!("vault renew read_metadata: {e}")))?;
let current_version = meta.current_version;
let renewed = LockPayload {
instance_id: self.instance_id.clone(),
acquired_at: existing.acquired_at,
expires_at: chrono::Utc::now().timestamp() + self.config.lease_duration as i64,
};
let options = SetSecretRequestOptions { cas: current_version as u32 };
match vaultrs::kv2::set_with_options(&client, mount, &path, &renewed, options).await {
Ok(_) => Ok(()),
Err(ClientError::APIError { code: 400, .. }) => Err(ElectionError::NotLeader),
Err(e) => Err(ElectionError::Failed(format!("vault renew write: {e}"))),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vault_election_new() {
let config = VaultConfig::default();
let election = VaultElection::new(config, "my_election", "instance-1");
assert_eq!(election.election_key, "my_election");
assert_eq!(election.instance_id, "instance-1");
}
#[test]
fn test_mount_and_prefix_with_slash() {
let mut config = VaultConfig::default();
config.kv_path = "secret/sunbeam/elections".to_string();
let election = VaultElection::new(config, "key", "id");
let (mount, prefix) = election.mount_and_prefix();
assert_eq!(mount, "secret");
assert_eq!(prefix, "sunbeam/elections");
}
#[test]
fn test_mount_and_prefix_no_slash() {
let mut config = VaultConfig::default();
config.kv_path = "secret".to_string();
let election = VaultElection::new(config, "key", "id");
let (mount, prefix) = election.mount_and_prefix();
assert_eq!(mount, "secret");
assert_eq!(prefix, "");
}
#[test]
fn test_secret_path_with_prefix() {
let mut config = VaultConfig::default();
config.kv_path = "secret/elections".to_string();
let election = VaultElection::new(config, "my-election", "id");
assert_eq!(election.secret_path(), "elections/my-election");
}
#[test]
fn test_secret_path_no_prefix() {
let mut config = VaultConfig::default();
config.kv_path = "secret".to_string();
let election = VaultElection::new(config, "my-election", "id");
assert_eq!(election.secret_path(), "my-election");
}
#[test]
fn test_lock_payload_not_expired() {
let payload = LockPayload {
instance_id: "inst-1".to_string(),
acquired_at: chrono::Utc::now().timestamp(),
expires_at: chrono::Utc::now().timestamp() + 60,
};
assert!(!payload.is_expired());
}
#[test]
fn test_lock_payload_expired() {
let payload = LockPayload {
instance_id: "inst-1".to_string(),
acquired_at: 1_000_000,
expires_at: 1_000_030,
};
assert!(payload.is_expired());
}
#[test]
fn test_fresh_payload_has_ttl() {
let config = VaultConfig { lease_duration: 45, ..VaultConfig::default() };
let election = VaultElection::new(config, "key", "inst-1");
let p = election.fresh_payload();
assert_eq!(p.instance_id, "inst-1");
let ttl = p.expires_at - p.acquired_at;
assert_eq!(ttl, 45);
}
#[test]
fn test_lock_payload_roundtrip() {
let payload = LockPayload {
instance_id: "node-7".to_string(),
acquired_at: 1_700_000_000,
expires_at: 1_700_000_030,
};
let json = serde_json::to_string(&payload).unwrap();
let decoded: LockPayload = serde_json::from_str(&json).unwrap();
assert_eq!(decoded.instance_id, "node-7");
assert_eq!(decoded.acquired_at, 1_700_000_000);
assert_eq!(decoded.expires_at, 1_700_000_030);
}
}
}
#[cfg(feature = "vault")]
pub use inner::VaultElection;