use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use async_trait::async_trait;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum JoinTokenLifecycle {
Issued,
InFlight { node_addr: SocketAddr },
Consumed { node_addr: SocketAddr, ts_ms: u64 },
Expired,
Aborted,
}
#[derive(Debug, Clone)]
pub struct JoinTokenState {
pub token_hash: [u8; 32],
pub lifecycle: JoinTokenLifecycle,
pub expires_at_ms: u64,
pub attempt: u32,
}
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
pub enum TokenStateError {
#[error("join token already consumed")]
AlreadyConsumed,
#[error("join token expired")]
Expired,
#[error("join token aborted")]
Aborted,
#[error("join token is already in-flight from a different address")]
InFlightConflict,
#[error("join token not found")]
NotFound,
#[error("unexpected lifecycle state for this transition")]
InvalidTransition,
#[error("raft proposer error: {detail}")]
ProposerError { detail: String },
}
pub type SharedTokenStateMirror = Arc<Mutex<HashMap<[u8; 32], JoinTokenState>>>;
#[async_trait]
pub trait TokenStateBackend: Send + Sync + 'static {
async fn register(&self, state: JoinTokenState);
async fn begin_inflight(
&self,
token_hash: &[u8; 32],
node_addr: SocketAddr,
) -> Result<(), TokenStateError>;
async fn mark_consumed(
&self,
token_hash: &[u8; 32],
node_addr: SocketAddr,
ts_ms: u64,
) -> Result<(), TokenStateError>;
async fn revert_inflight(&self, token_hash: &[u8; 32]) -> Result<(), TokenStateError>;
fn get(&self, token_hash: &[u8; 32]) -> Option<JoinTokenState>;
}
#[derive(Default, Clone)]
pub struct InMemoryTokenStore {
inner: Arc<Mutex<HashMap<[u8; 32], JoinTokenState>>>,
}
impl InMemoryTokenStore {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl TokenStateBackend for InMemoryTokenStore {
async fn register(&self, state: JoinTokenState) {
let mut map = self.inner.lock().expect("token store lock poisoned");
map.insert(state.token_hash, state);
}
async fn begin_inflight(
&self,
token_hash: &[u8; 32],
node_addr: SocketAddr,
) -> Result<(), TokenStateError> {
let mut map = self.inner.lock().expect("token store lock poisoned");
let entry = map.get_mut(token_hash).ok_or(TokenStateError::NotFound)?;
match &entry.lifecycle {
JoinTokenLifecycle::Issued => {
let now_ms = epoch_ms();
if now_ms > entry.expires_at_ms {
entry.lifecycle = JoinTokenLifecycle::Expired;
return Err(TokenStateError::Expired);
}
entry.lifecycle = JoinTokenLifecycle::InFlight { node_addr };
entry.attempt += 1;
Ok(())
}
JoinTokenLifecycle::InFlight {
node_addr: existing,
} => {
if *existing == node_addr {
Ok(())
} else {
Err(TokenStateError::InFlightConflict)
}
}
JoinTokenLifecycle::Consumed { .. } => Err(TokenStateError::AlreadyConsumed),
JoinTokenLifecycle::Expired => Err(TokenStateError::Expired),
JoinTokenLifecycle::Aborted => Err(TokenStateError::Aborted),
}
}
async fn mark_consumed(
&self,
token_hash: &[u8; 32],
node_addr: SocketAddr,
ts_ms: u64,
) -> Result<(), TokenStateError> {
let mut map = self.inner.lock().expect("token store lock poisoned");
let entry = map.get_mut(token_hash).ok_or(TokenStateError::NotFound)?;
match &entry.lifecycle {
JoinTokenLifecycle::InFlight { .. } => {
entry.lifecycle = JoinTokenLifecycle::Consumed { node_addr, ts_ms };
Ok(())
}
JoinTokenLifecycle::Consumed { .. } => Err(TokenStateError::AlreadyConsumed),
_ => Err(TokenStateError::InvalidTransition),
}
}
async fn revert_inflight(&self, token_hash: &[u8; 32]) -> Result<(), TokenStateError> {
let mut map = self.inner.lock().expect("token store lock poisoned");
let entry = map.get_mut(token_hash).ok_or(TokenStateError::NotFound)?;
match &entry.lifecycle {
JoinTokenLifecycle::InFlight { .. } => {
entry.lifecycle = JoinTokenLifecycle::Issued;
Ok(())
}
_ => Err(TokenStateError::InvalidTransition),
}
}
fn get(&self, token_hash: &[u8; 32]) -> Option<JoinTokenState> {
let map = self.inner.lock().expect("token store lock poisoned");
map.get(token_hash).cloned()
}
}
pub fn spawn_inflight_timeout<B: TokenStateBackend>(
backend: Arc<B>,
token_hash: [u8; 32],
timeout: Duration,
) {
tokio::spawn(async move {
tokio::time::sleep(timeout).await;
let _ = backend.revert_inflight(&token_hash).await;
});
}
fn epoch_ms() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
fn dummy_addr() -> SocketAddr {
"127.0.0.1:9000".parse().unwrap()
}
fn make_state(hash: [u8; 32], expires_in_secs: u64) -> JoinTokenState {
let expires_at_ms = epoch_ms() + expires_in_secs * 1000;
JoinTokenState {
token_hash: hash,
lifecycle: JoinTokenLifecycle::Issued,
expires_at_ms,
attempt: 0,
}
}
#[tokio::test]
async fn issued_to_inflight_to_consumed() {
let store = InMemoryTokenStore::new();
let hash = [0x01u8; 32];
store.register(make_state(hash, 60)).await;
let addr = dummy_addr();
store.begin_inflight(&hash, addr).await.unwrap();
{
let s = store.get(&hash).unwrap();
assert_eq!(
s.lifecycle,
JoinTokenLifecycle::InFlight { node_addr: addr }
);
assert_eq!(s.attempt, 1);
}
let ts = epoch_ms();
store.mark_consumed(&hash, addr, ts).await.unwrap();
let s = store.get(&hash).unwrap();
assert_eq!(
s.lifecycle,
JoinTokenLifecycle::Consumed {
node_addr: addr,
ts_ms: ts
}
);
}
#[tokio::test]
async fn replay_on_consumed_token_returns_error() {
let store = InMemoryTokenStore::new();
let hash = [0x02u8; 32];
store.register(make_state(hash, 60)).await;
let addr = dummy_addr();
store.begin_inflight(&hash, addr).await.unwrap();
store.mark_consumed(&hash, addr, epoch_ms()).await.unwrap();
assert_eq!(
store.begin_inflight(&hash, addr).await.unwrap_err(),
TokenStateError::AlreadyConsumed
);
}
#[tokio::test]
async fn inflight_reverts_to_issued_on_timeout() {
let store = InMemoryTokenStore::new();
let hash = [0x03u8; 32];
store.register(make_state(hash, 60)).await;
let addr = dummy_addr();
store.begin_inflight(&hash, addr).await.unwrap();
store.revert_inflight(&hash).await.unwrap();
let s = store.get(&hash).unwrap();
assert_eq!(s.lifecycle, JoinTokenLifecycle::Issued);
store.begin_inflight(&hash, addr).await.unwrap();
let s = store.get(&hash).unwrap();
assert_eq!(s.attempt, 2);
}
#[tokio::test]
async fn expired_token_rejected() {
let store = InMemoryTokenStore::new();
let hash = [0x04u8; 32];
let state = JoinTokenState {
token_hash: hash,
lifecycle: JoinTokenLifecycle::Issued,
expires_at_ms: 1, attempt: 0,
};
store.register(state).await;
assert_eq!(
store.begin_inflight(&hash, dummy_addr()).await.unwrap_err(),
TokenStateError::Expired
);
let s = store.get(&hash).unwrap();
assert_eq!(s.lifecycle, JoinTokenLifecycle::Expired);
}
#[tokio::test]
async fn aborted_token_rejected() {
let store = InMemoryTokenStore::new();
let hash = [0x05u8; 32];
let mut state = make_state(hash, 60);
state.lifecycle = JoinTokenLifecycle::Aborted;
store.register(state).await;
assert_eq!(
store.begin_inflight(&hash, dummy_addr()).await.unwrap_err(),
TokenStateError::Aborted
);
}
#[tokio::test]
async fn inflight_same_addr_is_idempotent() {
let store = InMemoryTokenStore::new();
let hash = [0x06u8; 32];
store.register(make_state(hash, 60)).await;
let addr = dummy_addr();
store.begin_inflight(&hash, addr).await.unwrap();
store.begin_inflight(&hash, addr).await.unwrap();
let s = store.get(&hash).unwrap();
assert_eq!(s.attempt, 1);
}
}