1use std::net::SocketAddr;
37use std::sync::Arc;
38use std::time::{SystemTime, UNIX_EPOCH};
39
40use async_trait::async_trait;
41use tracing::{error, warn};
42
43use crate::auth::token_state::{
44 JoinTokenLifecycle, JoinTokenState, SharedTokenStateMirror, TokenStateBackend, TokenStateError,
45};
46use crate::decommission::coordinator::MetadataProposer;
47use crate::metadata_group::entry::{JoinTokenTransitionKind, MetadataEntry};
48
49fn epoch_ms() -> u64 {
50 SystemTime::now()
51 .duration_since(UNIX_EPOCH)
52 .map(|d| d.as_millis() as u64)
53 .unwrap_or(0)
54}
55
56pub struct RaftBackedTokenStore {
64 proposer: Arc<dyn MetadataProposer>,
65 mirror: SharedTokenStateMirror,
68}
69
70impl RaftBackedTokenStore {
71 pub fn new(proposer: Arc<dyn MetadataProposer>, mirror: SharedTokenStateMirror) -> Self {
75 Self { proposer, mirror }
76 }
77
78 async fn propose(
80 &self,
81 token_hash: [u8; 32],
82 transition: JoinTokenTransitionKind,
83 ) -> Result<(), TokenStateError> {
84 let entry = MetadataEntry::JoinTokenTransition {
85 token_hash,
86 transition,
87 ts_ms: epoch_ms(),
88 };
89 self.proposer
90 .propose_and_wait(entry)
91 .await
92 .map(|_| ())
93 .map_err(|e| {
94 error!(error = %e, "RaftBackedTokenStore: proposer error");
95 TokenStateError::ProposerError {
96 detail: e.to_string(),
97 }
98 })
99 }
100
101 fn read_state(&self, token_hash: &[u8; 32]) -> Option<JoinTokenState> {
102 let map = self.mirror.lock().unwrap_or_else(|p| p.into_inner());
103 map.get(token_hash).cloned()
104 }
105}
106
107#[async_trait]
108impl TokenStateBackend for RaftBackedTokenStore {
109 async fn register(&self, state: JoinTokenState) {
112 let hash = state.token_hash;
113 let expires_at_ms = state.expires_at_ms;
114 if let Err(e) = self
115 .propose(hash, JoinTokenTransitionKind::Register { expires_at_ms })
116 .await
117 {
118 error!(
121 error = %e,
122 token_hash = ?hash,
123 "RaftBackedTokenStore: register failed — token not replicated"
124 );
125 }
126 }
127
128 async fn begin_inflight(
131 &self,
132 token_hash: &[u8; 32],
133 node_addr: SocketAddr,
134 ) -> Result<(), TokenStateError> {
135 if let Some(current) = self.read_state(token_hash) {
137 match ¤t.lifecycle {
138 JoinTokenLifecycle::Consumed { .. } => {
139 return Err(TokenStateError::AlreadyConsumed);
140 }
141 JoinTokenLifecycle::Expired => return Err(TokenStateError::Expired),
142 JoinTokenLifecycle::Aborted => return Err(TokenStateError::Aborted),
143 _ => {}
144 }
145 }
146
147 let addr_str = node_addr.to_string();
148 self.propose(
149 *token_hash,
150 JoinTokenTransitionKind::BeginInFlight {
151 node_addr: addr_str.clone(),
152 },
153 )
154 .await?;
155
156 match self.read_state(token_hash) {
160 None => Err(TokenStateError::NotFound),
161 Some(s) => match &s.lifecycle {
162 JoinTokenLifecycle::InFlight { node_addr: winner } => {
163 if *winner == node_addr {
164 Ok(())
165 } else {
166 Err(TokenStateError::InFlightConflict)
167 }
168 }
169 JoinTokenLifecycle::Consumed { .. } => Err(TokenStateError::AlreadyConsumed),
170 JoinTokenLifecycle::Expired => Err(TokenStateError::Expired),
171 JoinTokenLifecycle::Aborted => Err(TokenStateError::Aborted),
172 JoinTokenLifecycle::Issued => {
173 warn!(
176 token_hash = ?token_hash,
177 "RaftBackedTokenStore: begin_inflight propose succeeded but state remained Issued"
178 );
179 Err(TokenStateError::InvalidTransition)
180 }
181 },
182 }
183 }
184
185 async fn mark_consumed(
186 &self,
187 token_hash: &[u8; 32],
188 node_addr: SocketAddr,
189 _ts_ms: u64,
190 ) -> Result<(), TokenStateError> {
191 let addr_str = node_addr.to_string();
192 self.propose(
193 *token_hash,
194 JoinTokenTransitionKind::MarkConsumed {
195 node_addr: addr_str,
196 },
197 )
198 .await?;
199
200 match self.read_state(token_hash) {
201 None => Err(TokenStateError::NotFound),
202 Some(s) => match &s.lifecycle {
203 JoinTokenLifecycle::Consumed { .. } => Ok(()),
204 _ => Err(TokenStateError::InvalidTransition),
205 },
206 }
207 }
208
209 async fn revert_inflight(&self, token_hash: &[u8; 32]) -> Result<(), TokenStateError> {
210 self.propose(*token_hash, JoinTokenTransitionKind::RevertInFlight)
211 .await?;
212
213 match self.read_state(token_hash) {
214 None => Err(TokenStateError::NotFound),
215 Some(s) => match &s.lifecycle {
216 JoinTokenLifecycle::Issued => Ok(()),
217 _ => Err(TokenStateError::InvalidTransition),
218 },
219 }
220 }
221
222 fn get(&self, token_hash: &[u8; 32]) -> Option<JoinTokenState> {
223 self.read_state(token_hash)
224 }
225}
226pub fn apply_token_transition_to_mirror(
239 mirror: &SharedTokenStateMirror,
240 token_hash: [u8; 32],
241 transition: &JoinTokenTransitionKind,
242 ts_ms: u64,
243) {
244 let mut map = mirror.lock().unwrap_or_else(|p| p.into_inner());
245
246 match transition {
247 JoinTokenTransitionKind::Register { expires_at_ms } => {
248 map.entry(token_hash).or_insert_with(|| JoinTokenState {
250 token_hash,
251 lifecycle: JoinTokenLifecycle::Issued,
252 expires_at_ms: *expires_at_ms,
253 attempt: 0,
254 });
255 }
256
257 JoinTokenTransitionKind::BeginInFlight { node_addr } => {
258 let Ok(parsed_addr) = node_addr.parse::<SocketAddr>() else {
259 error!(
260 token_hash = ?token_hash,
261 addr = %node_addr,
262 "apply BeginInFlight: invalid address — skipping"
263 );
264 return;
265 };
266 if let Some(entry) = map.get_mut(&token_hash) {
267 match &entry.lifecycle {
268 JoinTokenLifecycle::Issued => {
269 entry.lifecycle = JoinTokenLifecycle::InFlight {
270 node_addr: parsed_addr,
271 };
272 entry.attempt += 1;
273 }
274 JoinTokenLifecycle::InFlight {
275 node_addr: existing,
276 } if *existing == parsed_addr => {
277 }
279 other => {
280 error!(
281 token_hash = ?token_hash,
282 ?other,
283 "apply BeginInFlight: unexpected lifecycle — log corruption?"
284 );
285 }
286 }
287 } else {
288 error!(
289 token_hash = ?token_hash,
290 "apply BeginInFlight: token not in mirror (Register must precede this)"
291 );
292 }
293 }
294
295 JoinTokenTransitionKind::MarkConsumed { node_addr } => {
296 let Ok(parsed_addr) = node_addr.parse::<SocketAddr>() else {
297 error!(
298 token_hash = ?token_hash,
299 addr = %node_addr,
300 "apply MarkConsumed: invalid address — skipping"
301 );
302 return;
303 };
304 if let Some(entry) = map.get_mut(&token_hash) {
305 match &entry.lifecycle {
306 JoinTokenLifecycle::InFlight { .. } => {
307 entry.lifecycle = JoinTokenLifecycle::Consumed {
308 node_addr: parsed_addr,
309 ts_ms,
310 };
311 }
312 JoinTokenLifecycle::Consumed { .. } => {
313 }
315 other => {
316 error!(
317 token_hash = ?token_hash,
318 ?other,
319 "apply MarkConsumed: unexpected lifecycle"
320 );
321 }
322 }
323 } else {
324 error!(token_hash = ?token_hash, "apply MarkConsumed: token not in mirror");
325 }
326 }
327
328 JoinTokenTransitionKind::RevertInFlight => {
329 if let Some(entry) = map.get_mut(&token_hash) {
330 match &entry.lifecycle {
331 JoinTokenLifecycle::InFlight { .. } => {
332 entry.lifecycle = JoinTokenLifecycle::Issued;
333 }
334 JoinTokenLifecycle::Issued => {
335 }
337 other => {
338 error!(
339 token_hash = ?token_hash,
340 ?other,
341 "apply RevertInFlight: unexpected lifecycle"
342 );
343 }
344 }
345 } else {
346 error!(token_hash = ?token_hash, "apply RevertInFlight: token not in mirror");
347 }
348 }
349
350 JoinTokenTransitionKind::MarkExpired => {
351 if let Some(entry) = map.get_mut(&token_hash) {
352 match &entry.lifecycle {
353 JoinTokenLifecycle::Issued | JoinTokenLifecycle::InFlight { .. } => {
354 entry.lifecycle = JoinTokenLifecycle::Expired;
355 }
356 JoinTokenLifecycle::Expired => {} other => {
358 error!(
359 token_hash = ?token_hash,
360 ?other,
361 "apply MarkExpired: unexpected lifecycle"
362 );
363 }
364 }
365 } else {
366 error!(token_hash = ?token_hash, "apply MarkExpired: token not in mirror");
367 }
368 }
369
370 JoinTokenTransitionKind::MarkAborted => {
371 if let Some(entry) = map.get_mut(&token_hash) {
372 match &entry.lifecycle {
373 JoinTokenLifecycle::Aborted => {} _ => {
375 entry.lifecycle = JoinTokenLifecycle::Aborted;
376 }
377 }
378 } else {
379 error!(token_hash = ?token_hash, "apply MarkAborted: token not in mirror");
380 }
381 }
382 }
383}
384#[cfg(test)]
385mod tests {
386 use super::*;
387 use crate::auth::token_state::InMemoryTokenStore;
388 use std::collections::HashMap;
389 use std::sync::Mutex;
390 use std::sync::atomic::{AtomicU64, Ordering};
391
392 struct MirroringProposer {
394 counter: AtomicU64,
395 mirror: SharedTokenStateMirror,
396 }
397
398 impl MirroringProposer {
399 fn new(mirror: SharedTokenStateMirror) -> Arc<Self> {
400 Arc::new(Self {
401 counter: AtomicU64::new(0),
402 mirror,
403 })
404 }
405 }
406
407 #[async_trait]
408 impl MetadataProposer for MirroringProposer {
409 async fn propose_and_wait(&self, entry: MetadataEntry) -> crate::error::Result<u64> {
410 let idx = self.counter.fetch_add(1, Ordering::SeqCst) + 1;
411 if let MetadataEntry::JoinTokenTransition {
413 token_hash,
414 transition,
415 ts_ms,
416 } = entry
417 {
418 apply_token_transition_to_mirror(&self.mirror, token_hash, &transition, ts_ms);
419 }
420 Ok(idx)
421 }
422 }
423
424 fn make_store() -> (RaftBackedTokenStore, SharedTokenStateMirror) {
425 let mirror: SharedTokenStateMirror = Arc::new(Mutex::new(HashMap::new()));
426 let proposer = MirroringProposer::new(mirror.clone());
427 let store = RaftBackedTokenStore::new(proposer, mirror.clone());
428 (store, mirror)
429 }
430
431 fn issued_state(hash: [u8; 32], expires_at_ms: u64) -> JoinTokenState {
432 JoinTokenState {
433 token_hash: hash,
434 lifecycle: JoinTokenLifecycle::Issued,
435 expires_at_ms,
436 attempt: 0,
437 }
438 }
439
440 fn far_future_ms() -> u64 {
441 epoch_ms() + 3_600_000
442 }
443
444 #[tokio::test]
445 async fn register_and_begin_inflight_consume() {
446 let (store, _mirror) = make_store();
447 let hash = [0xA1u8; 32];
448 let addr: SocketAddr = "127.0.0.1:9100".parse().unwrap();
449
450 store.register(issued_state(hash, far_future_ms())).await;
451 store.begin_inflight(&hash, addr).await.unwrap();
452 store.mark_consumed(&hash, addr, epoch_ms()).await.unwrap();
453
454 let s = store.get(&hash).unwrap();
455 assert!(matches!(s.lifecycle, JoinTokenLifecycle::Consumed { .. }));
456 }
457
458 #[tokio::test]
459 async fn replay_after_consume_rejected() {
460 let (store, _mirror) = make_store();
461 let hash = [0xA2u8; 32];
462 let addr: SocketAddr = "127.0.0.1:9101".parse().unwrap();
463
464 store.register(issued_state(hash, far_future_ms())).await;
465 store.begin_inflight(&hash, addr).await.unwrap();
466 store.mark_consumed(&hash, addr, epoch_ms()).await.unwrap();
467
468 assert_eq!(
469 store.begin_inflight(&hash, addr).await.unwrap_err(),
470 TokenStateError::AlreadyConsumed
471 );
472 }
473
474 #[tokio::test]
475 async fn revert_inflight_allows_retry() {
476 let (store, _mirror) = make_store();
477 let hash = [0xA3u8; 32];
478 let addr: SocketAddr = "127.0.0.1:9102".parse().unwrap();
479
480 store.register(issued_state(hash, far_future_ms())).await;
481 store.begin_inflight(&hash, addr).await.unwrap();
482 store.revert_inflight(&hash).await.unwrap();
483
484 let s = store.get(&hash).unwrap();
485 assert_eq!(s.lifecycle, JoinTokenLifecycle::Issued);
486 store.begin_inflight(&hash, addr).await.unwrap();
488 }
489
490 #[tokio::test]
493 async fn cross_node_replay_rejected_via_shared_mirror() {
494 let mirror: SharedTokenStateMirror = Arc::new(Mutex::new(HashMap::new()));
496 let hash = [0xA4u8; 32];
497 let addr_a: SocketAddr = "10.0.0.1:9000".parse().unwrap();
498 let addr_b: SocketAddr = "10.0.0.2:9000".parse().unwrap();
499
500 let proposer_a = MirroringProposer::new(mirror.clone());
502 let store_a = RaftBackedTokenStore::new(proposer_a, mirror.clone());
503 store_a.register(issued_state(hash, far_future_ms())).await;
504 store_a.begin_inflight(&hash, addr_a).await.unwrap();
505 store_a
506 .mark_consumed(&hash, addr_a, epoch_ms())
507 .await
508 .unwrap();
509
510 let proposer_b = MirroringProposer::new(mirror.clone());
512 let store_b = RaftBackedTokenStore::new(proposer_b, mirror.clone());
513 assert_eq!(
514 store_b.begin_inflight(&hash, addr_b).await.unwrap_err(),
515 TokenStateError::AlreadyConsumed,
516 "node B must reject replayed consumed token"
517 );
518 }
519
520 #[tokio::test]
522 async fn crash_recovery_replay_rejected() {
523 let mirror: SharedTokenStateMirror = Arc::new(Mutex::new(HashMap::new()));
524 let hash = [0xA5u8; 32];
525 let addr: SocketAddr = "10.0.0.1:9001".parse().unwrap();
526
527 let proposer = MirroringProposer::new(mirror.clone());
529 let store = RaftBackedTokenStore::new(proposer, mirror.clone());
530 store.register(issued_state(hash, far_future_ms())).await;
531 store.begin_inflight(&hash, addr).await.unwrap();
532 store.mark_consumed(&hash, addr, epoch_ms()).await.unwrap();
533
534 let fresh_mirror: SharedTokenStateMirror = Arc::new(Mutex::new(HashMap::new()));
536 let ts = epoch_ms();
538 apply_token_transition_to_mirror(
539 &fresh_mirror,
540 hash,
541 &JoinTokenTransitionKind::Register {
542 expires_at_ms: far_future_ms(),
543 },
544 ts,
545 );
546 apply_token_transition_to_mirror(
547 &fresh_mirror,
548 hash,
549 &JoinTokenTransitionKind::BeginInFlight {
550 node_addr: addr.to_string(),
551 },
552 ts,
553 );
554 apply_token_transition_to_mirror(
555 &fresh_mirror,
556 hash,
557 &JoinTokenTransitionKind::MarkConsumed {
558 node_addr: addr.to_string(),
559 },
560 ts,
561 );
562
563 let proposer2 = MirroringProposer::new(fresh_mirror.clone());
565 let store2 = RaftBackedTokenStore::new(proposer2, fresh_mirror.clone());
566 assert_eq!(
567 store2.begin_inflight(&hash, addr).await.unwrap_err(),
568 TokenStateError::AlreadyConsumed,
569 "post-crash replay must be rejected"
570 );
571 }
572
573 #[tokio::test]
575 async fn in_memory_store_still_works_async() {
576 let store = InMemoryTokenStore::new();
577 let hash = [0xA6u8; 32];
578 let addr: SocketAddr = "127.0.0.1:9200".parse().unwrap();
579 store
580 .register(JoinTokenState {
581 token_hash: hash,
582 lifecycle: JoinTokenLifecycle::Issued,
583 expires_at_ms: far_future_ms(),
584 attempt: 0,
585 })
586 .await;
587 store.begin_inflight(&hash, addr).await.unwrap();
588 store.mark_consumed(&hash, addr, epoch_ms()).await.unwrap();
589 assert!(matches!(
590 store.get(&hash).unwrap().lifecycle,
591 JoinTokenLifecycle::Consumed { .. }
592 ));
593 }
594}
595
596