use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use async_trait::async_trait;
use tracing::{error, warn};
use crate::auth::token_state::{
JoinTokenLifecycle, JoinTokenState, SharedTokenStateMirror, TokenStateBackend, TokenStateError,
};
use crate::decommission::coordinator::MetadataProposer;
use crate::metadata_group::entry::{JoinTokenTransitionKind, MetadataEntry};
fn epoch_ms() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
pub struct RaftBackedTokenStore {
proposer: Arc<dyn MetadataProposer>,
mirror: SharedTokenStateMirror,
}
impl RaftBackedTokenStore {
pub fn new(proposer: Arc<dyn MetadataProposer>, mirror: SharedTokenStateMirror) -> Self {
Self { proposer, mirror }
}
async fn propose(
&self,
token_hash: [u8; 32],
transition: JoinTokenTransitionKind,
) -> Result<(), TokenStateError> {
let entry = MetadataEntry::JoinTokenTransition {
token_hash,
transition,
ts_ms: epoch_ms(),
};
self.proposer
.propose_and_wait(entry)
.await
.map(|_| ())
.map_err(|e| {
error!(error = %e, "RaftBackedTokenStore: proposer error");
TokenStateError::ProposerError {
detail: e.to_string(),
}
})
}
fn read_state(&self, token_hash: &[u8; 32]) -> Option<JoinTokenState> {
let map = self.mirror.lock().unwrap_or_else(|p| p.into_inner());
map.get(token_hash).cloned()
}
}
#[async_trait]
impl TokenStateBackend for RaftBackedTokenStore {
async fn register(&self, state: JoinTokenState) {
let hash = state.token_hash;
let expires_at_ms = state.expires_at_ms;
if let Err(e) = self
.propose(hash, JoinTokenTransitionKind::Register { expires_at_ms })
.await
{
error!(
error = %e,
token_hash = ?hash,
"RaftBackedTokenStore: register failed — token not replicated"
);
}
}
async fn begin_inflight(
&self,
token_hash: &[u8; 32],
node_addr: SocketAddr,
) -> Result<(), TokenStateError> {
if let Some(current) = self.read_state(token_hash) {
match ¤t.lifecycle {
JoinTokenLifecycle::Consumed { .. } => {
return Err(TokenStateError::AlreadyConsumed);
}
JoinTokenLifecycle::Expired => return Err(TokenStateError::Expired),
JoinTokenLifecycle::Aborted => return Err(TokenStateError::Aborted),
_ => {}
}
}
let addr_str = node_addr.to_string();
self.propose(
*token_hash,
JoinTokenTransitionKind::BeginInFlight {
node_addr: addr_str.clone(),
},
)
.await?;
match self.read_state(token_hash) {
None => Err(TokenStateError::NotFound),
Some(s) => match &s.lifecycle {
JoinTokenLifecycle::InFlight { node_addr: winner } => {
if *winner == node_addr {
Ok(())
} else {
Err(TokenStateError::InFlightConflict)
}
}
JoinTokenLifecycle::Consumed { .. } => Err(TokenStateError::AlreadyConsumed),
JoinTokenLifecycle::Expired => Err(TokenStateError::Expired),
JoinTokenLifecycle::Aborted => Err(TokenStateError::Aborted),
JoinTokenLifecycle::Issued => {
warn!(
token_hash = ?token_hash,
"RaftBackedTokenStore: begin_inflight propose succeeded but state remained Issued"
);
Err(TokenStateError::InvalidTransition)
}
},
}
}
async fn mark_consumed(
&self,
token_hash: &[u8; 32],
node_addr: SocketAddr,
_ts_ms: u64,
) -> Result<(), TokenStateError> {
let addr_str = node_addr.to_string();
self.propose(
*token_hash,
JoinTokenTransitionKind::MarkConsumed {
node_addr: addr_str,
},
)
.await?;
match self.read_state(token_hash) {
None => Err(TokenStateError::NotFound),
Some(s) => match &s.lifecycle {
JoinTokenLifecycle::Consumed { .. } => Ok(()),
_ => Err(TokenStateError::InvalidTransition),
},
}
}
async fn revert_inflight(&self, token_hash: &[u8; 32]) -> Result<(), TokenStateError> {
self.propose(*token_hash, JoinTokenTransitionKind::RevertInFlight)
.await?;
match self.read_state(token_hash) {
None => Err(TokenStateError::NotFound),
Some(s) => match &s.lifecycle {
JoinTokenLifecycle::Issued => Ok(()),
_ => Err(TokenStateError::InvalidTransition),
},
}
}
fn get(&self, token_hash: &[u8; 32]) -> Option<JoinTokenState> {
self.read_state(token_hash)
}
}
pub fn apply_token_transition_to_mirror(
mirror: &SharedTokenStateMirror,
token_hash: [u8; 32],
transition: &JoinTokenTransitionKind,
ts_ms: u64,
) {
let mut map = mirror.lock().unwrap_or_else(|p| p.into_inner());
match transition {
JoinTokenTransitionKind::Register { expires_at_ms } => {
map.entry(token_hash).or_insert_with(|| JoinTokenState {
token_hash,
lifecycle: JoinTokenLifecycle::Issued,
expires_at_ms: *expires_at_ms,
attempt: 0,
});
}
JoinTokenTransitionKind::BeginInFlight { node_addr } => {
let Ok(parsed_addr) = node_addr.parse::<SocketAddr>() else {
error!(
token_hash = ?token_hash,
addr = %node_addr,
"apply BeginInFlight: invalid address — skipping"
);
return;
};
if let Some(entry) = map.get_mut(&token_hash) {
match &entry.lifecycle {
JoinTokenLifecycle::Issued => {
entry.lifecycle = JoinTokenLifecycle::InFlight {
node_addr: parsed_addr,
};
entry.attempt += 1;
}
JoinTokenLifecycle::InFlight {
node_addr: existing,
} if *existing == parsed_addr => {
}
other => {
error!(
token_hash = ?token_hash,
?other,
"apply BeginInFlight: unexpected lifecycle — log corruption?"
);
}
}
} else {
error!(
token_hash = ?token_hash,
"apply BeginInFlight: token not in mirror (Register must precede this)"
);
}
}
JoinTokenTransitionKind::MarkConsumed { node_addr } => {
let Ok(parsed_addr) = node_addr.parse::<SocketAddr>() else {
error!(
token_hash = ?token_hash,
addr = %node_addr,
"apply MarkConsumed: invalid address — skipping"
);
return;
};
if let Some(entry) = map.get_mut(&token_hash) {
match &entry.lifecycle {
JoinTokenLifecycle::InFlight { .. } => {
entry.lifecycle = JoinTokenLifecycle::Consumed {
node_addr: parsed_addr,
ts_ms,
};
}
JoinTokenLifecycle::Consumed { .. } => {
}
other => {
error!(
token_hash = ?token_hash,
?other,
"apply MarkConsumed: unexpected lifecycle"
);
}
}
} else {
error!(token_hash = ?token_hash, "apply MarkConsumed: token not in mirror");
}
}
JoinTokenTransitionKind::RevertInFlight => {
if let Some(entry) = map.get_mut(&token_hash) {
match &entry.lifecycle {
JoinTokenLifecycle::InFlight { .. } => {
entry.lifecycle = JoinTokenLifecycle::Issued;
}
JoinTokenLifecycle::Issued => {
}
other => {
error!(
token_hash = ?token_hash,
?other,
"apply RevertInFlight: unexpected lifecycle"
);
}
}
} else {
error!(token_hash = ?token_hash, "apply RevertInFlight: token not in mirror");
}
}
JoinTokenTransitionKind::MarkExpired => {
if let Some(entry) = map.get_mut(&token_hash) {
match &entry.lifecycle {
JoinTokenLifecycle::Issued | JoinTokenLifecycle::InFlight { .. } => {
entry.lifecycle = JoinTokenLifecycle::Expired;
}
JoinTokenLifecycle::Expired => {} other => {
error!(
token_hash = ?token_hash,
?other,
"apply MarkExpired: unexpected lifecycle"
);
}
}
} else {
error!(token_hash = ?token_hash, "apply MarkExpired: token not in mirror");
}
}
JoinTokenTransitionKind::MarkAborted => {
if let Some(entry) = map.get_mut(&token_hash) {
match &entry.lifecycle {
JoinTokenLifecycle::Aborted => {} _ => {
entry.lifecycle = JoinTokenLifecycle::Aborted;
}
}
} else {
error!(token_hash = ?token_hash, "apply MarkAborted: token not in mirror");
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::token_state::InMemoryTokenStore;
use std::collections::HashMap;
use std::sync::Mutex;
use std::sync::atomic::{AtomicU64, Ordering};
struct MirroringProposer {
counter: AtomicU64,
mirror: SharedTokenStateMirror,
}
impl MirroringProposer {
fn new(mirror: SharedTokenStateMirror) -> Arc<Self> {
Arc::new(Self {
counter: AtomicU64::new(0),
mirror,
})
}
}
#[async_trait]
impl MetadataProposer for MirroringProposer {
async fn propose_and_wait(&self, entry: MetadataEntry) -> crate::error::Result<u64> {
let idx = self.counter.fetch_add(1, Ordering::SeqCst) + 1;
if let MetadataEntry::JoinTokenTransition {
token_hash,
transition,
ts_ms,
} = entry
{
apply_token_transition_to_mirror(&self.mirror, token_hash, &transition, ts_ms);
}
Ok(idx)
}
}
fn make_store() -> (RaftBackedTokenStore, SharedTokenStateMirror) {
let mirror: SharedTokenStateMirror = Arc::new(Mutex::new(HashMap::new()));
let proposer = MirroringProposer::new(mirror.clone());
let store = RaftBackedTokenStore::new(proposer, mirror.clone());
(store, mirror)
}
fn issued_state(hash: [u8; 32], expires_at_ms: u64) -> JoinTokenState {
JoinTokenState {
token_hash: hash,
lifecycle: JoinTokenLifecycle::Issued,
expires_at_ms,
attempt: 0,
}
}
fn far_future_ms() -> u64 {
epoch_ms() + 3_600_000
}
#[tokio::test]
async fn register_and_begin_inflight_consume() {
let (store, _mirror) = make_store();
let hash = [0xA1u8; 32];
let addr: SocketAddr = "127.0.0.1:9100".parse().unwrap();
store.register(issued_state(hash, far_future_ms())).await;
store.begin_inflight(&hash, addr).await.unwrap();
store.mark_consumed(&hash, addr, epoch_ms()).await.unwrap();
let s = store.get(&hash).unwrap();
assert!(matches!(s.lifecycle, JoinTokenLifecycle::Consumed { .. }));
}
#[tokio::test]
async fn replay_after_consume_rejected() {
let (store, _mirror) = make_store();
let hash = [0xA2u8; 32];
let addr: SocketAddr = "127.0.0.1:9101".parse().unwrap();
store.register(issued_state(hash, far_future_ms())).await;
store.begin_inflight(&hash, addr).await.unwrap();
store.mark_consumed(&hash, addr, epoch_ms()).await.unwrap();
assert_eq!(
store.begin_inflight(&hash, addr).await.unwrap_err(),
TokenStateError::AlreadyConsumed
);
}
#[tokio::test]
async fn revert_inflight_allows_retry() {
let (store, _mirror) = make_store();
let hash = [0xA3u8; 32];
let addr: SocketAddr = "127.0.0.1:9102".parse().unwrap();
store.register(issued_state(hash, far_future_ms())).await;
store.begin_inflight(&hash, addr).await.unwrap();
store.revert_inflight(&hash).await.unwrap();
let s = store.get(&hash).unwrap();
assert_eq!(s.lifecycle, JoinTokenLifecycle::Issued);
store.begin_inflight(&hash, addr).await.unwrap();
}
#[tokio::test]
async fn cross_node_replay_rejected_via_shared_mirror() {
let mirror: SharedTokenStateMirror = Arc::new(Mutex::new(HashMap::new()));
let hash = [0xA4u8; 32];
let addr_a: SocketAddr = "10.0.0.1:9000".parse().unwrap();
let addr_b: SocketAddr = "10.0.0.2:9000".parse().unwrap();
let proposer_a = MirroringProposer::new(mirror.clone());
let store_a = RaftBackedTokenStore::new(proposer_a, mirror.clone());
store_a.register(issued_state(hash, far_future_ms())).await;
store_a.begin_inflight(&hash, addr_a).await.unwrap();
store_a
.mark_consumed(&hash, addr_a, epoch_ms())
.await
.unwrap();
let proposer_b = MirroringProposer::new(mirror.clone());
let store_b = RaftBackedTokenStore::new(proposer_b, mirror.clone());
assert_eq!(
store_b.begin_inflight(&hash, addr_b).await.unwrap_err(),
TokenStateError::AlreadyConsumed,
"node B must reject replayed consumed token"
);
}
#[tokio::test]
async fn crash_recovery_replay_rejected() {
let mirror: SharedTokenStateMirror = Arc::new(Mutex::new(HashMap::new()));
let hash = [0xA5u8; 32];
let addr: SocketAddr = "10.0.0.1:9001".parse().unwrap();
let proposer = MirroringProposer::new(mirror.clone());
let store = RaftBackedTokenStore::new(proposer, mirror.clone());
store.register(issued_state(hash, far_future_ms())).await;
store.begin_inflight(&hash, addr).await.unwrap();
store.mark_consumed(&hash, addr, epoch_ms()).await.unwrap();
let fresh_mirror: SharedTokenStateMirror = Arc::new(Mutex::new(HashMap::new()));
let ts = epoch_ms();
apply_token_transition_to_mirror(
&fresh_mirror,
hash,
&JoinTokenTransitionKind::Register {
expires_at_ms: far_future_ms(),
},
ts,
);
apply_token_transition_to_mirror(
&fresh_mirror,
hash,
&JoinTokenTransitionKind::BeginInFlight {
node_addr: addr.to_string(),
},
ts,
);
apply_token_transition_to_mirror(
&fresh_mirror,
hash,
&JoinTokenTransitionKind::MarkConsumed {
node_addr: addr.to_string(),
},
ts,
);
let proposer2 = MirroringProposer::new(fresh_mirror.clone());
let store2 = RaftBackedTokenStore::new(proposer2, fresh_mirror.clone());
assert_eq!(
store2.begin_inflight(&hash, addr).await.unwrap_err(),
TokenStateError::AlreadyConsumed,
"post-crash replay must be rejected"
);
}
#[tokio::test]
async fn in_memory_store_still_works_async() {
let store = InMemoryTokenStore::new();
let hash = [0xA6u8; 32];
let addr: SocketAddr = "127.0.0.1:9200".parse().unwrap();
store
.register(JoinTokenState {
token_hash: hash,
lifecycle: JoinTokenLifecycle::Issued,
expires_at_ms: far_future_ms(),
attempt: 0,
})
.await;
store.begin_inflight(&hash, addr).await.unwrap();
store.mark_consumed(&hash, addr, epoch_ms()).await.unwrap();
assert!(matches!(
store.get(&hash).unwrap().lifecycle,
JoinTokenLifecycle::Consumed { .. }
));
}
}