1use std::collections::HashMap;
26use std::net::SocketAddr;
27use std::sync::{Arc, Mutex};
28use std::time::Duration;
29
30use async_trait::async_trait;
31
32#[derive(Debug, Clone, PartialEq, Eq)]
34pub enum JoinTokenLifecycle {
35 Issued,
37 InFlight { node_addr: SocketAddr },
40 Consumed { node_addr: SocketAddr, ts_ms: u64 },
43 Expired,
45 Aborted,
47}
48
49#[derive(Debug, Clone)]
51pub struct JoinTokenState {
52 pub token_hash: [u8; 32],
54 pub lifecycle: JoinTokenLifecycle,
55 pub expires_at_ms: u64,
57 pub attempt: u32,
60}
61
62#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
64pub enum TokenStateError {
65 #[error("join token already consumed")]
66 AlreadyConsumed,
67 #[error("join token expired")]
68 Expired,
69 #[error("join token aborted")]
70 Aborted,
71 #[error("join token is already in-flight from a different address")]
72 InFlightConflict,
73 #[error("join token not found")]
74 NotFound,
75 #[error("unexpected lifecycle state for this transition")]
76 InvalidTransition,
77 #[error("raft proposer error: {detail}")]
80 ProposerError { detail: String },
81}
82
83pub type SharedTokenStateMirror = Arc<Mutex<HashMap<[u8; 32], JoinTokenState>>>;
88
89#[async_trait]
101pub trait TokenStateBackend: Send + Sync + 'static {
102 async fn register(&self, state: JoinTokenState);
104 async fn begin_inflight(
108 &self,
109 token_hash: &[u8; 32],
110 node_addr: SocketAddr,
111 ) -> Result<(), TokenStateError>;
112 async fn mark_consumed(
115 &self,
116 token_hash: &[u8; 32],
117 node_addr: SocketAddr,
118 ts_ms: u64,
119 ) -> Result<(), TokenStateError>;
120 async fn revert_inflight(&self, token_hash: &[u8; 32]) -> Result<(), TokenStateError>;
122 fn get(&self, token_hash: &[u8; 32]) -> Option<JoinTokenState>;
124}
125
126#[derive(Default, Clone)]
133pub struct InMemoryTokenStore {
134 inner: Arc<Mutex<HashMap<[u8; 32], JoinTokenState>>>,
135}
136
137impl InMemoryTokenStore {
138 pub fn new() -> Self {
139 Self::default()
140 }
141}
142
143#[async_trait]
144impl TokenStateBackend for InMemoryTokenStore {
145 async fn register(&self, state: JoinTokenState) {
146 let mut map = self.inner.lock().expect("token store lock poisoned");
147 map.insert(state.token_hash, state);
148 }
149
150 async fn begin_inflight(
151 &self,
152 token_hash: &[u8; 32],
153 node_addr: SocketAddr,
154 ) -> Result<(), TokenStateError> {
155 let mut map = self.inner.lock().expect("token store lock poisoned");
156 let entry = map.get_mut(token_hash).ok_or(TokenStateError::NotFound)?;
157 match &entry.lifecycle {
158 JoinTokenLifecycle::Issued => {
159 let now_ms = epoch_ms();
160 if now_ms > entry.expires_at_ms {
161 entry.lifecycle = JoinTokenLifecycle::Expired;
162 return Err(TokenStateError::Expired);
163 }
164 entry.lifecycle = JoinTokenLifecycle::InFlight { node_addr };
165 entry.attempt += 1;
166 Ok(())
167 }
168 JoinTokenLifecycle::InFlight {
169 node_addr: existing,
170 } => {
171 if *existing == node_addr {
172 Ok(())
174 } else {
175 Err(TokenStateError::InFlightConflict)
176 }
177 }
178 JoinTokenLifecycle::Consumed { .. } => Err(TokenStateError::AlreadyConsumed),
179 JoinTokenLifecycle::Expired => Err(TokenStateError::Expired),
180 JoinTokenLifecycle::Aborted => Err(TokenStateError::Aborted),
181 }
182 }
183
184 async fn mark_consumed(
185 &self,
186 token_hash: &[u8; 32],
187 node_addr: SocketAddr,
188 ts_ms: u64,
189 ) -> Result<(), TokenStateError> {
190 let mut map = self.inner.lock().expect("token store lock poisoned");
191 let entry = map.get_mut(token_hash).ok_or(TokenStateError::NotFound)?;
192 match &entry.lifecycle {
193 JoinTokenLifecycle::InFlight { .. } => {
194 entry.lifecycle = JoinTokenLifecycle::Consumed { node_addr, ts_ms };
195 Ok(())
196 }
197 JoinTokenLifecycle::Consumed { .. } => Err(TokenStateError::AlreadyConsumed),
198 _ => Err(TokenStateError::InvalidTransition),
199 }
200 }
201
202 async fn revert_inflight(&self, token_hash: &[u8; 32]) -> Result<(), TokenStateError> {
203 let mut map = self.inner.lock().expect("token store lock poisoned");
204 let entry = map.get_mut(token_hash).ok_or(TokenStateError::NotFound)?;
205 match &entry.lifecycle {
206 JoinTokenLifecycle::InFlight { .. } => {
207 entry.lifecycle = JoinTokenLifecycle::Issued;
208 Ok(())
209 }
210 _ => Err(TokenStateError::InvalidTransition),
211 }
212 }
213
214 fn get(&self, token_hash: &[u8; 32]) -> Option<JoinTokenState> {
215 let map = self.inner.lock().expect("token store lock poisoned");
216 map.get(token_hash).cloned()
217 }
218}
219
220pub fn spawn_inflight_timeout<B: TokenStateBackend>(
224 backend: Arc<B>,
225 token_hash: [u8; 32],
226 timeout: Duration,
227) {
228 tokio::spawn(async move {
229 tokio::time::sleep(timeout).await;
230 let _ = backend.revert_inflight(&token_hash).await;
232 });
233}
234
235fn epoch_ms() -> u64 {
236 use std::time::{SystemTime, UNIX_EPOCH};
237 SystemTime::now()
238 .duration_since(UNIX_EPOCH)
239 .map(|d| d.as_millis() as u64)
240 .unwrap_or(0)
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 fn dummy_addr() -> SocketAddr {
248 "127.0.0.1:9000".parse().unwrap()
249 }
250
251 fn make_state(hash: [u8; 32], expires_in_secs: u64) -> JoinTokenState {
252 let expires_at_ms = epoch_ms() + expires_in_secs * 1000;
253 JoinTokenState {
254 token_hash: hash,
255 lifecycle: JoinTokenLifecycle::Issued,
256 expires_at_ms,
257 attempt: 0,
258 }
259 }
260
261 #[tokio::test]
262 async fn issued_to_inflight_to_consumed() {
263 let store = InMemoryTokenStore::new();
264 let hash = [0x01u8; 32];
265 store.register(make_state(hash, 60)).await;
266
267 let addr = dummy_addr();
268 store.begin_inflight(&hash, addr).await.unwrap();
269 {
270 let s = store.get(&hash).unwrap();
271 assert_eq!(
272 s.lifecycle,
273 JoinTokenLifecycle::InFlight { node_addr: addr }
274 );
275 assert_eq!(s.attempt, 1);
276 }
277
278 let ts = epoch_ms();
279 store.mark_consumed(&hash, addr, ts).await.unwrap();
280 let s = store.get(&hash).unwrap();
281 assert_eq!(
282 s.lifecycle,
283 JoinTokenLifecycle::Consumed {
284 node_addr: addr,
285 ts_ms: ts
286 }
287 );
288 }
289
290 #[tokio::test]
291 async fn replay_on_consumed_token_returns_error() {
292 let store = InMemoryTokenStore::new();
293 let hash = [0x02u8; 32];
294 store.register(make_state(hash, 60)).await;
295 let addr = dummy_addr();
296 store.begin_inflight(&hash, addr).await.unwrap();
297 store.mark_consumed(&hash, addr, epoch_ms()).await.unwrap();
298
299 assert_eq!(
301 store.begin_inflight(&hash, addr).await.unwrap_err(),
302 TokenStateError::AlreadyConsumed
303 );
304 }
305
306 #[tokio::test]
307 async fn inflight_reverts_to_issued_on_timeout() {
308 let store = InMemoryTokenStore::new();
309 let hash = [0x03u8; 32];
310 store.register(make_state(hash, 60)).await;
311 let addr = dummy_addr();
312 store.begin_inflight(&hash, addr).await.unwrap();
313 store.revert_inflight(&hash).await.unwrap();
314 let s = store.get(&hash).unwrap();
315 assert_eq!(s.lifecycle, JoinTokenLifecycle::Issued);
316 store.begin_inflight(&hash, addr).await.unwrap();
318 let s = store.get(&hash).unwrap();
319 assert_eq!(s.attempt, 2);
320 }
321
322 #[tokio::test]
323 async fn expired_token_rejected() {
324 let store = InMemoryTokenStore::new();
325 let hash = [0x04u8; 32];
326 let state = JoinTokenState {
328 token_hash: hash,
329 lifecycle: JoinTokenLifecycle::Issued,
330 expires_at_ms: 1, attempt: 0,
332 };
333 store.register(state).await;
334 assert_eq!(
335 store.begin_inflight(&hash, dummy_addr()).await.unwrap_err(),
336 TokenStateError::Expired
337 );
338 let s = store.get(&hash).unwrap();
340 assert_eq!(s.lifecycle, JoinTokenLifecycle::Expired);
341 }
342
343 #[tokio::test]
344 async fn aborted_token_rejected() {
345 let store = InMemoryTokenStore::new();
346 let hash = [0x05u8; 32];
347 let mut state = make_state(hash, 60);
348 state.lifecycle = JoinTokenLifecycle::Aborted;
349 store.register(state).await;
350 assert_eq!(
351 store.begin_inflight(&hash, dummy_addr()).await.unwrap_err(),
352 TokenStateError::Aborted
353 );
354 }
355
356 #[tokio::test]
357 async fn inflight_same_addr_is_idempotent() {
358 let store = InMemoryTokenStore::new();
359 let hash = [0x06u8; 32];
360 store.register(make_state(hash, 60)).await;
361 let addr = dummy_addr();
362 store.begin_inflight(&hash, addr).await.unwrap();
363 store.begin_inflight(&hash, addr).await.unwrap();
365 let s = store.get(&hash).unwrap();
366 assert_eq!(s.attempt, 1);
368 }
369}