dynamo-llm 1.3.0-dev.1

Dynamo LLM Library
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::collections::{HashMap, VecDeque};

use dynamo_kv_router::{
    protocols::{DpRank, KvCacheEventData, RouterEvent, WorkerId},
    recovery::{CursorObservation, CursorState},
};

pub(super) type RecoveryKey = (WorkerId, DpRank);

pub(super) const RECOVERY_PENDING_LIVE_EVENT_LIMIT: usize = 1024;
pub(super) const RECOVERY_PENDING_FAST_PRUNE_MARGIN: usize = 10;

pub(super) enum LiveEventAction {
    Ignore,
    ApplyDirect(RouterEvent),
    ApplyClear(RouterEvent),
    SpawnFullRestore { epoch: u64 },
    SpawnIncremental { epoch: u64, start_event_id: u64 },
}

pub(super) struct DiscoveredRankAction {
    pub(super) reset_rank: bool,
    pub(super) restore_epoch: Option<u64>,
}

pub(super) enum PendingDrainAction {
    Apply(RouterEvent),
    RecoverFrom(u64),
    Complete,
}

#[derive(Debug, Default)]
pub(super) struct RankState {
    pub(super) cursor: CursorState,
    pub(super) max_seen_live_id: Option<u64>,
    pub(super) recovery_inflight: bool,
    pub(super) pending_live_events: VecDeque<RouterEvent>,
}

impl RankState {
    pub(super) fn last_applied_id(&self) -> Option<u64> {
        self.cursor.last_applied_id()
    }

    pub(super) fn observe_live_id(&mut self, event_id: u64) {
        self.max_seen_live_id = Some(self.max_seen_live_id.unwrap_or(0).max(event_id));
    }

    fn clear_max_seen_if_caught_up(&mut self, last_applied_id: u64) {
        if self
            .max_seen_live_id
            .is_some_and(|max_seen| max_seen <= last_applied_id)
        {
            self.max_seen_live_id = None;
        }
    }

    fn observe_and_buffer_live_event(&mut self, event: RouterEvent) {
        self.observe_live_id(event.event.event_id);
        self.push_pending_live_event(event);
    }

    fn push_pending_live_event(&mut self, event: RouterEvent) {
        self.pending_live_events.push_back(event);
        while self.pending_live_events.len() > RECOVERY_PENDING_LIVE_EVENT_LIMIT {
            self.pending_live_events.pop_front();
        }
    }

    fn clear_pending_live_events(&mut self) {
        self.pending_live_events.clear();
    }

    fn fast_prune_stale_pending_prefix(&mut self, last_applied_id: u64) {
        if self.pending_live_events.len() <= RECOVERY_PENDING_FAST_PRUNE_MARGIN {
            return;
        }

        let split_at = self.pending_live_events.len() - RECOVERY_PENDING_FAST_PRUNE_MARGIN;
        let Some(boundary_event) = self.pending_live_events.get(split_at) else {
            return;
        };
        if boundary_event.event.event_id <= last_applied_id {
            self.pending_live_events.drain(..split_at);
        }
    }

    fn next_pending_drain_action(&mut self) -> PendingDrainAction {
        let mut last_applied_id = self.last_applied_id().unwrap_or(0);
        self.fast_prune_stale_pending_prefix(last_applied_id);

        loop {
            let Some(front_event_id) = self
                .pending_live_events
                .front()
                .map(|event| event.event.event_id)
            else {
                self.clear_max_seen_if_caught_up(last_applied_id);
                if self
                    .max_seen_live_id
                    .is_some_and(|max_seen| max_seen > last_applied_id)
                {
                    return PendingDrainAction::RecoverFrom(last_applied_id.saturating_add(1));
                }
                self.recovery_inflight = false;
                return PendingDrainAction::Complete;
            };

            if front_event_id <= last_applied_id {
                self.pending_live_events.pop_front();
                continue;
            }

            let expected = last_applied_id.saturating_add(1);
            if front_event_id != expected {
                return PendingDrainAction::RecoverFrom(expected);
            }

            let event = self
                .pending_live_events
                .pop_front()
                .expect("front event exists while draining pending live events");
            self.cursor = self.cursor.advance_to(front_event_id);
            last_applied_id = front_event_id;
            self.clear_max_seen_if_caught_up(last_applied_id);
            return PendingDrainAction::Apply(event);
        }
    }
}

