dynamo_llm/kv_router/
sequence.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! KV Cache Sequence Management for LLM Inference
5//!
6//! This module provides efficient management of token sequences and their associated KV cache blocks
7//! for distributed LLM inference. It implements a shared block system where multiple requests can
8//! reuse the same KV cache blocks for common token prefixes, significantly reducing memory usage.
9//!
10//! # Key Components
11//!
12//! - [`ActiveSequences`]: Single-threaded sequence manager that tracks active requests and their
13//!   token sequences, managing shared KV cache blocks efficiently.
14//!
15//! - [`ActiveSequencesMultiWorker`]: Multi-threaded extension that distributes sequence management
16//!   across multiple worker threads, enabling parallel processing of requests while maintaining
17//!   consistency.
18//!
19//! # Architecture
20//!
21//! The system uses a block-based approach where token sequences are divided into fixed-size blocks.
22//! Each block is identified by a hash of its contents, allowing for deduplication when multiple
23//! requests share common prefixes (e.g., system prompts, few-shot examples).
24
25use crate::kv_router::indexer::OverlapScores;
26use crate::kv_router::indexer::WorkerId;
27use crate::tokens::SequenceHash;
28use anyhow::Result;
29use dashmap::DashMap;
30use derive_getters::Getters;
31use dynamo_runtime::component::Component;
32use dynamo_runtime::traits::DistributedRuntimeProvider;
33use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber};
34use futures::StreamExt;
35use std::collections::{HashMap, HashSet};
36use std::sync::Arc;
37use std::time::Duration;
38use tokio::time::Instant;
39use uuid::Uuid;
40
41use super::protocols::{ActiveSequenceEvent, ActiveSequenceEventData};
42use crate::kv_router::ACTIVE_SEQUENCES_SUBJECT;
43use dynamo_runtime::CancellationToken;
44
45/// Duration after which stale requests are forcibly expired (5 minutes)
46const EXPIRY_DURATION: Duration = Duration::from_secs(300);
47
48// TODO: use the common request_id if it exists in the repo
49pub type RequestId = String;
50
51/// A multi-request sequence manager that handles multiple active sequences with shared KV cache
52#[derive(Debug, Getters)]
53pub struct ActiveSequences {
54    active_seqs: HashMap<RequestId, Vec<SequenceHash>>,
55
56    prefill_tokens: HashMap<RequestId, usize>,
57
58    unique_blocks: HashMap<SequenceHash, HashSet<RequestId>>,
59
60    #[getter(copy)]
61    block_size: usize,
62
63    #[getter(copy)]
64    active_blocks: usize,
65
66    #[getter(copy)]
67    active_tokens: usize,
68
69    /// Timer for when to force expiry of stale requests
70    expiry_timer: Instant,
71
72    /// Set of request IDs to check for expiry
73    expiry_requests: HashSet<RequestId>,
74}
75
76impl ActiveSequences {
77    /// Create a new SharedSequenceManager instance
78    pub fn new(block_size: usize) -> Self {
79        // TODO: make this not a hard req
80        assert!(block_size > 1, "block_size must be greater than 1");
81
82        Self {
83            active_seqs: HashMap::new(),
84            prefill_tokens: HashMap::new(),
85            unique_blocks: HashMap::new(),
86            block_size,
87            active_blocks: 0,
88            active_tokens: 0,
89            expiry_timer: Instant::now() + EXPIRY_DURATION,
90            expiry_requests: HashSet::new(),
91        }
92    }
93
94    fn add_block(&mut self, request_id: RequestId, block: &SequenceHash) {
95        let is_new_block = !self.unique_blocks.contains_key(block);
96
97        self.unique_blocks
98            .entry(*block)
99            .or_default()
100            .insert(request_id.clone());
101
102        if is_new_block {
103            self.active_blocks += 1;
104        }
105    }
106
107    fn remove_block(&mut self, request_id: &RequestId, block: &SequenceHash) {
108        let Some(request_ids) = self.unique_blocks.get_mut(block) else {
109            return;
110        };
111
112        // Remove the unique block if no more requests using it
113        request_ids.retain(|w| w != request_id);
114        if request_ids.is_empty() {
115            self.active_blocks -= 1;
116            self.unique_blocks.remove(block);
117        }
118    }
119
120    /// Add a new request with its initial tokens
121    /// Returns the set of expired request IDs that were removed during cleanup
122    pub fn add_request(
123        &mut self,
124        request_id: RequestId,
125        token_sequence: Option<Vec<SequenceHash>>,
126        isl: usize,
127        overlap: u32,
128    ) -> HashSet<RequestId> {
129        // Check for double-add and panic early
130        if self.active_seqs.contains_key(&request_id) {
131            panic!("Request {request_id} is already active. Cannot accept double-add.");
132        }
133
134        // Lazily check and clean up expired requests, capturing removed IDs
135        let removed_requests = self.force_expiry();
136
137        let prefill_tokens = self.new_tokens(isl, overlap);
138        self.prefill_tokens
139            .insert(request_id.clone(), prefill_tokens);
140        self.active_tokens += prefill_tokens;
141
142        if let Some(sequence) = token_sequence {
143            for block in &sequence {
144                self.add_block(request_id.clone(), block);
145            }
146            self.active_seqs.insert(request_id.clone(), sequence);
147        } else {
148            // dummy empty sequence
149            self.active_seqs.insert(request_id.clone(), Vec::new());
150        }
151
152        removed_requests
153    }
154
155    /// Mark prefill as completed for a request, removing it from prefill_tokens tracking
156    pub fn mark_prefill_completed(&mut self, request_id: &RequestId) {
157        if let Some(tokens) = self.prefill_tokens.remove(request_id) {
158            self.active_tokens = self
159                .active_tokens
160                .checked_sub(tokens)
161                .expect("active_tokens underflow");
162        }
163    }
164
165    pub fn new_tokens(&self, isl: usize, overlap: u32) -> usize {
166        isl.checked_sub((overlap as usize) * self.block_size)
167            .unwrap_or_else(|| panic!("prefill_tokens < 0 with overlap {overlap} and ISL {isl}"))
168    }
169
170    pub fn potential_blocks_and_tokens(
171        &self,
172        token_sequence: Option<&[SequenceHash]>,
173        isl: usize,
174        overlap: u32,
175    ) -> (usize, usize) {
176        let potential_blocks = if let Some(token_seq) = token_sequence {
177            self.new_blocks(token_seq) + self.active_blocks
178        } else {
179            self.active_blocks
180        };
181        let potential_tokens = self.new_tokens(isl, overlap) + self.active_tokens;
182        (potential_blocks, potential_tokens)
183    }
184
185    /// Match a request against existing blocks and return the number of new blocks that would be added
186    pub fn new_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
187        token_sequence
188            .iter()
189            .filter(|block| !self.unique_blocks.contains_key(block))
190            .count()
191    }
192
193    /// Return the total number of blocks that would be used if the token sequence was added
194    /// This is the sum of new blocks that would be added plus the current active blocks
195    pub fn potential_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
196        self.new_blocks(token_sequence) + self.active_blocks
197    }
198
199    /// Free all blocks associated with a request
200    pub fn free(&mut self, request_id: &RequestId) -> usize {
201        self.mark_prefill_completed(request_id);
202
203        self.expiry_requests.remove(request_id);
204
205        // Remove from active_seqs and get the token sequence
206        let token_seq = match self.active_seqs.remove(request_id) {
207            Some(seq) => seq,
208            None => {
209                tracing::warn!("Trying to free non-existent request {request_id}");
210                return self.active_blocks;
211            }
212        };
213
214        for block in token_seq {
215            self.remove_block(request_id, &block)
216        }
217
218        self.active_blocks
219    }
220
221    /// Force expiry of stale requests if the timer has elapsed
222    /// Returns the set of expired request IDs that were removed
223    pub fn force_expiry(&mut self) -> HashSet<RequestId> {
224        let now = Instant::now();
225
226        // Early return if timer hasn't expired yet
227        if now < self.expiry_timer {
228            return HashSet::new();
229        }
230
231        // Process expired requests - drain to avoid clone
232        let expired_requests: HashSet<RequestId> = self.expiry_requests.drain().collect();
233        for request_id in &expired_requests {
234            tracing::warn!("Force expiring stale request: {}", request_id);
235            self.free(request_id);
236        }
237
238        self.expiry_timer = now + EXPIRY_DURATION;
239        self.expiry_requests = self.active_seqs.keys().cloned().collect();
240
241        expired_requests
242    }
243}
244
245enum UpdateSequences {
246    AddRequest {
247        request_id: RequestId,
248        token_sequence: Option<Vec<SequenceHash>>,
249        isl: usize,
250        overlap: u32,
251        resp_tx: tokio::sync::oneshot::Sender<HashSet<RequestId>>,
252    },
253    Free {
254        request_id: RequestId,
255    },
256    MarkPrefillCompleted {
257        request_id: RequestId,
258    },
259    NewBlocks {
260        token_sequence: Arc<Vec<SequenceHash>>,
261        resp_tx: tokio::sync::oneshot::Sender<usize>,
262    },
263    PotentialBlocks {
264        token_sequence: Arc<Vec<SequenceHash>>,
265        resp_tx: tokio::sync::oneshot::Sender<usize>,
266    },
267    PotentialBlocksAndTokens {
268        token_sequence: Option<Arc<Vec<SequenceHash>>>,
269        isl: usize,
270        overlap: u32,
271        resp_tx: tokio::sync::oneshot::Sender<(usize, usize)>,
272    },
273    ActiveBlocks {
274        resp_tx: tokio::sync::oneshot::Sender<usize>,
275    },
276    ActiveTokens {
277        resp_tx: tokio::sync::oneshot::Sender<usize>,
278    },
279    Shutdown,
280}
281
282/// Multi-worker extension of ActiveSequences that distributes requests across multiple threads
283pub struct ActiveSequencesMultiWorker {
284    senders: Arc<DashMap<WorkerId, tokio::sync::mpsc::UnboundedSender<UpdateSequences>>>,
285    request_to_worker: Arc<DashMap<RequestId, WorkerId>>,
286    handles: Arc<DashMap<WorkerId, tokio::task::JoinHandle<()>>>,
287    block_size: usize,
288    component: Component,
289    router_id: Uuid,
290    replica_sync: bool,
291}
292
293impl ActiveSequencesMultiWorker {
294    pub fn new(
295        component: Component,
296        block_size: usize,
297        worker_ids: Vec<WorkerId>,
298        replica_sync: bool,
299        router_uuid: String,
300    ) -> Self {
301        assert!(block_size > 1, "block_size must be greater than 1");
302
303        let senders = Arc::new(DashMap::new());
304        let handles = Arc::new(DashMap::new());
305        let request_to_worker = Arc::new(DashMap::new());
306        let router_id = Uuid::parse_str(&router_uuid).unwrap_or_else(|e| {
307            tracing::warn!(
308                "Failed to parse router UUID '{}': {}, using new UUID",
309                router_uuid,
310                e
311            );
312            Uuid::new_v4()
313        });
314
315        for worker_id in worker_ids {
316            // Create a child cancellation token from the component's runtime
317            let cancel_token = component.drt().runtime().child_token();
318            let (sender, handle) = Self::start_worker(block_size, cancel_token);
319            senders.insert(worker_id, sender);
320            handles.insert(worker_id, handle);
321        }
322
323        let multi_worker = Self {
324            senders: senders.clone(),
325            request_to_worker: request_to_worker.clone(),
326            handles,
327            block_size,
328            component: component.clone(),
329            router_id,
330            replica_sync,
331        };
332
333        // Start the subscription loop only if replica_sync is enabled
334        if replica_sync {
335            let senders_clone = senders.clone();
336            let request_to_worker_clone = request_to_worker.clone();
337            let component_clone = component.clone();
338            let router_id_clone = router_id;
339            let cancel_token = component.drt().runtime().child_token();
340
341            tokio::spawn(async move {
342                // NATS subscription loop
343                if let Err(e) = Self::subscribe_to_events(
344                    senders_clone,
345                    request_to_worker_clone,
346                    component_clone,
347                    router_id_clone,
348                    cancel_token,
349                )
350                .await
351                {
352                    tracing::error!("Error in active sequences events subscription: {}", e);
353                }
354            });
355        }
356
357        multi_worker
358    }
359
360    /// Helper method to start a worker task
361    fn start_worker(
362        block_size: usize,
363        cancel_token: CancellationToken, // Add cancellation token parameter
364    ) -> (
365        tokio::sync::mpsc::UnboundedSender<UpdateSequences>,
366        tokio::task::JoinHandle<()>,
367    ) {
368        let (request_tx, mut request_rx) = tokio::sync::mpsc::unbounded_channel();
369
370        let handle = tokio::spawn(async move {
371            let mut active_sequences = ActiveSequences::new(block_size);
372
373            loop {
374                tokio::select! {
375                    // Handle incoming commands
376                    command = request_rx.recv() => {
377                        match command {
378                            Some(command) => {
379                                match command {
380                                    UpdateSequences::AddRequest {
381                                        request_id,
382                                        token_sequence,
383                                        isl,
384                                        overlap,
385                                        resp_tx,
386                                    } => {
387                                        let removed = active_sequences.add_request(request_id, token_sequence, isl, overlap);
388                                        let _ = resp_tx.send(removed);
389                                    }
390                                    UpdateSequences::Free { request_id } => {
391                                        active_sequences.free(&request_id);
392                                    }
393                                    UpdateSequences::MarkPrefillCompleted { request_id } => {
394                                        active_sequences.mark_prefill_completed(&request_id);
395                                    }
396                                    UpdateSequences::NewBlocks {
397                                        token_sequence,
398                                        resp_tx,
399                                    } => {
400                                        let new_blocks = active_sequences.new_blocks(&token_sequence);
401                                        let _ = resp_tx.send(new_blocks);
402                                    }
403                                    UpdateSequences::PotentialBlocks {
404                                        token_sequence,
405                                        resp_tx,
406                                    } => {
407                                        let potential_blocks = active_sequences.potential_blocks(&token_sequence);
408                                        let _ = resp_tx.send(potential_blocks);
409                                    }
410                                    UpdateSequences::PotentialBlocksAndTokens {
411                                        token_sequence,
412                                        isl,
413                                        overlap,
414                                        resp_tx,
415                                    } => {
416                                        let potential_tokens = active_sequences.potential_blocks_and_tokens(
417                                            token_sequence.as_ref().map(|v| v.as_slice()),
418                                            isl,
419                                            overlap,
420                                        );
421                                        let _ = resp_tx.send(potential_tokens);
422                                    }
423                                    UpdateSequences::ActiveBlocks { resp_tx } => {
424                                        let active_blocks = active_sequences.active_blocks();
425                                        let _ = resp_tx.send(active_blocks);
426                                    }
427                                    UpdateSequences::ActiveTokens { resp_tx } => {
428                                        let active_tokens = active_sequences.active_tokens();
429                                        let _ = resp_tx.send(active_tokens);
430                                    }
431                                    UpdateSequences::Shutdown => {
432                                        break;
433                                    }
434                                }
435                            }
436                            None => {
437                                // Channel closed, exit
438                                break;
439                            }
440                        }
441                    }
442                    // Handle cancellation
443                    _ = cancel_token.cancelled() => {
444                        tracing::debug!("Worker task cancelled");
445                        break;
446                    }
447                }
448            }
449        });
450
451        (request_tx, handle)
452    }
453
454    /// Background task to subscribe to active sequence events and update all workers
455    async fn subscribe_to_events(
456        senders: Arc<DashMap<WorkerId, tokio::sync::mpsc::UnboundedSender<UpdateSequences>>>,
457        request_to_worker: Arc<DashMap<RequestId, WorkerId>>,
458        component: Component,
459        router_id: Uuid,
460        cancel_token: CancellationToken,
461    ) -> Result<()> {
462        let mut subscriber = component
463            .subscribe_with_type::<ActiveSequenceEvent>(ACTIVE_SEQUENCES_SUBJECT)
464            .await?;
465
466        loop {
467            tokio::select! {
468                // Handle incoming events
469                result = subscriber.next() => {
470                    let Some(result) = result else {
471                        // Stream ended
472                        break;
473                    };
474
475                    let Ok(event) = result else {
476                        tracing::error!(
477                            "Error receiving active sequence event: {}",
478                            result.unwrap_err()
479                        );
480                        continue;
481                    };
482
483                    // Skip events emitted by itself
484                    if event.router_id == router_id {
485                        continue;
486                    }
487
488                    match &event.data {
489                        ActiveSequenceEventData::AddRequest {
490                            token_sequence,
491                            isl,
492                            overlap,
493                        } => {
494                            request_to_worker.insert(event.request_id.clone(), event.worker_id);
495
496                            if let Some(sender) = senders.get(&event.worker_id) {
497                                // For replicated events, we create a dummy response channel since we don't need to handle expired requests
498                                let (resp_tx, _) = tokio::sync::oneshot::channel();
499                                let _ = sender.send(UpdateSequences::AddRequest {
500                                    request_id: event.request_id.clone(),
501                                    token_sequence: token_sequence.clone(),
502                                    isl: *isl,
503                                    overlap: *overlap,
504                                    resp_tx,
505                                });
506                            } else {
507                                tracing::warn!(
508                                    "Worker {} not found, cannot process AddRequest",
509                                    event.worker_id
510                                );
511                            }
512                        }
513                        ActiveSequenceEventData::Free => {
514                            if let Some((_, worker_id)) = request_to_worker.remove(&event.request_id)
515                                && let Some(sender) = senders.get(&worker_id)
516                            {
517                                let _ = sender.send(UpdateSequences::Free {
518                                    request_id: event.request_id.clone(),
519                                });
520                            }
521                        }
522                        ActiveSequenceEventData::MarkPrefillCompleted => {
523                            if let Some(worker_id) = request_to_worker.get(&event.request_id)
524                                && let Some(sender) = senders.get(&*worker_id)
525                            {
526                                let _ = sender.send(UpdateSequences::MarkPrefillCompleted {
527                                    request_id: event.request_id.clone(),
528                                });
529                            }
530                        }
531                    }
532                }
533                // Handle cancellation
534                _ = cancel_token.cancelled() => {
535                    tracing::debug!("Subscription task cancelled");
536                    break;
537                }
538            }
539        }
540
541        Ok(())
542    }
543
544    /// Update the set of workers, adding and removing as needed
545    pub fn update_workers(&self, new_worker_ids: Vec<WorkerId>) {
546        let current_workers: HashSet<WorkerId> =
547            self.senders.iter().map(|entry| *entry.key()).collect();
548        let new_workers: HashSet<WorkerId> = new_worker_ids.into_iter().collect();
549
550        let workers_to_remove: Vec<WorkerId> =
551            current_workers.difference(&new_workers).copied().collect();
552        let workers_to_add: Vec<WorkerId> =
553            new_workers.difference(&current_workers).copied().collect();
554
555        // Remove workers
556        for worker_id in &workers_to_remove {
557            tracing::warn!("Removing worker {}", worker_id);
558
559            // Send shutdown command to the worker
560            if let Some((_, sender)) = self.senders.remove(worker_id) {
561                let _ = sender.send(UpdateSequences::Shutdown);
562            }
563            if let Some((_, handle)) = self.handles.remove(worker_id) {
564                handle.abort();
565            }
566
567            // Clean up request_to_worker mappings for this worker
568            self.request_to_worker
569                .retain(|_request_id, mapped_worker_id| *mapped_worker_id != *worker_id);
570        }
571
572        // Add new workers
573        for worker_id in &workers_to_add {
574            tracing::warn!("Adding worker {}", worker_id);
575
576            let (sender, handle) = Self::start_worker(
577                self.block_size,
578                self.component.drt().runtime().child_token(),
579            );
580            self.senders.insert(*worker_id, sender);
581            self.handles.insert(*worker_id, handle);
582        }
583    }
584
585    pub async fn add_request(
586        &self,
587        request_id: RequestId,
588        token_sequence: Option<Vec<SequenceHash>>,
589        isl: usize,
590        overlap: u32,
591        worker_id: WorkerId,
592    ) -> Result<()> {
593        if !self.senders.contains_key(&worker_id) {
594            return Err(anyhow::anyhow!("Worker ID {worker_id} not found"));
595        }
596
597        // Create response channel
598        let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
599
600        // Publish event only if replica_sync is enabled
601        if self.replica_sync {
602            let event = ActiveSequenceEvent {
603                request_id: request_id.clone(),
604                worker_id,
605                data: ActiveSequenceEventData::AddRequest {
606                    token_sequence: token_sequence.clone(),
607                    isl,
608                    overlap,
609                },
610                router_id: self.router_id,
611            };
612            self.component
613                .publish(ACTIVE_SEQUENCES_SUBJECT, &event)
614                .await?;
615        }
616
617        // Update local state
618        self.request_to_worker.insert(request_id.clone(), worker_id);
619
620        self.senders
621            .get(&worker_id)
622            .unwrap()
623            .send(UpdateSequences::AddRequest {
624                request_id,
625                token_sequence,
626                isl,
627                overlap,
628                resp_tx,
629            })
630            .map_err(|_| anyhow::anyhow!("Failed to send add_request command to worker"))?;
631
632        // Wait for response and handle removed requests
633        let removed_requests = resp_rx
634            .await
635            .map_err(|_| anyhow::anyhow!("Failed to receive response from worker"))?;
636
637        // Remove expired requests from request_to_worker mapping
638        for expired_id in &removed_requests {
639            self.request_to_worker.remove(expired_id);
640        }
641
642        Ok(())
643    }
644
645    pub async fn free(&self, request_id: &RequestId) -> Result<()> {
646        let worker_id = self
647            .request_to_worker
648            .get(request_id)
649            .map(|entry| *entry)
650            .ok_or_else(|| anyhow::anyhow!("Request ID not found in request_to_worker mapping"))?;
651
652        // Publish event only if replica_sync is enabled
653        if self.replica_sync {
654            let event = ActiveSequenceEvent {
655                request_id: request_id.clone(),
656                worker_id,
657                data: ActiveSequenceEventData::Free,
658                router_id: self.router_id,
659            };
660            self.component
661                .publish(ACTIVE_SEQUENCES_SUBJECT, &event)
662                .await?;
663        }
664
665        // Update local state
666        self.senders
667            .get(&worker_id)
668            .unwrap()
669            .send(UpdateSequences::Free {
670                request_id: request_id.clone(),
671            })
672            .map_err(|_| anyhow::anyhow!("Failed to send free command to worker"))?;
673
674        self.request_to_worker.remove(request_id);
675
676        Ok(())
677    }
678
679    /// Mark prefill as completed for a request
680    pub async fn mark_prefill_completed(&self, request_id: &RequestId) -> Result<()> {
681        let worker_id = self
682            .request_to_worker
683            .get(request_id)
684            .map(|entry| *entry)
685            .ok_or_else(|| anyhow::anyhow!("Request ID not found in request_to_worker mapping"))?;
686
687        // Publish event only if replica_sync is enabled
688        if self.replica_sync {
689            let event = ActiveSequenceEvent {
690                request_id: request_id.clone(),
691                worker_id,
692                data: ActiveSequenceEventData::MarkPrefillCompleted,
693                router_id: self.router_id,
694            };
695            self.component
696                .publish(ACTIVE_SEQUENCES_SUBJECT, &event)
697                .await?;
698        }
699
700        // Update local state
701        self.senders
702            .get(&worker_id)
703            .unwrap()
704            .send(UpdateSequences::MarkPrefillCompleted {
705                request_id: request_id.clone(),
706            })
707            .map_err(|_| {
708                anyhow::anyhow!("Failed to send mark_prefill_completed command to worker")
709            })?;
710
711        Ok(())
712    }
713
714    /// Get the number of workers
715    pub fn num_workers(&self) -> usize {
716        self.senders.len()
717    }
718
719    /// Generic method to query all workers with a given command
720    async fn query_workers<T: Send + 'static>(
721        &self,
722        token_sequence: Option<Vec<SequenceHash>>,
723        command_fn: impl Fn(
724            Option<Arc<Vec<SequenceHash>>>,
725            tokio::sync::oneshot::Sender<T>,
726        ) -> UpdateSequences,
727    ) -> HashMap<WorkerId, T> {
728        let mut results = HashMap::new();
729        let token_sequence_shared = token_sequence.map(Arc::new);
730        let mut receivers = Vec::new();
731
732        // Send queries to all workers in parallel
733        for entry in self.senders.iter() {
734            let worker_id = *entry.key();
735            let sender = entry.value();
736            let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
737            receivers.push((worker_id, resp_rx));
738            if let Err(e) = sender.send(command_fn(token_sequence_shared.clone(), resp_tx)) {
739                tracing::error!("Failed to send command to worker {}: {}", worker_id, e);
740            }
741        }
742
743        // Collect results from all workers
744        for (worker_id, receiver) in receivers {
745            match tokio::time::timeout(tokio::time::Duration::from_secs(1), receiver).await {
746                Ok(Ok(result)) => {
747                    results.insert(worker_id, result);
748                }
749                Ok(Err(_)) => {
750                    tracing::error!("Worker {} dropped response channel", worker_id);
751                }
752                Err(_) => {
753                    tracing::error!("Timeout waiting for response from worker {}", worker_id);
754                }
755            }
756        }
757
758        results
759    }
760
761    /// Query all workers for the number of new blocks that would be added by a token sequence
762    pub async fn new_blocks(&self, token_sequence: Vec<SequenceHash>) -> HashMap<WorkerId, usize> {
763        self.query_workers(Some(token_sequence), |ts, resp_tx| match ts {
764            Some(ts) => UpdateSequences::NewBlocks {
765                token_sequence: ts,
766                resp_tx,
767            },
768            None => unreachable!("token_sequence should always be Some for new_blocks"),
769        })
770        .await
771    }
772
773    /// Query all workers for the total number of blocks (new + active) that would be used by a token sequence
774    pub async fn potential_blocks(
775        &self,
776        token_sequence: Vec<SequenceHash>,
777    ) -> HashMap<WorkerId, usize> {
778        self.query_workers(Some(token_sequence), |ts, resp_tx| match ts {
779            Some(ts) => UpdateSequences::PotentialBlocks {
780                token_sequence: ts,
781                resp_tx,
782            },
783            None => unreachable!("token_sequence should always be Some for potential_blocks"),
784        })
785        .await
786    }
787
788    /// Query all workers for the potential tokens (new + active) that would be used by a token sequence with overlap
789    pub async fn potential_blocks_and_tokens(
790        &self,
791        token_sequence: Option<Vec<SequenceHash>>,
792        isl: usize,
793        overlaps: OverlapScores,
794    ) -> (HashMap<WorkerId, usize>, HashMap<WorkerId, usize>) {
795        let mut potential_blocks = HashMap::new();
796        let mut potential_tokens = HashMap::new();
797        let token_sequence_shared = token_sequence.map(Arc::new);
798        let mut receivers = Vec::new();
799
800        // Send queries to all workers in parallel
801        for entry in self.senders.iter() {
802            let worker_id = *entry.key();
803            let sender = entry.value();
804            let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
805            receivers.push((worker_id, resp_rx));
806
807            if let Err(e) = sender.send(UpdateSequences::PotentialBlocksAndTokens {
808                token_sequence: token_sequence_shared.clone(),
809                isl,
810                overlap: overlaps.scores.get(&worker_id).copied().unwrap_or(0),
811                resp_tx,
812            }) {
813                tracing::error!(
814                    "Failed to send potential_tokens command to worker {}: {}",
815                    worker_id,
816                    e
817                );
818            }
819        }
820
821        // Collect results from all workers
822        for (worker_id, receiver) in receivers {
823            match tokio::time::timeout(tokio::time::Duration::from_secs(1), receiver).await {
824                Ok(Ok((blocks, tokens))) => {
825                    potential_blocks.insert(worker_id, blocks);
826                    potential_tokens.insert(worker_id, tokens);
827                }
828                Ok(Err(_)) => {
829                    tracing::error!("Worker {} dropped response channel", worker_id);
830                }
831                Err(_) => {
832                    tracing::error!("Timeout waiting for response from worker {}", worker_id);
833                }
834            }
835        }
836
837        (potential_blocks, potential_tokens)
838    }
839
840    /// Query all workers for their current number of active blocks
841    pub async fn active_blocks(&self) -> HashMap<WorkerId, usize> {
842        self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveBlocks { resp_tx })
843            .await
844    }
845
846    /// Query all workers for their current number of active tokens
847    pub async fn active_tokens(&self) -> HashMap<WorkerId, usize> {
848        self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveTokens { resp_tx })
849            .await
850    }
851}
852
853impl Drop for ActiveSequencesMultiWorker {
854    fn drop(&mut self) {
855        // Send shutdown to all workers
856        for entry in self.senders.iter() {
857            let _ = entry.value().send(UpdateSequences::Shutdown);
858        }
859
860        // Abort all tasks
861        for entry in self.handles.iter() {
862            entry.value().abort();
863        }
864    }
865}
866
867#[cfg(test)]
868mod tests {
869    use super::*;
870    use dynamo_runtime::{DistributedRuntime, Runtime};
871    use std::sync::Arc;
872
873    #[tokio::test]
874    #[ignore]
875    async fn test_multi_worker_cross_instance_sync() -> Result<()> {
876        // Initialize logging once
877        dynamo_runtime::logging::init();
878
879        let block_size = 4; // arbitrary block size
880
881        // Create runtime and distributed runtime
882        let runtime = Runtime::from_current()?;
883        let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
884
885        // Create namespace and shared component for both seq_managers
886        let namespace = distributed.namespace("test_cross_instance_sync")?;
887        let component = namespace
888            .component("sequences")?
889            .service_builder()
890            .create()
891            .await?;
892
893        // Create multi-worker sequence managers with ALL workers [0, 1, 2]
894        // Both use the same component to ensure event synchronization works
895        let worker_ids = vec![0, 1, 2];
896        let seq_manager_1 = Arc::new(ActiveSequencesMultiWorker::new(
897            component.clone(),
898            block_size,
899            worker_ids.clone(),
900            true,
901            Uuid::new_v4().to_string(),
902        ));
903        let seq_manager_2 = Arc::new(ActiveSequencesMultiWorker::new(
904            component,
905            block_size,
906            worker_ids,
907            true,
908            Uuid::new_v4().to_string(),
909        ));
910
911        // Give some time for the subscription loops to start
912        tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
913
914        // PHASE 1: Add requests using both seq_manager_1 and seq_manager_2
915
916        // Add request_0 to worker 0: sequence [0, 1, 2]
917        seq_manager_1
918            .add_request(
919                "request_0".to_string(),
920                Some(vec![0, 1, 2]),
921                12, // ISL (3 blocks * 4 block_size)
922                0,  // no overlap
923                0,  // worker_id
924            )
925            .await?;
926
927        // Add request_1 to worker 1: sequence [3, 4]
928        seq_manager_1
929            .add_request(
930                "request_1".to_string(),
931                Some(vec![3, 4]),
932                8, // ISL (2 blocks * 4 block_size)
933                0, // no overlap
934                1, // worker_id
935            )
936            .await?;
937
938        // Add request_2 to worker 2: sequence [0, 1, 2, 3] using seq_manager_2
939        seq_manager_2
940            .add_request(
941                "request_2".to_string(),
942                Some(vec![0, 1, 2, 3]),
943                16, // ISL (4 blocks * 4 block_size)
944                0,  // no overlap
945                2,  // worker_id
946            )
947            .await?;
948
949        // Give some time for synchronization
950        tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
951
952        // Query seq_manager_1 to verify it sees all requests including request_2 from seq_manager_2
953        let blocks_phase1 = seq_manager_1.active_blocks().await;
954        let tokens_phase1 = seq_manager_1.active_tokens().await;
955
956        // Verify that seq_manager_1 sees all requests including request_2 from thread 2
957        assert_eq!(
958            blocks_phase1[&0], 3,
959            "Worker 0 should have 3 active blocks (from request_0)"
960        );
961        assert_eq!(
962            blocks_phase1[&1], 2,
963            "Worker 1 should have 2 active blocks (from request_1)"
964        );
965        assert_eq!(
966            blocks_phase1[&2], 4,
967            "Worker 2 should have 4 active blocks (from request_2 added by seq_manager_2)"
968        );
969        assert_eq!(
970            tokens_phase1[&0], 12,
971            "Worker 0 should have 12 active tokens"
972        );
973        assert_eq!(tokens_phase1[&1], 8, "Worker 1 should have 8 active tokens");
974        assert_eq!(
975            tokens_phase1[&2], 16,
976            "Worker 2 should have 16 active tokens (from request_2 added by seq_manager_2)"
977        );
978
979        // PHASE 2: Free requests using opposite sequence managers, verify on seq_manager_2
980
981        // Free request_2 (which was added by seq_manager_2) using seq_manager_1
982        seq_manager_1.free(&"request_2".to_string()).await?;
983
984        // Free request_0 and request_1 (which were added by seq_manager_1) using seq_manager_2
985        seq_manager_2.free(&"request_0".to_string()).await?;
986        seq_manager_2.free(&"request_1".to_string()).await?;
987
988        // Give some time for synchronization
989        tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
990
991        // Query seq_manager_2 to verify everything is empty
992        let blocks_phase2 = seq_manager_2.active_blocks().await;
993        let tokens_phase2 = seq_manager_2.active_tokens().await;
994
995        // Verify phase 2 results - everything should be empty
996        for worker_id in 0..=2 {
997            assert_eq!(
998                blocks_phase2[&worker_id], 0,
999                "Worker {} should have 0 active blocks after all requests freed",
1000                worker_id
1001            );
1002            assert_eq!(
1003                tokens_phase2[&worker_id], 0,
1004                "Worker {} should have 0 active tokens after all requests freed",
1005                worker_id
1006            );
1007        }
1008
1009        Ok(())
1010    }
1011
1012    #[tokio::test]
1013    #[ignore]
1014    async fn test_multi_worker_no_token_sequence_sync() -> Result<()> {
1015        // Initialize logging once
1016        dynamo_runtime::logging::init();
1017
1018        let block_size = 4; // arbitrary block size
1019
1020        // Create runtime and distributed runtime
1021        let runtime = Runtime::from_current()?;
1022        let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
1023
1024        // Create namespace and shared component for both seq_managers
1025        let namespace = distributed.namespace("test_no_token_seq_sync")?;
1026        let component = namespace
1027            .component("sequences")?
1028            .service_builder()
1029            .create()
1030            .await?;
1031
1032        // Create multi-worker sequence managers with ALL workers [0, 1, 2]
1033        // Both use the same component to ensure event synchronization works
1034        let worker_ids = vec![0, 1, 2];
1035        let seq_manager_1 = Arc::new(ActiveSequencesMultiWorker::new(
1036            component.clone(),
1037            block_size,
1038            worker_ids.clone(),
1039            true,
1040            Uuid::new_v4().to_string(),
1041        ));
1042        let seq_manager_2 = Arc::new(ActiveSequencesMultiWorker::new(
1043            component,
1044            block_size,
1045            worker_ids,
1046            true,
1047            Uuid::new_v4().to_string(),
1048        ));
1049
1050        // Give some time for the subscription loops to start
1051        tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
1052
1053        // PHASE 1: Add requests (without token sequences) using both seq_managers
1054
1055        // Add request_0 to worker 0 with no token sequence
1056        seq_manager_1
1057            .add_request(
1058                "request_0".to_string(),
1059                None, // No token sequence
1060                12,   // ISL (12 tokens)
1061                0,    // no overlap
1062                0,    // worker_id
1063            )
1064            .await?;
1065
1066        // Add request_1 to worker 1 with no token sequence
1067        seq_manager_1
1068            .add_request(
1069                "request_1".to_string(),
1070                None, // No token sequence
1071                8,    // ISL (8 tokens)
1072                0,    // no overlap
1073                1,    // worker_id
1074            )
1075            .await?;
1076
1077        // Add request_2 to worker 2 with no token sequence using seq_manager_2
1078        seq_manager_2
1079            .add_request(
1080                "request_2".to_string(),
1081                None, // No token sequence
1082                16,   // ISL (16 tokens)
1083                0,    // no overlap
1084                2,    // worker_id
1085            )
1086            .await?;
1087
1088        // Give some time for synchronization
1089        tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
1090
1091        // Query seq_manager_1 to verify it sees all requests including request_2 from seq_manager_2
1092        let tokens_phase1 = seq_manager_1.active_tokens().await;
1093
1094        // Verify that seq_manager_1 sees all requests including request_2 from thread 2
1095        assert_eq!(
1096            tokens_phase1[&0], 12,
1097            "Worker 0 should have 12 active tokens"
1098        );
1099        assert_eq!(tokens_phase1[&1], 8, "Worker 1 should have 8 active tokens");
1100        assert_eq!(
1101            tokens_phase1[&2], 16,
1102            "Worker 2 should have 16 active tokens (from request_2 added by seq_manager_2)"
1103        );
1104
1105        // PHASE 2: Free requests using opposite sequence managers, verify on seq_manager_2
1106
1107        // Mark prefill completed and free request_2 (which was added by seq_manager_2) using seq_manager_1
1108        seq_manager_1
1109            .mark_prefill_completed(&"request_2".to_string())
1110            .await?;
1111        seq_manager_1.free(&"request_2".to_string()).await?;
1112
1113        // Mark prefill completed and free requests 0 and 1 (which were added by seq_manager_1) using seq_manager_2
1114        seq_manager_2
1115            .mark_prefill_completed(&"request_0".to_string())
1116            .await?;
1117        seq_manager_2
1118            .mark_prefill_completed(&"request_1".to_string())
1119            .await?;
1120        seq_manager_2.free(&"request_0".to_string()).await?;
1121        seq_manager_2.free(&"request_1".to_string()).await?;
1122
1123        // Give some time for synchronization
1124        tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
1125
1126        // Query seq_manager_2 to verify everything is empty
1127        let tokens_phase2 = seq_manager_2.active_tokens().await;
1128
1129        // Verify phase 2 results - everything should be empty
1130        for worker_id in 0..=2 {
1131            assert_eq!(
1132                tokens_phase2[&worker_id], 0,
1133                "Worker {} should have 0 active tokens after all requests freed",
1134                worker_id
1135            );
1136        }
1137
1138        Ok(())
1139    }
1140}