1use async_trait::async_trait;
18use std::cmp::Reverse;
19use std::collections::{BinaryHeap, HashMap};
20use std::hash::Hash;
21use std::sync::OnceLock;
22use tokio::sync::{mpsc, oneshot};
23use tokio::time::{Duration, Instant};
24use tokio_util::sync::CancellationToken;
25
26use crate::tokens::{SequenceHash, TokenBlockSequence};
27
28use crate::kv_router::indexer::{
29 DumpRequest, KvIndexerInterface, KvRouterError, OverlapScores, RadixTree, RouterEvent,
30 WorkerId, compute_block_hash_for_seq,
31};
32use crate::kv_router::protocols::{
33 ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
34 KvCacheStoredBlockData, LocalBlockHash,
35};
36
37#[derive(Debug)]
38struct MatchRequest {
39 sequence: Vec<LocalBlockHash>,
41 resp: oneshot::Sender<OverlapScores>,
43}
44
45#[derive(Debug)]
46struct RouterResult {
47 worker_id: WorkerId,
49
50 local_hashes: Vec<LocalBlockHash>,
52
53 sequence_hashes: Vec<u64>,
55}
56
57#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)]
58struct TimerEntry {
59 key: ExternalSequenceBlockHash,
61 worker: WorkerId,
63}
64
65#[derive(Debug)]
71struct TimerManager<K: Clone + Hash + Eq + Ord> {
72 timers: HashMap<K, Instant>,
74
75 expirations: BinaryHeap<Reverse<(Instant, K)>>,
79
80 ttl: Duration,
82
83 threshold: usize,
86}
87
88impl<K: Clone + Hash + Eq + Ord> TimerManager<K> {
89 pub fn new(ttl: Duration, threshold: usize) -> Self {
91 TimerManager {
92 timers: HashMap::new(),
93 expirations: BinaryHeap::new(),
94 ttl,
95 threshold,
96 }
97 }
98
99 fn rebuild_heap(&mut self) {
101 self.expirations = self
102 .timers
103 .iter()
104 .map(|(key, &expiry)| Reverse((expiry, key.clone())))
105 .collect();
106 }
107
108 pub fn insert(&mut self, keys: Vec<K>) {
114 let expiry_time = Instant::now() + self.ttl;
115
116 for key in keys {
117 self.timers.insert(key.clone(), expiry_time);
119
120 self.expirations.push(Reverse((expiry_time, key)));
124 }
125
126 if self.expirations.len() > self.timers.len() * self.threshold {
128 self.rebuild_heap();
129 }
130 }
131
132 pub fn pop_expired(&mut self) -> Vec<K> {
135 let mut expired_keys = Vec::new();
136 let now = Instant::now();
137
138 while let Some(Reverse((expiry_time, _))) = self.expirations.peek() {
139 if *expiry_time > now {
141 break;
142 }
143
144 let Reverse((expiry_time, key)) = self.expirations.pop().unwrap();
146
147 if self.timers.get(&key) == Some(&expiry_time) {
148 self.timers.remove(&key);
150 expired_keys.push(key);
151 }
152 }
153
154 expired_keys
155 }
156
157 pub fn peek_next_expiry(&self) -> Option<Instant> {
159 self.expirations
160 .peek()
161 .map(|Reverse((expiry_time, _))| *expiry_time)
162 }
163}
164
165pub struct ApproxKvIndexer {
166 cancel: CancellationToken,
168 match_tx: mpsc::Sender<MatchRequest>,
170 route_tx: mpsc::Sender<RouterResult>,
172 remove_worker_tx: mpsc::Sender<WorkerId>,
174 dump_tx: mpsc::Sender<DumpRequest>,
176 task: OnceLock<std::thread::JoinHandle<()>>,
178 kv_block_size: u32,
180}
181
182impl ApproxKvIndexer {
183 pub fn new(token: CancellationToken, kv_block_size: u32, ttl: Duration) -> Self {
184 let (match_tx, mut match_rx) = mpsc::channel::<MatchRequest>(2048);
185 let (route_tx, mut route_rx) = mpsc::channel::<RouterResult>(2048);
186 let (remove_worker_tx, mut remove_worker_rx) = mpsc::channel::<WorkerId>(16);
187 let (dump_tx, mut dump_rx) = mpsc::channel::<DumpRequest>(16);
188 let cancel_clone = token.clone();
189 let task = std::thread::spawn(move || {
190 let runtime = tokio::runtime::Builder::new_current_thread()
192 .enable_all()
193 .build()
194 .unwrap();
195
196 runtime.block_on(async move {
197 let mut trie = RadixTree::new();
198 let mut timer_manager: TimerManager<TimerEntry> = TimerManager::new(ttl, 50);
200 let mut event_id = 0;
201 loop {
202 let expiry_fut = if let Some(next_expiry) = timer_manager.peek_next_expiry() {
204 tokio::time::sleep_until(next_expiry)
205 } else {
206 tokio::time::sleep(Duration::MAX)
208 };
209
210 tokio::select! {
211 _ = cancel_clone.cancelled() => {
212 tracing::debug!("Approximate Indexer progress loop shutting down");
213 return;
214 }
215
216 Some(worker) = remove_worker_rx.recv() => {
217 trie.remove_worker(worker);
218 }
219
220 Some(result) = route_rx.recv() => {
221 let hashes = result.local_hashes.iter().zip(result.sequence_hashes.iter());
222
223 let stored_event = KvCacheEventData::Stored(KvCacheStoreData {
224 parent_hash: None,
225 blocks: hashes.map(|(local_hash, sequence_hash)| KvCacheStoredBlockData {
226 tokens_hash: *local_hash,
227 block_hash: ExternalSequenceBlockHash(*sequence_hash),
228 }).collect(),
229 });
230 event_id += 1;
231
232 let event = RouterEvent::new(
233 result.worker_id,
234 KvCacheEvent {
235 event_id,
236 data: stored_event,
237 }
238 );
239
240 let _ = trie.apply_event(event);
241
242 timer_manager.insert(result.sequence_hashes.iter().map(|h| TimerEntry {
243 key: ExternalSequenceBlockHash(*h),
244 worker: result.worker_id,
245 }).collect());
246 }
247
248 Some(dump_req) = dump_rx.recv() => {
249 let events = trie.dump_tree_as_events();
250 let _ = dump_req.resp.send(events);
251 }
252
253 Some(request) = match_rx.recv() => {
254 let scores = trie.find_matches(request.sequence, false);
255 request.resp.send(scores).unwrap();
256 }
257
258 _ = expiry_fut => {
259 let expired = timer_manager.pop_expired();
260
261 expired.iter().for_each(|e| {
262 event_id += 1;
263
264 let event = RouterEvent::new(
265 e.worker,
266 KvCacheEvent {
267 event_id,
268 data: KvCacheEventData::Removed(KvCacheRemoveData {
269 block_hashes: vec![e.key],
270 }),
271 }
272 );
273
274 let _ = trie.apply_event(event);
275 });
276 }
277 }
278 }
279 });
280 });
281
282 let once = OnceLock::new();
283 once.set(task).unwrap();
284
285 Self {
286 cancel: token,
287 match_tx,
288 route_tx,
289 remove_worker_tx,
290 dump_tx,
291 task: once,
292 kv_block_size,
293 }
294 }
295
296 pub fn block_size(&self) -> u32 {
297 self.kv_block_size
298 }
299
300 pub async fn process_routing_decision(
302 &self,
303 worker_id: WorkerId,
304 local_hashes: Vec<LocalBlockHash>,
305 sequence_hashes: Vec<SequenceHash>,
306 ) -> Result<(), KvRouterError> {
307 self.route_tx
308 .send(RouterResult {
309 worker_id,
310 local_hashes,
311 sequence_hashes,
312 })
313 .await
314 .map_err(|_| KvRouterError::IndexerDroppedRequest)?;
315
316 Ok(())
317 }
318
319 pub async fn process_routing_decision_for_request(
321 &self,
322 tokens: &[u32],
323 worker_id: WorkerId,
324 ) -> Result<(), KvRouterError> {
325 let local_hashes = compute_block_hash_for_seq(tokens, self.kv_block_size);
326
327 let sequence = TokenBlockSequence::new(tokens.into(), self.kv_block_size, None);
328 let sequence_hashes = sequence
329 .blocks()
330 .iter()
331 .map(|b| b.sequence_hash())
332 .collect::<Vec<_>>();
333
334 self.process_routing_decision(worker_id, local_hashes, sequence_hashes)
335 .await
336 }
337}
338
339#[async_trait]
340impl KvIndexerInterface for ApproxKvIndexer {
341 async fn find_matches(
342 &self,
343 sequence: Vec<LocalBlockHash>,
344 ) -> Result<OverlapScores, KvRouterError> {
345 let (resp_tx, resp_rx) = oneshot::channel();
346 let request = MatchRequest {
347 sequence,
348 resp: resp_tx,
349 };
350
351 if let Err(e) = self.match_tx.send(request).await {
352 tracing::error!(
353 "Failed to send match request: {:?}; the indexer maybe offline",
354 e
355 );
356 return Err(KvRouterError::IndexerOffline);
357 }
358
359 resp_rx
360 .await
361 .map_err(|_| KvRouterError::IndexerDroppedRequest)
362 }
363
364 async fn find_matches_for_request(
365 &self,
366 tokens: &[u32],
367 ) -> Result<OverlapScores, KvRouterError> {
368 let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size);
369 self.find_matches(sequence).await
370 }
371
372 async fn apply_event(&mut self, _event: RouterEvent) {
373 panic!("Approximate Indexer does not support apply_event");
374 }
375
376 async fn remove_worker(&mut self, worker: WorkerId) {
377 self.remove_worker_tx.send(worker).await.unwrap();
378 }
379
380 async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
381 let (resp_tx, resp_rx) = oneshot::channel();
382 let dump_req = DumpRequest { resp: resp_tx };
383
384 if let Err(e) = self.dump_tx.send(dump_req).await {
385 tracing::error!("Failed to send dump request: {:?}", e);
386 return Err(KvRouterError::IndexerOffline);
387 }
388
389 resp_rx
390 .await
391 .map_err(|_| KvRouterError::IndexerDroppedRequest)
392 }
393
394 fn shutdown(&mut self) {
395 self.cancel.cancel();
396 if let Some(task) = self.task.take() {
397 task.join()
398 .expect("Failed to join approximate indexer task");
399 }
400 }
401}
402
403impl Drop for ApproxKvIndexer {
404 fn drop(&mut self) {
405 self.shutdown();
406 }
407}
408
409#[cfg(test)]
410mod tests {
411 use super::*;
412
413 use tokio::time::{self, Duration, Instant};
414 use tokio_util::sync::CancellationToken;
415
416 const KV_BLOCK_SIZE: u32 = 4;
417
418 impl<T: Clone + Hash + Eq + Ord> TimerManager<T> {
419 pub fn get_expiry(&self, key: &T) -> Option<&Instant> {
420 self.timers.get(key)
421 }
422 }
423
424 async fn spin_until<F, Fut>(timeout: Duration, mut predicate: F)
426 where
427 F: FnMut() -> Fut,
428 Fut: std::future::Future<Output = bool>,
429 {
430 let start = Instant::now();
431 const POLL: Duration = Duration::from_millis(1);
432 loop {
433 if predicate().await {
434 return;
435 }
436 if Instant::now().duration_since(start) >= timeout {
437 panic!("timeout waiting for condition");
438 }
439 time::sleep(POLL).await;
440 }
441 }
442
443 #[tokio::test]
445 async fn test_timer_manager_expiry() {
446 const TTL: Duration = Duration::from_millis(50);
447 let mut tm: TimerManager<u32> = TimerManager::new(TTL, 50);
448
449 tm.insert(vec![1, 2, 3]);
450 assert!(tm.get_expiry(&1).is_some());
451 assert!(tm.get_expiry(&2).is_some());
452 assert!(tm.get_expiry(&3).is_some());
453
454 time::sleep(TTL + Duration::from_millis(20)).await;
456 let expired = tm.pop_expired();
457 assert_eq!(expired.len(), 3);
458 assert!(tm.get_expiry(&1).is_none());
459 assert!(tm.get_expiry(&2).is_none());
460 assert!(tm.get_expiry(&3).is_none());
461 }
462
463 #[tokio::test]
465 async fn test_timer_manager_update_resets_ttl() {
466 const TTL: Duration = Duration::from_millis(50);
468 let mut tm: TimerManager<u32> = TimerManager::new(TTL, 50);
469
470 tm.insert(vec![42]);
472 let first_expiry = *tm
473 .get_expiry(&42)
474 .expect("expiry missing after first insert");
475
476 time::sleep(Duration::from_millis(25)).await;
478 tm.insert(vec![42]);
479 let second_expiry = *tm
480 .get_expiry(&42)
481 .expect("expiry missing after reinsertion");
482
483 assert!(second_expiry > first_expiry);
485
486 time::sleep(Duration::from_millis(30)).await; let expired = tm.pop_expired();
489 assert!(
490 expired.is_empty(),
491 "key expired prematurely despite TTL refresh"
492 );
493
494 time::sleep(Duration::from_millis(30)).await; let expired_after = tm.pop_expired();
497 assert_eq!(expired_after, vec![42]);
498 }
499
500 #[tokio::test]
505 async fn test_approx_kv_indexer_basic_flow() {
506 const TTL: Duration = Duration::from_millis(200);
507 let cancel = CancellationToken::new();
508 let indexer = ApproxKvIndexer::new(cancel.clone(), KV_BLOCK_SIZE, TTL);
509
510 let tokens: Vec<u32> = vec![1, 2, 3, 4]; let worker_id: WorkerId = 0;
512
513 let pre_scores = indexer
515 .find_matches_for_request(&tokens)
516 .await
517 .expect("indexer offline");
518 assert!(pre_scores.scores.is_empty());
519
520 indexer
522 .process_routing_decision_for_request(&tokens, worker_id)
523 .await
524 .unwrap();
525
526 spin_until(Duration::from_millis(100), || async {
528 let s = indexer.find_matches_for_request(&tokens).await.unwrap();
529 s.scores.get(&worker_id).copied() == Some(1)
530 })
531 .await;
532
533 time::sleep(TTL + Duration::from_millis(50)).await;
535 let post_scores = indexer.find_matches_for_request(&tokens).await.unwrap();
536 assert!(post_scores.scores.is_empty());
537 }
538
539 #[tokio::test]
541 async fn test_remove_worker() {
542 const TTL: Duration = Duration::from_secs(5); let cancel = CancellationToken::new();
544 let mut indexer = ApproxKvIndexer::new(cancel.clone(), KV_BLOCK_SIZE, TTL);
545
546 let tokens: Vec<u32> = vec![10, 11, 12, 13];
547 let worker_id: WorkerId = 7;
548
549 indexer
550 .process_routing_decision_for_request(&tokens, worker_id)
551 .await
552 .unwrap();
553
554 spin_until(Duration::from_millis(100), || async {
556 let s = indexer.find_matches_for_request(&tokens).await.unwrap();
557 s.scores.contains_key(&worker_id)
558 })
559 .await;
560
561 indexer.remove_worker(worker_id).await;
563
564 spin_until(Duration::from_millis(100), || async {
566 let s = indexer.find_matches_for_request(&tokens).await.unwrap();
567 !s.scores.contains_key(&worker_id)
568 })
569 .await;
570 }
571
572 #[tokio::test]
574 async fn test_remove_worker_preserves_other_workers() {
575 const TTL: Duration = Duration::from_secs(5); let cancel = CancellationToken::new();
578 let mut indexer = ApproxKvIndexer::new(cancel.clone(), KV_BLOCK_SIZE, TTL);
579
580 let tokens: Vec<u32> = vec![100, 101, 102, 103];
581 let worker_0: WorkerId = 30;
582 let worker_1: WorkerId = 31;
583
584 indexer
586 .process_routing_decision_for_request(&tokens, worker_0)
587 .await
588 .unwrap();
589 indexer
590 .process_routing_decision_for_request(&tokens, worker_1)
591 .await
592 .unwrap();
593
594 spin_until(Duration::from_millis(100), || async {
596 let s = indexer.find_matches_for_request(&tokens).await.unwrap();
597 s.scores.get(&worker_0).copied() == Some(1)
598 && s.scores.get(&worker_1).copied() == Some(1)
599 })
600 .await;
601
602 indexer.remove_worker(worker_0).await;
604
605 spin_until(Duration::from_millis(100), || async {
607 let s = indexer.find_matches_for_request(&tokens).await.unwrap();
608 !s.scores.contains_key(&worker_0) && s.scores.get(&worker_1).copied() == Some(1)
609 })
610 .await;
611 }
612
613 #[tokio::test]
615 async fn test_common_prefix_overlap() {
616 const TTL: Duration = Duration::from_secs(5);
617
618 let cancel = CancellationToken::new();
619 let indexer = ApproxKvIndexer::new(cancel.clone(), KV_BLOCK_SIZE, TTL);
620
621 let seq_a: Vec<u32> = vec![1, 2, 3, 4];
623 let worker_a: WorkerId = 11;
624
625 indexer
627 .process_routing_decision_for_request(&seq_a, worker_a)
628 .await
629 .unwrap();
630
631 spin_until(Duration::from_millis(100), || async {
633 let s = indexer.find_matches_for_request(&seq_a).await.unwrap();
634 s.scores.get(&worker_a).copied() == Some(1)
635 })
636 .await;
637
638 let seq_b: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
640
641 let overlap = indexer.find_matches_for_request(&seq_b).await.unwrap();
643
644 assert_eq!(overlap.scores.get(&worker_a), Some(&1));
646 }
647
648 #[tokio::test]
650 async fn test_multiple_workers_same_block() {
651 const TTL: Duration = Duration::from_secs(5);
652
653 let cancel = CancellationToken::new();
654 let indexer = ApproxKvIndexer::new(cancel.clone(), KV_BLOCK_SIZE, TTL);
655
656 let tokens: Vec<u32> = vec![9, 8, 7, 6];
657 let worker_0: WorkerId = 21;
658 let worker_1: WorkerId = 22;
659
660 indexer
662 .process_routing_decision_for_request(&tokens, worker_0)
663 .await
664 .unwrap();
665 indexer
666 .process_routing_decision_for_request(&tokens, worker_1)
667 .await
668 .unwrap();
669
670 spin_until(Duration::from_millis(100), || async {
672 let s = indexer.find_matches_for_request(&tokens).await.unwrap();
673 s.scores.get(&worker_0).copied() == Some(1)
674 && s.scores.get(&worker_1).copied() == Some(1)
675 })
676 .await;
677
678 let scores = indexer.find_matches_for_request(&tokens).await.unwrap();
679
680 assert_eq!(scores.scores.get(&worker_0), Some(&1));
681 assert_eq!(scores.scores.get(&worker_1), Some(&1));
682 }
683}