use anyhow::Result;
use dynamo_runtime::{
component::Component,
pipeline::{
async_trait, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, ResponseStream,
SingleIn,
},
prelude::*,
protocols::annotated::Annotated,
};
use futures::stream::{self, StreamExt};
use std::sync::Arc;
pub mod indexer;
pub mod metrics_aggregator;
pub mod protocols;
pub mod publisher;
pub mod recorder;
pub mod scheduler;
pub mod scoring;
use crate::{
kv_router::{
indexer::{KvIndexer, KvIndexerInterface, RouterEvent},
metrics_aggregator::KvMetricsAggregator,
protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
scoring::ProcessedEndpoints,
},
tokens::TokenBlockSequence,
};
use dynamo_runtime::traits::events::EventSubscriber;
pub const KV_EVENT_SUBJECT: &str = "kv_events";
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";
pub trait WorkerSelector {
fn select_worker(
&self,
workers: &ProcessedEndpoints,
request: &SchedulingRequest,
block_size: usize,
) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
pub struct KvRouter {
indexer: KvIndexer,
scheduler: KvScheduler,
block_size: usize,
}
impl KvRouter {
pub async fn new(
component: Component,
block_size: usize,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
) -> Result<Arc<Self>> {
let cancellation_token = component
.drt()
.primary_lease()
.expect("Cannot KV route static workers")
.primary_token();
let metrics_aggregator =
KvMetricsAggregator::new(component.clone(), cancellation_token.clone()).await;
let indexer = KvIndexer::new(cancellation_token.clone(), block_size);
let scheduler = KvScheduler::start(
component.namespace().clone(),
block_size,
metrics_aggregator.endpoints_watcher(),
selector,
)
.await?;
let mut kv_events_rx = component.subscribe(KV_EVENT_SUBJECT).await?;
let kv_events_tx = indexer.event_sender();
tokio::spawn(async move {
while let Some(event) = kv_events_rx.next().await {
let event: RouterEvent = match serde_json::from_slice(&event.payload) {
Ok(event) => {
tracing::debug!("received kv event: {:?}", event);
event
}
Err(e) => {
tracing::warn!("Failed to deserialize RouterEvent: {:?}", e);
continue;
}
};
if let Err(e) = kv_events_tx.send(event).await {
tracing::trace!("failed to send kv event to indexer; shutting down: {:?}", e);
}
}
});
Ok(Arc::new(Self {
scheduler,
indexer,
block_size,
}))
}
pub async fn schedule(&self, token_ids: &Vec<u32>, _lora_id: u64) -> Result<i64> {
let isl_tokens = token_ids.len();
let overlap_scores = self
.indexer
.find_matches_for_request(token_ids.as_slice())
.await?;
tracing::debug!("KV router overlap_scores: {:?}", overlap_scores);
let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
Ok(worker_id)
}
}
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
async fn generate(
&self,
request: SingleIn<RouterRequest>,
) -> Result<ManyOut<Annotated<RouterResponse>>> {
let (request, ctx) = request.into_parts();
let isl_tokens = request.tokens.len();
let block_size = self.block_size;
let (complete_blocks, _partial_block) =
TokenBlockSequence::split_tokens(&request.tokens, block_size, 1337_u64);
let local_block_hashes = complete_blocks
.into_iter()
.map(|block| LocalBlockHash(block.block_hash()))
.collect();
let overlap_scores = self.indexer.find_matches(local_block_hashes).await?;
let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
let response = RouterResponse { worker_id };
let response = Annotated::from_data(response);
let stream = stream::iter(vec![response]);
Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
}
}