#[derive(Debug, Default)]
pub(super) struct WorkerState {
    pub(super) epoch: u64,
    pub(super) ranks: HashMap<DpRank, RankState>,
}

impl WorkerState {
    pub(super) fn handle_discovered_rank(
        &mut self,
        dp_rank: DpRank,
        reset_rank: bool,
    ) -> DiscoveredRankAction {
        if reset_rank {
            self.epoch += 1;
            self.ranks.insert(dp_rank, RankState::default());
        }

        let rank_state = self.ranks.entry(dp_rank).or_default();
        let restore_epoch =
            if matches!(rank_state.cursor, CursorState::Initial) && !rank_state.recovery_inflight {
                rank_state.recovery_inflight = true;
                Some(self.epoch)
            } else {
                None
            };

        DiscoveredRankAction {
            reset_rank,
            restore_epoch,
        }
    }

    pub(super) fn remove_rank(&mut self, dp_rank: DpRank) -> bool {
        self.ranks.remove(&dp_rank).is_some()
    }

    pub(super) fn is_empty(&self) -> bool {
        self.ranks.is_empty()
    }

    pub(super) fn observe_live_event(&mut self, event: RouterEvent) -> LiveEventAction {
        let dp_rank = event.event.dp_rank;
        let event_id = event.event.event_id;

        if matches!(&event.event.data, KvCacheEventData::Cleared) {
            let rank_state = self.ranks.entry(dp_rank).or_default();
            if rank_state
                .last_applied_id()
                .is_some_and(|last_applied_id| event_id <= last_applied_id)
            {
                return LiveEventAction::Ignore;
            }

            self.apply_worker_clear_barrier(dp_rank, event_id);
            return LiveEventAction::ApplyClear(event);
        }

        let rank_state = self.ranks.entry(dp_rank).or_default();
        match rank_state.cursor.observe(event_id) {
            CursorObservation::Stale { .. } => LiveEventAction::Ignore,
            observation if rank_state.recovery_inflight => {
                match observation {
                    CursorObservation::Initial { .. }
                    | CursorObservation::Contiguous { .. }
                    | CursorObservation::Gap { .. }
                    | CursorObservation::FreshAfterBarrier { .. } => {
                        rank_state.observe_and_buffer_live_event(event);
                    }
                    CursorObservation::Stale { .. } => {}
                }
                LiveEventAction::Ignore
            }
            CursorObservation::Initial { .. } => {
                rank_state.observe_and_buffer_live_event(event);
                rank_state.recovery_inflight = true;
                LiveEventAction::SpawnFullRestore { epoch: self.epoch }
            }
            CursorObservation::Gap { expected, .. } => {
                rank_state.observe_and_buffer_live_event(event);
                rank_state.recovery_inflight = true;
                LiveEventAction::SpawnIncremental {
                    epoch: self.epoch,
                    start_event_id: expected,
                }
            }
            CursorObservation::Contiguous { got }
            | CursorObservation::FreshAfterBarrier { got, .. } => {
                rank_state.cursor = rank_state.cursor.advance_to(got);
                rank_state.clear_max_seen_if_caught_up(got);
                LiveEventAction::ApplyDirect(event)
            }
        }
    }

