1use anyhow::Result;
17use dynamo_runtime::{
18 component::Component,
19 pipeline::{
20 async_trait, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, ResponseStream,
21 SingleIn,
22 },
23 prelude::*,
24 protocols::annotated::Annotated,
25};
26use futures::stream::{self, StreamExt};
27use std::sync::Arc;
28
29pub mod indexer;
30pub mod metrics_aggregator;
31pub mod protocols;
32pub mod publisher;
33pub mod recorder;
34pub mod scheduler;
35pub mod scoring;
36
37use crate::{
38 kv_router::{
39 indexer::{KvIndexer, KvIndexerInterface, RouterEvent},
40 metrics_aggregator::KvMetricsAggregator,
41 protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
42 scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
43 scoring::ProcessedEndpoints,
44 },
45 tokens::Tokens,
46};
47
48use dynamo_runtime::traits::events::EventSubscriber;
49
50pub const KV_EVENT_SUBJECT: &str = "kv_events";
53pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
54pub const KV_METRICS_ENDPOINT: &str = "load_metrics";
55
56pub trait WorkerSelector {
58 fn select_worker(
59 &self,
60 workers: &ProcessedEndpoints,
61 request: &SchedulingRequest,
62 block_size: usize,
63 ) -> Result<WorkerSelectionResult, KvSchedulerError>;
64}
65
66pub struct KvRouter {
67 indexer: KvIndexer,
68 scheduler: KvScheduler,
69 block_size: usize,
70}
71
72impl KvRouter {
73 pub async fn new(
74 component: Component,
75 block_size: usize,
76 selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
77 ) -> Result<Arc<Self>> {
78 let cancellation_token = component
79 .drt()
80 .primary_lease()
81 .expect("Cannot KV route static workers")
82 .primary_token();
83
84 let metrics_aggregator =
85 KvMetricsAggregator::new(component.clone(), cancellation_token.clone()).await;
86 let indexer = KvIndexer::new(cancellation_token.clone(), block_size);
87 let scheduler = KvScheduler::start(
88 component.namespace().clone(),
89 block_size,
90 metrics_aggregator.endpoints_watcher(),
91 selector,
92 )
93 .await?;
94
95 let mut kv_events_rx = component.subscribe(KV_EVENT_SUBJECT).await?;
98 let kv_events_tx = indexer.event_sender();
99
100 tokio::spawn(async move {
101 while let Some(event) = kv_events_rx.next().await {
102 let event: RouterEvent = match serde_json::from_slice(&event.payload) {
103 Ok(event) => {
104 tracing::debug!("received kv event: {:?}", event);
105 event
106 }
107 Err(e) => {
108 tracing::warn!("Failed to deserialize RouterEvent: {:?}", e);
109 continue;
112 }
113 };
114 if let Err(e) = kv_events_tx.send(event).await {
115 tracing::trace!("failed to send kv event to indexer; shutting down: {:?}", e);
116 }
117 }
118 });
119
120 Ok(Arc::new(Self {
121 scheduler,
122 indexer,
123 block_size,
124 }))
125 }
126
127 pub async fn schedule(&self, token_ids: &Vec<u32>, _lora_id: u64) -> Result<i64> {
129 let isl_tokens = token_ids.len();
132 let overlap_scores = self
133 .indexer
134 .find_matches_for_request(token_ids.as_slice())
135 .await?;
136 tracing::debug!("KV router overlap_scores: {:?}", overlap_scores);
137 let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
138 Ok(worker_id)
139 }
140}
141
142#[async_trait]
143impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
144 async fn generate(
145 &self,
146 request: SingleIn<RouterRequest>,
147 ) -> Result<ManyOut<Annotated<RouterResponse>>> {
148 let (request, ctx) = request.into_parts();
149 let isl_tokens = request.tokens.len();
150 let block_size = self.block_size;
151
152 let local_block_hashes: Vec<LocalBlockHash> = tokio::task::spawn_blocking(move || {
154 Tokens::compute_block_hash(&request.tokens, block_size)
155 .into_iter()
156 .map(LocalBlockHash)
157 .collect()
158 })
159 .await?;
160
161 let overlap_scores = self.indexer.find_matches(local_block_hashes).await?;
162 let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
163
164 let response = RouterResponse { worker_id };
165 let response = Annotated::from_data(response);
166 let stream = stream::iter(vec![response]);
167 Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
168 }
169}