1use crate::kv_router::indexer::OverlapScores;
26use crate::kv_router::indexer::WorkerId;
27use crate::tokens::SequenceHash;
28use anyhow::Result;
29use dashmap::DashMap;
30use derive_getters::Getters;
31use dynamo_runtime::component::Component;
32use dynamo_runtime::traits::DistributedRuntimeProvider;
33use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber};
34use futures::StreamExt;
35use std::collections::{HashMap, HashSet};
36use std::sync::Arc;
37use std::time::Duration;
38use tokio::time::Instant;
39use uuid::Uuid;
40
41use super::protocols::{ActiveSequenceEvent, ActiveSequenceEventData};
42use crate::kv_router::ACTIVE_SEQUENCES_SUBJECT;
43use dynamo_runtime::CancellationToken;
44
45const EXPIRY_DURATION: Duration = Duration::from_secs(300);
47
48pub type RequestId = String;
50
51#[derive(Debug, Getters)]
53pub struct ActiveSequences {
54 active_seqs: HashMap<RequestId, Vec<SequenceHash>>,
55
56 prefill_tokens: HashMap<RequestId, usize>,
57
58 unique_blocks: HashMap<SequenceHash, HashSet<RequestId>>,
59
60 #[getter(copy)]
61 block_size: usize,
62
63 #[getter(copy)]
64 active_blocks: usize,
65
66 #[getter(copy)]
67 active_tokens: usize,
68
69 expiry_timer: Instant,
71
72 expiry_requests: HashSet<RequestId>,
74}
75
76impl ActiveSequences {
77 pub fn new(block_size: usize) -> Self {
79 assert!(block_size > 1, "block_size must be greater than 1");
81
82 Self {
83 active_seqs: HashMap::new(),
84 prefill_tokens: HashMap::new(),
85 unique_blocks: HashMap::new(),
86 block_size,
87 active_blocks: 0,
88 active_tokens: 0,
89 expiry_timer: Instant::now() + EXPIRY_DURATION,
90 expiry_requests: HashSet::new(),
91 }
92 }
93
94 fn add_block(&mut self, request_id: RequestId, block: &SequenceHash) {
95 let is_new_block = !self.unique_blocks.contains_key(block);
96
97 self.unique_blocks
98 .entry(*block)
99 .or_default()
100 .insert(request_id.clone());
101
102 if is_new_block {
103 self.active_blocks += 1;
104 }
105 }
106
107 fn remove_block(&mut self, request_id: &RequestId, block: &SequenceHash) {
108 let Some(request_ids) = self.unique_blocks.get_mut(block) else {
109 return;
110 };
111
112 request_ids.retain(|w| w != request_id);
114 if request_ids.is_empty() {
115 self.active_blocks -= 1;
116 self.unique_blocks.remove(block);
117 }
118 }
119
120 pub fn add_request(
123 &mut self,
124 request_id: RequestId,
125 token_sequence: Option<Vec<SequenceHash>>,
126 isl: usize,
127 overlap: u32,
128 ) -> HashSet<RequestId> {
129 if self.active_seqs.contains_key(&request_id) {
131 panic!("Request {request_id} is already active. Cannot accept double-add.");
132 }
133
134 let removed_requests = self.force_expiry();
136
137 let prefill_tokens = self.new_tokens(isl, overlap);
138 self.prefill_tokens
139 .insert(request_id.clone(), prefill_tokens);
140 self.active_tokens += prefill_tokens;
141
142 if let Some(sequence) = token_sequence {
143 for block in &sequence {
144 self.add_block(request_id.clone(), block);
145 }
146 self.active_seqs.insert(request_id.clone(), sequence);
147 } else {
148 self.active_seqs.insert(request_id.clone(), Vec::new());
150 }
151
152 removed_requests
153 }
154
155 pub fn mark_prefill_completed(&mut self, request_id: &RequestId) {
157 if let Some(tokens) = self.prefill_tokens.remove(request_id) {
158 self.active_tokens = self
159 .active_tokens
160 .checked_sub(tokens)
161 .expect("active_tokens underflow");
162 }
163 }
164
165 pub fn new_tokens(&self, isl: usize, overlap: u32) -> usize {
166 isl.checked_sub((overlap as usize) * self.block_size)
167 .unwrap_or_else(|| panic!("prefill_tokens < 0 with overlap {overlap} and ISL {isl}"))
168 }
169
170 pub fn potential_blocks_and_tokens(
171 &self,
172 token_sequence: Option<&[SequenceHash]>,
173 isl: usize,
174 overlap: u32,
175 ) -> (usize, usize) {
176 let potential_blocks = if let Some(token_seq) = token_sequence {
177 self.new_blocks(token_seq) + self.active_blocks
178 } else {
179 self.active_blocks
180 };
181 let potential_tokens = self.new_tokens(isl, overlap) + self.active_tokens;
182 (potential_blocks, potential_tokens)
183 }
184
185 pub fn new_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
187 token_sequence
188 .iter()
189 .filter(|block| !self.unique_blocks.contains_key(block))
190 .count()
191 }
192
193 pub fn potential_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
196 self.new_blocks(token_sequence) + self.active_blocks
197 }
198
199 pub fn free(&mut self, request_id: &RequestId) -> usize {
201 self.mark_prefill_completed(request_id);
202
203 self.expiry_requests.remove(request_id);
204
205 let token_seq = match self.active_seqs.remove(request_id) {
207 Some(seq) => seq,
208 None => {
209 tracing::warn!("Trying to free non-existent request {request_id}");
210 return self.active_blocks;
211 }
212 };
213
214 for block in token_seq {
215 self.remove_block(request_id, &block)
216 }
217
218 self.active_blocks
219 }
220
221 pub fn force_expiry(&mut self) -> HashSet<RequestId> {
224 let now = Instant::now();
225
226 if now < self.expiry_timer {
228 return HashSet::new();
229 }
230
231 let expired_requests: HashSet<RequestId> = self.expiry_requests.drain().collect();
233 for request_id in &expired_requests {
234 tracing::warn!("Force expiring stale request: {}", request_id);
235 self.free(request_id);
236 }
237
238 self.expiry_timer = now + EXPIRY_DURATION;
239 self.expiry_requests = self.active_seqs.keys().cloned().collect();
240
241 expired_requests
242 }
243}
244
245enum UpdateSequences {
246 AddRequest {
247 request_id: RequestId,
248 token_sequence: Option<Vec<SequenceHash>>,
249 isl: usize,
250 overlap: u32,
251 resp_tx: tokio::sync::oneshot::Sender<HashSet<RequestId>>,
252 },
253 Free {
254 request_id: RequestId,
255 },
256 MarkPrefillCompleted {
257 request_id: RequestId,
258 },
259 NewBlocks {
260 token_sequence: Arc<Vec<SequenceHash>>,
261 resp_tx: tokio::sync::oneshot::Sender<usize>,
262 },
263 PotentialBlocks {
264 token_sequence: Arc<Vec<SequenceHash>>,
265 resp_tx: tokio::sync::oneshot::Sender<usize>,
266 },
267 PotentialBlocksAndTokens {
268 token_sequence: Option<Arc<Vec<SequenceHash>>>,
269 isl: usize,
270 overlap: u32,
271 resp_tx: tokio::sync::oneshot::Sender<(usize, usize)>,
272 },
273 ActiveBlocks {
274 resp_tx: tokio::sync::oneshot::Sender<usize>,
275 },
276 ActiveTokens {
277 resp_tx: tokio::sync::oneshot::Sender<usize>,
278 },
279 Shutdown,
280}
281
282pub struct ActiveSequencesMultiWorker {
284 senders: Arc<DashMap<WorkerId, tokio::sync::mpsc::UnboundedSender<UpdateSequences>>>,
285 request_to_worker: Arc<DashMap<RequestId, WorkerId>>,
286 handles: Arc<DashMap<WorkerId, tokio::task::JoinHandle<()>>>,
287 block_size: usize,
288 component: Component,
289 router_id: Uuid,
290 replica_sync: bool,
291}
292
293impl ActiveSequencesMultiWorker {
294 pub fn new(
295 component: Component,
296 block_size: usize,
297 worker_ids: Vec<WorkerId>,
298 replica_sync: bool,
299 router_uuid: String,
300 ) -> Self {
301 assert!(block_size > 1, "block_size must be greater than 1");
302
303 let senders = Arc::new(DashMap::new());
304 let handles = Arc::new(DashMap::new());
305 let request_to_worker = Arc::new(DashMap::new());
306 let router_id = Uuid::parse_str(&router_uuid).unwrap_or_else(|e| {
307 tracing::warn!(
308 "Failed to parse router UUID '{}': {}, using new UUID",
309 router_uuid,
310 e
311 );
312 Uuid::new_v4()
313 });
314
315 for worker_id in worker_ids {
316 let cancel_token = component.drt().runtime().child_token();
318 let (sender, handle) = Self::start_worker(block_size, cancel_token);
319 senders.insert(worker_id, sender);
320 handles.insert(worker_id, handle);
321 }
322
323 let multi_worker = Self {
324 senders: senders.clone(),
325 request_to_worker: request_to_worker.clone(),
326 handles,
327 block_size,
328 component: component.clone(),
329 router_id,
330 replica_sync,
331 };
332
333 if replica_sync {
335 let senders_clone = senders.clone();
336 let request_to_worker_clone = request_to_worker.clone();
337 let component_clone = component.clone();
338 let router_id_clone = router_id;
339 let cancel_token = component.drt().runtime().child_token();
340
341 tokio::spawn(async move {
342 if let Err(e) = Self::subscribe_to_events(
344 senders_clone,
345 request_to_worker_clone,
346 component_clone,
347 router_id_clone,
348 cancel_token,
349 )
350 .await
351 {
352 tracing::error!("Error in active sequences events subscription: {}", e);
353 }
354 });
355 }
356
357 multi_worker
358 }
359
360 fn start_worker(
362 block_size: usize,
363 cancel_token: CancellationToken, ) -> (
365 tokio::sync::mpsc::UnboundedSender<UpdateSequences>,
366 tokio::task::JoinHandle<()>,
367 ) {
368 let (request_tx, mut request_rx) = tokio::sync::mpsc::unbounded_channel();
369
370 let handle = tokio::spawn(async move {
371 let mut active_sequences = ActiveSequences::new(block_size);
372
373 loop {
374 tokio::select! {
375 command = request_rx.recv() => {
377 match command {
378 Some(command) => {
379 match command {
380 UpdateSequences::AddRequest {
381 request_id,
382 token_sequence,
383 isl,
384 overlap,
385 resp_tx,
386 } => {
387 let removed = active_sequences.add_request(request_id, token_sequence, isl, overlap);
388 let _ = resp_tx.send(removed);
389 }
390 UpdateSequences::Free { request_id } => {
391 active_sequences.free(&request_id);
392 }
393 UpdateSequences::MarkPrefillCompleted { request_id } => {
394 active_sequences.mark_prefill_completed(&request_id);
395 }
396 UpdateSequences::NewBlocks {
397 token_sequence,
398 resp_tx,
399 } => {
400 let new_blocks = active_sequences.new_blocks(&token_sequence);
401 let _ = resp_tx.send(new_blocks);
402 }
403 UpdateSequences::PotentialBlocks {
404 token_sequence,
405 resp_tx,
406 } => {
407 let potential_blocks = active_sequences.potential_blocks(&token_sequence);
408 let _ = resp_tx.send(potential_blocks);
409 }
410 UpdateSequences::PotentialBlocksAndTokens {
411 token_sequence,
412 isl,
413 overlap,
414 resp_tx,
415 } => {
416 let potential_tokens = active_sequences.potential_blocks_and_tokens(
417 token_sequence.as_ref().map(|v| v.as_slice()),
418 isl,
419 overlap,
420 );
421 let _ = resp_tx.send(potential_tokens);
422 }
423 UpdateSequences::ActiveBlocks { resp_tx } => {
424 let active_blocks = active_sequences.active_blocks();
425 let _ = resp_tx.send(active_blocks);
426 }
427 UpdateSequences::ActiveTokens { resp_tx } => {
428 let active_tokens = active_sequences.active_tokens();
429 let _ = resp_tx.send(active_tokens);
430 }
431 UpdateSequences::Shutdown => {
432 break;
433 }
434 }
435 }
436 None => {
437 break;
439 }
440 }
441 }
442 _ = cancel_token.cancelled() => {
444 tracing::debug!("Worker task cancelled");
445 break;
446 }
447 }
448 }
449 });
450
451 (request_tx, handle)
452 }
453
454 async fn subscribe_to_events(
456 senders: Arc<DashMap<WorkerId, tokio::sync::mpsc::UnboundedSender<UpdateSequences>>>,
457 request_to_worker: Arc<DashMap<RequestId, WorkerId>>,
458 component: Component,
459 router_id: Uuid,
460 cancel_token: CancellationToken,
461 ) -> Result<()> {
462 let mut subscriber = component
463 .subscribe_with_type::<ActiveSequenceEvent>(ACTIVE_SEQUENCES_SUBJECT)
464 .await?;
465
466 loop {
467 tokio::select! {
468 result = subscriber.next() => {
470 let Some(result) = result else {
471 break;
473 };
474
475 let Ok(event) = result else {
476 tracing::error!(
477 "Error receiving active sequence event: {}",
478 result.unwrap_err()
479 );
480 continue;
481 };
482
483 if event.router_id == router_id {
485 continue;
486 }
487
488 match &event.data {
489 ActiveSequenceEventData::AddRequest {
490 token_sequence,
491 isl,
492 overlap,
493 } => {
494 request_to_worker.insert(event.request_id.clone(), event.worker_id);
495
496 if let Some(sender) = senders.get(&event.worker_id) {
497 let (resp_tx, _) = tokio::sync::oneshot::channel();
499 let _ = sender.send(UpdateSequences::AddRequest {
500 request_id: event.request_id.clone(),
501 token_sequence: token_sequence.clone(),
502 isl: *isl,
503 overlap: *overlap,
504 resp_tx,
505 });
506 } else {
507 tracing::warn!(
508 "Worker {} not found, cannot process AddRequest",
509 event.worker_id
510 );
511 }
512 }
513 ActiveSequenceEventData::Free => {
514 if let Some((_, worker_id)) = request_to_worker.remove(&event.request_id)
515 && let Some(sender) = senders.get(&worker_id)
516 {
517 let _ = sender.send(UpdateSequences::Free {
518 request_id: event.request_id.clone(),
519 });
520 }
521 }
522 ActiveSequenceEventData::MarkPrefillCompleted => {
523 if let Some(worker_id) = request_to_worker.get(&event.request_id)
524 && let Some(sender) = senders.get(&*worker_id)
525 {
526 let _ = sender.send(UpdateSequences::MarkPrefillCompleted {
527 request_id: event.request_id.clone(),
528 });
529 }
530 }
531 }
532 }
533 _ = cancel_token.cancelled() => {
535 tracing::debug!("Subscription task cancelled");
536 break;
537 }
538 }
539 }
540
541 Ok(())
542 }
543
544 pub fn update_workers(&self, new_worker_ids: Vec<WorkerId>) {
546 let current_workers: HashSet<WorkerId> =
547 self.senders.iter().map(|entry| *entry.key()).collect();
548 let new_workers: HashSet<WorkerId> = new_worker_ids.into_iter().collect();
549
550 let workers_to_remove: Vec<WorkerId> =
551 current_workers.difference(&new_workers).copied().collect();
552 let workers_to_add: Vec<WorkerId> =
553 new_workers.difference(¤t_workers).copied().collect();
554
555 for worker_id in &workers_to_remove {
557 tracing::warn!("Removing worker {}", worker_id);
558
559 if let Some((_, sender)) = self.senders.remove(worker_id) {
561 let _ = sender.send(UpdateSequences::Shutdown);
562 }
563 if let Some((_, handle)) = self.handles.remove(worker_id) {
564 handle.abort();
565 }
566
567 self.request_to_worker
569 .retain(|_request_id, mapped_worker_id| *mapped_worker_id != *worker_id);
570 }
571
572 for worker_id in &workers_to_add {
574 tracing::warn!("Adding worker {}", worker_id);
575
576 let (sender, handle) = Self::start_worker(
577 self.block_size,
578 self.component.drt().runtime().child_token(),
579 );
580 self.senders.insert(*worker_id, sender);
581 self.handles.insert(*worker_id, handle);
582 }
583 }
584
585 pub async fn add_request(
586 &self,
587 request_id: RequestId,
588 token_sequence: Option<Vec<SequenceHash>>,
589 isl: usize,
590 overlap: u32,
591 worker_id: WorkerId,
592 ) -> Result<()> {
593 if !self.senders.contains_key(&worker_id) {
594 return Err(anyhow::anyhow!("Worker ID {worker_id} not found"));
595 }
596
597 let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
599
600 if self.replica_sync {
602 let event = ActiveSequenceEvent {
603 request_id: request_id.clone(),
604 worker_id,
605 data: ActiveSequenceEventData::AddRequest {
606 token_sequence: token_sequence.clone(),
607 isl,
608 overlap,
609 },
610 router_id: self.router_id,
611 };
612 self.component
613 .publish(ACTIVE_SEQUENCES_SUBJECT, &event)
614 .await?;
615 }
616
617 self.request_to_worker.insert(request_id.clone(), worker_id);
619
620 self.senders
621 .get(&worker_id)
622 .unwrap()
623 .send(UpdateSequences::AddRequest {
624 request_id,
625 token_sequence,
626 isl,
627 overlap,
628 resp_tx,
629 })
630 .map_err(|_| anyhow::anyhow!("Failed to send add_request command to worker"))?;
631
632 let removed_requests = resp_rx
634 .await
635 .map_err(|_| anyhow::anyhow!("Failed to receive response from worker"))?;
636
637 for expired_id in &removed_requests {
639 self.request_to_worker.remove(expired_id);
640 }
641
642 Ok(())
643 }
644
645 pub async fn free(&self, request_id: &RequestId) -> Result<()> {
646 let worker_id = self
647 .request_to_worker
648 .get(request_id)
649 .map(|entry| *entry)
650 .ok_or_else(|| anyhow::anyhow!("Request ID not found in request_to_worker mapping"))?;
651
652 if self.replica_sync {
654 let event = ActiveSequenceEvent {
655 request_id: request_id.clone(),
656 worker_id,
657 data: ActiveSequenceEventData::Free,
658 router_id: self.router_id,
659 };
660 self.component
661 .publish(ACTIVE_SEQUENCES_SUBJECT, &event)
662 .await?;
663 }
664
665 self.senders
667 .get(&worker_id)
668 .unwrap()
669 .send(UpdateSequences::Free {
670 request_id: request_id.clone(),
671 })
672 .map_err(|_| anyhow::anyhow!("Failed to send free command to worker"))?;
673
674 self.request_to_worker.remove(request_id);
675
676 Ok(())
677 }
678
679 pub async fn mark_prefill_completed(&self, request_id: &RequestId) -> Result<()> {
681 let worker_id = self
682 .request_to_worker
683 .get(request_id)
684 .map(|entry| *entry)
685 .ok_or_else(|| anyhow::anyhow!("Request ID not found in request_to_worker mapping"))?;
686
687 if self.replica_sync {
689 let event = ActiveSequenceEvent {
690 request_id: request_id.clone(),
691 worker_id,
692 data: ActiveSequenceEventData::MarkPrefillCompleted,
693 router_id: self.router_id,
694 };
695 self.component
696 .publish(ACTIVE_SEQUENCES_SUBJECT, &event)
697 .await?;
698 }
699
700 self.senders
702 .get(&worker_id)
703 .unwrap()
704 .send(UpdateSequences::MarkPrefillCompleted {
705 request_id: request_id.clone(),
706 })
707 .map_err(|_| {
708 anyhow::anyhow!("Failed to send mark_prefill_completed command to worker")
709 })?;
710
711 Ok(())
712 }
713
714 pub fn num_workers(&self) -> usize {
716 self.senders.len()
717 }
718
719 async fn query_workers<T: Send + 'static>(
721 &self,
722 token_sequence: Option<Vec<SequenceHash>>,
723 command_fn: impl Fn(
724 Option<Arc<Vec<SequenceHash>>>,
725 tokio::sync::oneshot::Sender<T>,
726 ) -> UpdateSequences,
727 ) -> HashMap<WorkerId, T> {
728 let mut results = HashMap::new();
729 let token_sequence_shared = token_sequence.map(Arc::new);
730 let mut receivers = Vec::new();
731
732 for entry in self.senders.iter() {
734 let worker_id = *entry.key();
735 let sender = entry.value();
736 let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
737 receivers.push((worker_id, resp_rx));
738 if let Err(e) = sender.send(command_fn(token_sequence_shared.clone(), resp_tx)) {
739 tracing::error!("Failed to send command to worker {}: {}", worker_id, e);
740 }
741 }
742
743 for (worker_id, receiver) in receivers {
745 match tokio::time::timeout(tokio::time::Duration::from_secs(1), receiver).await {
746 Ok(Ok(result)) => {
747 results.insert(worker_id, result);
748 }
749 Ok(Err(_)) => {
750 tracing::error!("Worker {} dropped response channel", worker_id);
751 }
752 Err(_) => {
753 tracing::error!("Timeout waiting for response from worker {}", worker_id);
754 }
755 }
756 }
757
758 results
759 }
760
761 pub async fn new_blocks(&self, token_sequence: Vec<SequenceHash>) -> HashMap<WorkerId, usize> {
763 self.query_workers(Some(token_sequence), |ts, resp_tx| match ts {
764 Some(ts) => UpdateSequences::NewBlocks {
765 token_sequence: ts,
766 resp_tx,
767 },
768 None => unreachable!("token_sequence should always be Some for new_blocks"),
769 })
770 .await
771 }
772
773 pub async fn potential_blocks(
775 &self,
776 token_sequence: Vec<SequenceHash>,
777 ) -> HashMap<WorkerId, usize> {
778 self.query_workers(Some(token_sequence), |ts, resp_tx| match ts {
779 Some(ts) => UpdateSequences::PotentialBlocks {
780 token_sequence: ts,
781 resp_tx,
782 },
783 None => unreachable!("token_sequence should always be Some for potential_blocks"),
784 })
785 .await
786 }
787
788 pub async fn potential_blocks_and_tokens(
790 &self,
791 token_sequence: Option<Vec<SequenceHash>>,
792 isl: usize,
793 overlaps: OverlapScores,
794 ) -> (HashMap<WorkerId, usize>, HashMap<WorkerId, usize>) {
795 let mut potential_blocks = HashMap::new();
796 let mut potential_tokens = HashMap::new();
797 let token_sequence_shared = token_sequence.map(Arc::new);
798 let mut receivers = Vec::new();
799
800 for entry in self.senders.iter() {
802 let worker_id = *entry.key();
803 let sender = entry.value();
804 let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
805 receivers.push((worker_id, resp_rx));
806
807 if let Err(e) = sender.send(UpdateSequences::PotentialBlocksAndTokens {
808 token_sequence: token_sequence_shared.clone(),
809 isl,
810 overlap: overlaps.scores.get(&worker_id).copied().unwrap_or(0),
811 resp_tx,
812 }) {
813 tracing::error!(
814 "Failed to send potential_tokens command to worker {}: {}",
815 worker_id,
816 e
817 );
818 }
819 }
820
821 for (worker_id, receiver) in receivers {
823 match tokio::time::timeout(tokio::time::Duration::from_secs(1), receiver).await {
824 Ok(Ok((blocks, tokens))) => {
825 potential_blocks.insert(worker_id, blocks);
826 potential_tokens.insert(worker_id, tokens);
827 }
828 Ok(Err(_)) => {
829 tracing::error!("Worker {} dropped response channel", worker_id);
830 }
831 Err(_) => {
832 tracing::error!("Timeout waiting for response from worker {}", worker_id);
833 }
834 }
835 }
836
837 (potential_blocks, potential_tokens)
838 }
839
840 pub async fn active_blocks(&self) -> HashMap<WorkerId, usize> {
842 self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveBlocks { resp_tx })
843 .await
844 }
845
846 pub async fn active_tokens(&self) -> HashMap<WorkerId, usize> {
848 self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveTokens { resp_tx })
849 .await
850 }
851}
852
853impl Drop for ActiveSequencesMultiWorker {
854 fn drop(&mut self) {
855 for entry in self.senders.iter() {
857 let _ = entry.value().send(UpdateSequences::Shutdown);
858 }
859
860 for entry in self.handles.iter() {
862 entry.value().abort();
863 }
864 }
865}
866
867#[cfg(test)]
868mod tests {
869 use super::*;
870 use dynamo_runtime::{DistributedRuntime, Runtime};
871 use std::sync::Arc;
872
873 #[tokio::test]
874 #[ignore]
875 async fn test_multi_worker_cross_instance_sync() -> Result<()> {
876 dynamo_runtime::logging::init();
878
879 let block_size = 4; let runtime = Runtime::from_current()?;
883 let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
884
885 let namespace = distributed.namespace("test_cross_instance_sync")?;
887 let component = namespace
888 .component("sequences")?
889 .service_builder()
890 .create()
891 .await?;
892
893 let worker_ids = vec![0, 1, 2];
896 let seq_manager_1 = Arc::new(ActiveSequencesMultiWorker::new(
897 component.clone(),
898 block_size,
899 worker_ids.clone(),
900 true,
901 Uuid::new_v4().to_string(),
902 ));
903 let seq_manager_2 = Arc::new(ActiveSequencesMultiWorker::new(
904 component,
905 block_size,
906 worker_ids,
907 true,
908 Uuid::new_v4().to_string(),
909 ));
910
911 tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
913
914 seq_manager_1
918 .add_request(
919 "request_0".to_string(),
920 Some(vec![0, 1, 2]),
921 12, 0, 0, )
925 .await?;
926
927 seq_manager_1
929 .add_request(
930 "request_1".to_string(),
931 Some(vec![3, 4]),
932 8, 0, 1, )
936 .await?;
937
938 seq_manager_2
940 .add_request(
941 "request_2".to_string(),
942 Some(vec![0, 1, 2, 3]),
943 16, 0, 2, )
947 .await?;
948
949 tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
951
952 let blocks_phase1 = seq_manager_1.active_blocks().await;
954 let tokens_phase1 = seq_manager_1.active_tokens().await;
955
956 assert_eq!(
958 blocks_phase1[&0], 3,
959 "Worker 0 should have 3 active blocks (from request_0)"
960 );
961 assert_eq!(
962 blocks_phase1[&1], 2,
963 "Worker 1 should have 2 active blocks (from request_1)"
964 );
965 assert_eq!(
966 blocks_phase1[&2], 4,
967 "Worker 2 should have 4 active blocks (from request_2 added by seq_manager_2)"
968 );
969 assert_eq!(
970 tokens_phase1[&0], 12,
971 "Worker 0 should have 12 active tokens"
972 );
973 assert_eq!(tokens_phase1[&1], 8, "Worker 1 should have 8 active tokens");
974 assert_eq!(
975 tokens_phase1[&2], 16,
976 "Worker 2 should have 16 active tokens (from request_2 added by seq_manager_2)"
977 );
978
979 seq_manager_1.free(&"request_2".to_string()).await?;
983
984 seq_manager_2.free(&"request_0".to_string()).await?;
986 seq_manager_2.free(&"request_1".to_string()).await?;
987
988 tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
990
991 let blocks_phase2 = seq_manager_2.active_blocks().await;
993 let tokens_phase2 = seq_manager_2.active_tokens().await;
994
995 for worker_id in 0..=2 {
997 assert_eq!(
998 blocks_phase2[&worker_id], 0,
999 "Worker {} should have 0 active blocks after all requests freed",
1000 worker_id
1001 );
1002 assert_eq!(
1003 tokens_phase2[&worker_id], 0,
1004 "Worker {} should have 0 active tokens after all requests freed",
1005 worker_id
1006 );
1007 }
1008
1009 Ok(())
1010 }
1011
1012 #[tokio::test]
1013 #[ignore]
1014 async fn test_multi_worker_no_token_sequence_sync() -> Result<()> {
1015 dynamo_runtime::logging::init();
1017
1018 let block_size = 4; let runtime = Runtime::from_current()?;
1022 let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
1023
1024 let namespace = distributed.namespace("test_no_token_seq_sync")?;
1026 let component = namespace
1027 .component("sequences")?
1028 .service_builder()
1029 .create()
1030 .await?;
1031
1032 let worker_ids = vec![0, 1, 2];
1035 let seq_manager_1 = Arc::new(ActiveSequencesMultiWorker::new(
1036 component.clone(),
1037 block_size,
1038 worker_ids.clone(),
1039 true,
1040 Uuid::new_v4().to_string(),
1041 ));
1042 let seq_manager_2 = Arc::new(ActiveSequencesMultiWorker::new(
1043 component,
1044 block_size,
1045 worker_ids,
1046 true,
1047 Uuid::new_v4().to_string(),
1048 ));
1049
1050 tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
1052
1053 seq_manager_1
1057 .add_request(
1058 "request_0".to_string(),
1059 None, 12, 0, 0, )
1064 .await?;
1065
1066 seq_manager_1
1068 .add_request(
1069 "request_1".to_string(),
1070 None, 8, 0, 1, )
1075 .await?;
1076
1077 seq_manager_2
1079 .add_request(
1080 "request_2".to_string(),
1081 None, 16, 0, 2, )
1086 .await?;
1087
1088 tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
1090
1091 let tokens_phase1 = seq_manager_1.active_tokens().await;
1093
1094 assert_eq!(
1096 tokens_phase1[&0], 12,
1097 "Worker 0 should have 12 active tokens"
1098 );
1099 assert_eq!(tokens_phase1[&1], 8, "Worker 1 should have 8 active tokens");
1100 assert_eq!(
1101 tokens_phase1[&2], 16,
1102 "Worker 2 should have 16 active tokens (from request_2 added by seq_manager_2)"
1103 );
1104
1105 seq_manager_1
1109 .mark_prefill_completed(&"request_2".to_string())
1110 .await?;
1111 seq_manager_1.free(&"request_2".to_string()).await?;
1112
1113 seq_manager_2
1115 .mark_prefill_completed(&"request_0".to_string())
1116 .await?;
1117 seq_manager_2
1118 .mark_prefill_completed(&"request_1".to_string())
1119 .await?;
1120 seq_manager_2.free(&"request_0".to_string()).await?;
1121 seq_manager_2.free(&"request_1".to_string()).await?;
1122
1123 tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
1125
1126 let tokens_phase2 = seq_manager_2.active_tokens().await;
1128
1129 for worker_id in 0..=2 {
1131 assert_eq!(
1132 tokens_phase2[&worker_id], 0,
1133 "Worker {} should have 0 active tokens after all requests freed",
1134 worker_id
1135 );
1136 }
1137
1138 Ok(())
1139 }
1140}