    pub(super) fn apply_worker_clear_barrier(
        &mut self,
        clear_dp_rank: DpRank,
        clear_event_id: u64,
    ) {
        self.epoch += 1;
        for rank_state in self.ranks.values_mut() {
            let post_clear_max_seen = rank_state
                .max_seen_live_id
                .filter(|max_seen| *max_seen > clear_event_id);
            rank_state.cursor = rank_state.cursor.invalidate_by_barrier();
            rank_state.max_seen_live_id = post_clear_max_seen;
            rank_state.recovery_inflight = false;
            rank_state.clear_pending_live_events();
        }

        let rank_state = self.ranks.entry(clear_dp_rank).or_default();
        rank_state.cursor = rank_state.cursor.apply_barrier(clear_event_id);
    }

    pub(super) fn rank_cursor(&self, dp_rank: DpRank) -> Option<CursorState> {
        self.ranks.get(&dp_rank).map(|rank_state| rank_state.cursor)
    }

    pub(super) fn begin_successful_recovery_drain(&mut self, dp_rank: DpRank, cursor: CursorState) {
        let rank_state = self
            .ranks
            .get_mut(&dp_rank)
            .expect("rank state should exist while finishing recovery");
        rank_state.cursor = cursor;
        rank_state.recovery_inflight = true;
    }

    pub(super) fn next_pending_drain_action(&mut self, dp_rank: DpRank) -> PendingDrainAction {
        let Some(rank_state) = self.ranks.get_mut(&dp_rank) else {
            return PendingDrainAction::Complete;
        };
        rank_state.next_pending_drain_action()
    }

    pub(super) fn finish_failed_recovery(&mut self, dp_rank: DpRank) {
        let Some(rank_state) = self.ranks.get_mut(&dp_rank) else {
            return;
        };
        rank_state.recovery_inflight = false;
        rank_state.clear_pending_live_events();
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use dynamo_kv_router::protocols::{
        ExternalSequenceBlockHash, KvCacheEvent, KvCacheStoreData, KvCacheStoredBlockData,
        LocalBlockHash,
    };

    fn make_store_event(event_id: u64) -> RouterEvent {
        RouterEvent::new(
            1,
            KvCacheEvent {
                event_id,
                data: KvCacheEventData::Stored(KvCacheStoreData {
                    parent_hash: None,
                    start_position: None,
                    blocks: vec![KvCacheStoredBlockData {
                        block_hash: ExternalSequenceBlockHash(event_id),
                        tokens_hash: LocalBlockHash(event_id),
                        mm_extra_info: None,
                    }],
                }),
                dp_rank: 0,
            },
        )
    }

    fn pending_ids(rank_state: &RankState) -> Vec<u64> {
        rank_state
            .pending_live_events
            .iter()
            .map(|event| event.event.event_id)
            .collect()
    }

    #[test]
    fn test_pending_fast_prune_keeps_margin_before_normal_drain() {
        let mut rank_state = RankState {
            cursor: CursorState::Live(35),
            recovery_inflight: true,
            ..Default::default()
        };
        for event_id in 1..=40 {
            rank_state.push_pending_live_event(make_store_event(event_id));
        }

        match rank_state.next_pending_drain_action() {
            PendingDrainAction::Apply(event) => assert_eq!(event.event.event_id, 36),
            _ => panic!("expected first non-stale contiguous pending event to apply"),
        }
        assert_eq!(pending_ids(&rank_state), vec![37, 38, 39, 40]);
    }

    #[test]
    fn test_pending_stale_gappy_queue_after_recovery_completes_without_follow_up() {
        let mut rank_state = RankState {
            cursor: CursorState::Live(20),
            recovery_inflight: true,
            ..Default::default()
        };
        for event_id in [13, 14, 18, 19] {
            rank_state.push_pending_live_event(make_store_event(event_id));
        }

        match rank_state.next_pending_drain_action() {
            PendingDrainAction::Complete => {}
            _ => panic!("expected stale pending queue to complete without follow-up"),
        }
        assert!(rank_state.pending_live_events.is_empty());
        assert!(!rank_state.recovery_inflight);
    }
}