Skip to main content

nodedb_cluster/auth/
token_state.rs

1// SPDX-License-Identifier: BUSL-1.1
2
3//! In-process token state machine for single-use join-token enforcement.
4//!
5//! The `JoinTokenStore` tracks every issued token's lifecycle from
6//! `Issued` through `InFlight` to `Consumed` (or `Expired`/`Aborted`).
7//! In a full distributed deployment this state is proposed through
8//! the metadata Raft group (as `MetadataEntry::JoinTokenTransition`) so
9//! all nodes reject a replayed token even after a crash-restart.
10//!
11//! The trait [`TokenStateBackend`] abstracts the storage so the
12//! bootstrap-listener handler can be tested with the in-memory backend
13//! and wired to the Raft-backed backend in production.
14//!
15//! # Async trait
16//!
17//! `TokenStateBackend` is async (via `async_trait`) because the
18//! production [`crate::auth::raft_backed_store::RaftBackedTokenStore`]
19//! must call `MetadataProposer::propose_and_wait` which is inherently
20//! async. Using `block_in_place` to bridge async→sync at the trait
21//! boundary would couple the trait to Tokio internals and make it
22//! untestable without a runtime. An async trait is the clean solution;
23//! `InMemoryTokenStore` simply uses immediate `async { }` bodies.
24
25use std::collections::HashMap;
26use std::net::SocketAddr;
27use std::sync::{Arc, Mutex};
28use std::time::Duration;
29
30use async_trait::async_trait;
31
32/// Lifecycle states for a join token.
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub enum JoinTokenLifecycle {
35    /// Token has been issued; no joiner has presented it yet.
36    Issued,
37    /// A joiner at `node_addr` is currently receiving its bundle.
38    /// If the joiner times out, the state reverts to `Issued`.
39    InFlight { node_addr: SocketAddr },
40    /// Bundle was successfully delivered to `node_addr`.
41    /// Replay attempts on the same token are rejected.
42    Consumed { node_addr: SocketAddr, ts_ms: u64 },
43    /// Token's expiry timestamp has passed without being consumed.
44    Expired,
45    /// Explicitly invalidated (e.g. operator revoke).
46    Aborted,
47}
48
49/// Complete state record for one token.
50#[derive(Debug, Clone)]
51pub struct JoinTokenState {
52    /// SHA-256 of the token hex string. Never stores the raw token.
53    pub token_hash: [u8; 32],
54    pub lifecycle: JoinTokenLifecycle,
55    /// Absolute unix-ms expiry derived from the token's `expiry_unix_secs`.
56    pub expires_at_ms: u64,
57    /// Number of times this token moved from `Issued` to `InFlight`
58    /// (increments on retry after an in-flight timeout).
59    pub attempt: u32,
60}
61
62/// Error from token state transitions.
63#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
64pub enum TokenStateError {
65    #[error("join token already consumed")]
66    AlreadyConsumed,
67    #[error("join token expired")]
68    Expired,
69    #[error("join token aborted")]
70    Aborted,
71    #[error("join token is already in-flight from a different address")]
72    InFlightConflict,
73    #[error("join token not found")]
74    NotFound,
75    #[error("unexpected lifecycle state for this transition")]
76    InvalidTransition,
77    /// The Raft proposer returned an error. The transition was not
78    /// replicated; single-use enforcement may be incomplete.
79    #[error("raft proposer error: {detail}")]
80    ProposerError { detail: String },
81}
82
83/// Shared read-side mirror of token lifecycle state, updated only by the
84/// Raft apply path. The `RaftBackedTokenStore` and `CacheApplier` share
85/// this handle so the applier can write apply-path updates and the store
86/// can read post-apply state without a second round-trip.
87pub type SharedTokenStateMirror = Arc<Mutex<HashMap<[u8; 32], JoinTokenState>>>;
88
89/// Abstraction over where token state is persisted.
90///
91/// The in-memory implementation ([`InMemoryTokenStore`]) is used in
92/// tests and single-node deployments that don't need cross-crash
93/// single-use guarantees. Production deployments wire a Raft-backed
94/// implementation that proposes each transition through the metadata
95/// group.
96///
97/// The trait is async so the Raft-backed implementation can call
98/// `MetadataProposer::propose_and_wait` directly. `InMemoryTokenStore`
99/// uses immediate `async { }` bodies — zero overhead.
100#[async_trait]
101pub trait TokenStateBackend: Send + Sync + 'static {
102    /// Register a freshly issued token with `Issued` state.
103    async fn register(&self, state: JoinTokenState);
104    /// Attempt to transition from `Issued` → `InFlight`.
105    /// Returns `Err(AlreadyConsumed)` / `Err(Expired)` / `Err(Aborted)`
106    /// if the token is already in a terminal or conflicting state.
107    async fn begin_inflight(
108        &self,
109        token_hash: &[u8; 32],
110        node_addr: SocketAddr,
111    ) -> Result<(), TokenStateError>;
112    /// Transition from `InFlight` → `Consumed`. Called after the bundle
113    /// has been sent and the peer has acknowledged receipt.
114    async fn mark_consumed(
115        &self,
116        token_hash: &[u8; 32],
117        node_addr: SocketAddr,
118        ts_ms: u64,
119    ) -> Result<(), TokenStateError>;
120    /// Revert `InFlight` → `Issued` (joiner timed out before ACK).
121    async fn revert_inflight(&self, token_hash: &[u8; 32]) -> Result<(), TokenStateError>;
122    /// Look up the current state.
123    fn get(&self, token_hash: &[u8; 32]) -> Option<JoinTokenState>;
124}
125
126/// Simple in-memory token store. Thread-safe via `Mutex`.
127///
128/// Suitable for tests and single-node deployments. Does not survive
129/// process restart — a previously-consumed token is invisible after
130/// restart, allowing replay on that edge case. Production deployments
131/// must use a Raft-backed backend.
132#[derive(Default, Clone)]
133pub struct InMemoryTokenStore {
134    inner: Arc<Mutex<HashMap<[u8; 32], JoinTokenState>>>,
135}
136
137impl InMemoryTokenStore {
138    pub fn new() -> Self {
139        Self::default()
140    }
141}
142
143#[async_trait]
144impl TokenStateBackend for InMemoryTokenStore {
145    async fn register(&self, state: JoinTokenState) {
146        let mut map = self.inner.lock().expect("token store lock poisoned");
147        map.insert(state.token_hash, state);
148    }
149
150    async fn begin_inflight(
151        &self,
152        token_hash: &[u8; 32],
153        node_addr: SocketAddr,
154    ) -> Result<(), TokenStateError> {
155        let mut map = self.inner.lock().expect("token store lock poisoned");
156        let entry = map.get_mut(token_hash).ok_or(TokenStateError::NotFound)?;
157        match &entry.lifecycle {
158            JoinTokenLifecycle::Issued => {
159                let now_ms = epoch_ms();
160                if now_ms > entry.expires_at_ms {
161                    entry.lifecycle = JoinTokenLifecycle::Expired;
162                    return Err(TokenStateError::Expired);
163                }
164                entry.lifecycle = JoinTokenLifecycle::InFlight { node_addr };
165                entry.attempt += 1;
166                Ok(())
167            }
168            JoinTokenLifecycle::InFlight {
169                node_addr: existing,
170            } => {
171                if *existing == node_addr {
172                    // Idempotent: same joiner re-presenting (e.g. reconnect).
173                    Ok(())
174                } else {
175                    Err(TokenStateError::InFlightConflict)
176                }
177            }
178            JoinTokenLifecycle::Consumed { .. } => Err(TokenStateError::AlreadyConsumed),
179            JoinTokenLifecycle::Expired => Err(TokenStateError::Expired),
180            JoinTokenLifecycle::Aborted => Err(TokenStateError::Aborted),
181        }
182    }
183
184    async fn mark_consumed(
185        &self,
186        token_hash: &[u8; 32],
187        node_addr: SocketAddr,
188        ts_ms: u64,
189    ) -> Result<(), TokenStateError> {
190        let mut map = self.inner.lock().expect("token store lock poisoned");
191        let entry = map.get_mut(token_hash).ok_or(TokenStateError::NotFound)?;
192        match &entry.lifecycle {
193            JoinTokenLifecycle::InFlight { .. } => {
194                entry.lifecycle = JoinTokenLifecycle::Consumed { node_addr, ts_ms };
195                Ok(())
196            }
197            JoinTokenLifecycle::Consumed { .. } => Err(TokenStateError::AlreadyConsumed),
198            _ => Err(TokenStateError::InvalidTransition),
199        }
200    }
201
202    async fn revert_inflight(&self, token_hash: &[u8; 32]) -> Result<(), TokenStateError> {
203        let mut map = self.inner.lock().expect("token store lock poisoned");
204        let entry = map.get_mut(token_hash).ok_or(TokenStateError::NotFound)?;
205        match &entry.lifecycle {
206            JoinTokenLifecycle::InFlight { .. } => {
207                entry.lifecycle = JoinTokenLifecycle::Issued;
208                Ok(())
209            }
210            _ => Err(TokenStateError::InvalidTransition),
211        }
212    }
213
214    fn get(&self, token_hash: &[u8; 32]) -> Option<JoinTokenState> {
215        let map = self.inner.lock().expect("token store lock poisoned");
216        map.get(token_hash).cloned()
217    }
218}
219
220/// Spawn an async dead-man timer that reverts an `InFlight` token to
221/// `Issued` after `timeout` if it has not been consumed. This runs as a
222/// detached task; callers call it immediately after `begin_inflight`.
223pub fn spawn_inflight_timeout<B: TokenStateBackend>(
224    backend: Arc<B>,
225    token_hash: [u8; 32],
226    timeout: Duration,
227) {
228    tokio::spawn(async move {
229        tokio::time::sleep(timeout).await;
230        // If still InFlight, revert — the joiner timed out.
231        let _ = backend.revert_inflight(&token_hash).await;
232    });
233}
234
235fn epoch_ms() -> u64 {
236    use std::time::{SystemTime, UNIX_EPOCH};
237    SystemTime::now()
238        .duration_since(UNIX_EPOCH)
239        .map(|d| d.as_millis() as u64)
240        .unwrap_or(0)
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    fn dummy_addr() -> SocketAddr {
248        "127.0.0.1:9000".parse().unwrap()
249    }
250
251    fn make_state(hash: [u8; 32], expires_in_secs: u64) -> JoinTokenState {
252        let expires_at_ms = epoch_ms() + expires_in_secs * 1000;
253        JoinTokenState {
254            token_hash: hash,
255            lifecycle: JoinTokenLifecycle::Issued,
256            expires_at_ms,
257            attempt: 0,
258        }
259    }
260
261    #[tokio::test]
262    async fn issued_to_inflight_to_consumed() {
263        let store = InMemoryTokenStore::new();
264        let hash = [0x01u8; 32];
265        store.register(make_state(hash, 60)).await;
266
267        let addr = dummy_addr();
268        store.begin_inflight(&hash, addr).await.unwrap();
269        {
270            let s = store.get(&hash).unwrap();
271            assert_eq!(
272                s.lifecycle,
273                JoinTokenLifecycle::InFlight { node_addr: addr }
274            );
275            assert_eq!(s.attempt, 1);
276        }
277
278        let ts = epoch_ms();
279        store.mark_consumed(&hash, addr, ts).await.unwrap();
280        let s = store.get(&hash).unwrap();
281        assert_eq!(
282            s.lifecycle,
283            JoinTokenLifecycle::Consumed {
284                node_addr: addr,
285                ts_ms: ts
286            }
287        );
288    }
289
290    #[tokio::test]
291    async fn replay_on_consumed_token_returns_error() {
292        let store = InMemoryTokenStore::new();
293        let hash = [0x02u8; 32];
294        store.register(make_state(hash, 60)).await;
295        let addr = dummy_addr();
296        store.begin_inflight(&hash, addr).await.unwrap();
297        store.mark_consumed(&hash, addr, epoch_ms()).await.unwrap();
298
299        // Second begin_inflight must be rejected.
300        assert_eq!(
301            store.begin_inflight(&hash, addr).await.unwrap_err(),
302            TokenStateError::AlreadyConsumed
303        );
304    }
305
306    #[tokio::test]
307    async fn inflight_reverts_to_issued_on_timeout() {
308        let store = InMemoryTokenStore::new();
309        let hash = [0x03u8; 32];
310        store.register(make_state(hash, 60)).await;
311        let addr = dummy_addr();
312        store.begin_inflight(&hash, addr).await.unwrap();
313        store.revert_inflight(&hash).await.unwrap();
314        let s = store.get(&hash).unwrap();
315        assert_eq!(s.lifecycle, JoinTokenLifecycle::Issued);
316        // Second attempt is allowed after revert.
317        store.begin_inflight(&hash, addr).await.unwrap();
318        let s = store.get(&hash).unwrap();
319        assert_eq!(s.attempt, 2);
320    }
321
322    #[tokio::test]
323    async fn expired_token_rejected() {
324        let store = InMemoryTokenStore::new();
325        let hash = [0x04u8; 32];
326        // expires_at_ms in the past
327        let state = JoinTokenState {
328            token_hash: hash,
329            lifecycle: JoinTokenLifecycle::Issued,
330            expires_at_ms: 1, // Jan 1, 1970 — long expired
331            attempt: 0,
332        };
333        store.register(state).await;
334        assert_eq!(
335            store.begin_inflight(&hash, dummy_addr()).await.unwrap_err(),
336            TokenStateError::Expired
337        );
338        // State machine must have transitioned to Expired.
339        let s = store.get(&hash).unwrap();
340        assert_eq!(s.lifecycle, JoinTokenLifecycle::Expired);
341    }
342
343    #[tokio::test]
344    async fn aborted_token_rejected() {
345        let store = InMemoryTokenStore::new();
346        let hash = [0x05u8; 32];
347        let mut state = make_state(hash, 60);
348        state.lifecycle = JoinTokenLifecycle::Aborted;
349        store.register(state).await;
350        assert_eq!(
351            store.begin_inflight(&hash, dummy_addr()).await.unwrap_err(),
352            TokenStateError::Aborted
353        );
354    }
355
356    #[tokio::test]
357    async fn inflight_same_addr_is_idempotent() {
358        let store = InMemoryTokenStore::new();
359        let hash = [0x06u8; 32];
360        store.register(make_state(hash, 60)).await;
361        let addr = dummy_addr();
362        store.begin_inflight(&hash, addr).await.unwrap();
363        // Same addr: idempotent
364        store.begin_inflight(&hash, addr).await.unwrap();
365        let s = store.get(&hash).unwrap();
366        // attempt incremented only on the first begin_inflight
367        assert_eq!(s.attempt, 1);
368    }
369}