use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::time::{Duration, Instant};
use dashmap::DashMap;
use dynamo_runtime::component::Component;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::transports::event_plane::EventSubscriber;
use crate::kv_router::ACTIVE_SEQUENCES_SUBJECT;
use crate::kv_router::protocols::{ActiveSequenceEvent, ActiveSequenceEventData};
use crate::kv_router::scheduler::KvScheduler;
#[derive(Debug, Clone)]
pub struct LoadSample {
pub timestamp: Instant,
pub active_count: usize,
}
#[derive(Debug, Clone, Default)]
struct LoraLoadData {
active_count: usize,
samples: VecDeque<LoadSample>,
}
#[derive(Debug, Clone)]
pub struct LoadEstimatorConfig {
pub poll_interval: Duration,
pub max_samples: usize,
}
impl Default for LoadEstimatorConfig {
fn default() -> Self {
Self {
poll_interval: Duration::from_secs(5),
max_samples: 1000,
}
}
}
pub struct LoadEstimator {
data: DashMap<String, LoraLoadData>,
config: LoadEstimatorConfig,
}
impl LoadEstimator {
pub fn new() -> Self {
Self::with_config(LoadEstimatorConfig::default())
}
pub fn with_config(config: LoadEstimatorConfig) -> Self {
Self {
data: DashMap::new(),
config,
}
}
pub fn start_polling(
self: Arc<Self>,
scheduler: Arc<KvScheduler>,
component: Component,
) -> tokio::task::JoinHandle<()> {
let cancel_token = component.drt().child_token();
tokio::spawn(async move {
let mut interval = tokio::time::interval(self.config.poll_interval);
tracing::info!("Started LORA load polling");
loop {
tokio::select! {
_ = cancel_token.cancelled() => {
tracing::debug!("LORA load polling task cancelled");
break;
}
_ = interval.tick() => {
let lora_counts = scheduler.get_active_lora_counts();
self.update_from_counts(lora_counts);
}
}
}
})
}
pub fn start_event_subscription(
self: Arc<Self>,
component: Component,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
if let Err(e) = self.subscribe_to_events(component).await {
tracing::error!("Error in LORA load event subscription: {}", e);
}
})
}
async fn subscribe_to_events(&self, component: Component) -> anyhow::Result<()> {
let cancel_token = component.drt().child_token();
let mut subscriber = EventSubscriber::for_component(&component, ACTIVE_SEQUENCES_SUBJECT)
.await?
.typed::<ActiveSequenceEvent>();
tracing::info!("Started LORA load event subscription");
loop {
tokio::select! {
_ = cancel_token.cancelled() => {
tracing::debug!("LORA load event subscription cancelled");
break;
}
result = subscriber.next() => {
match result {
Some(Ok((_envelope, event))) => {
self.handle_event(event);
}
Some(Err(e)) => {
tracing::warn!("Error receiving LORA load event: {}", e);
}
None => {
tracing::warn!("LORA load event stream ended");
break;
}
}
}
}
}
Ok(())
}
fn handle_event(&self, event: ActiveSequenceEvent) {
if let Some(lora_name) = event.lora_name {
match event.data {
ActiveSequenceEventData::AddRequest { .. } => {
self.increment_load(&lora_name);
}
ActiveSequenceEventData::Free => {
self.decrement_load(&lora_name);
}
ActiveSequenceEventData::MarkPrefillCompleted => {
}
}
}
}
fn increment_load(&self, lora_name: &str) {
let now = Instant::now();
let max_samples = self.config.max_samples;
self.data
.entry(lora_name.to_string())
.and_modify(|data| {
data.active_count += 1;
data.samples.push_back(LoadSample {
timestamp: now,
active_count: data.active_count,
});
while data.samples.len() > max_samples {
data.samples.pop_front();
}
})
.or_insert_with(|| {
let mut data = LoraLoadData {
active_count: 1,
samples: VecDeque::new(),
};
data.samples.push_back(LoadSample {
timestamp: now,
active_count: 1,
});
data
});
}
fn decrement_load(&self, lora_name: &str) {
let now = Instant::now();
let max_samples = self.config.max_samples;
if let Some(mut entry) = self.data.get_mut(lora_name) {
let data = entry.value_mut();
data.active_count = data.active_count.saturating_sub(1);
data.samples.push_back(LoadSample {
timestamp: now,
active_count: data.active_count,
});
while data.samples.len() > max_samples {
data.samples.pop_front();
}
}
}
fn update_from_counts(&self, lora_counts: HashMap<String, usize>) {
let now = Instant::now();
let max_samples = self.config.max_samples;
for (lora_name, count) in &lora_counts {
self.data
.entry(lora_name.clone())
.and_modify(|data| {
data.active_count = *count;
data.samples.push_back(LoadSample {
timestamp: now,
active_count: *count,
});
while data.samples.len() > max_samples {
data.samples.pop_front();
}
})
.or_insert_with(|| {
let mut data = LoraLoadData {
active_count: *count,
samples: VecDeque::new(),
};
data.samples.push_back(LoadSample {
timestamp: now,
active_count: *count,
});
data
});
}
for mut entry in self.data.iter_mut() {
if !lora_counts.contains_key(entry.key()) {
let data = entry.value_mut();
if data.active_count > 0 {
data.active_count = 0;
data.samples.push_back(LoadSample {
timestamp: now,
active_count: 0,
});
while data.samples.len() > max_samples {
data.samples.pop_front();
}
}
}
}
}
pub fn get_current_load(&self) -> HashMap<String, usize> {
self.data
.iter()
.filter(|entry| entry.value().active_count > 0)
.map(|entry| (entry.key().clone(), entry.value().active_count))
.collect()
}
pub fn get_time_series(&self) -> HashMap<String, Vec<LoadSample>> {
self.data
.iter()
.map(|entry| {
(
entry.key().clone(),
entry.value().samples.iter().cloned().collect(),
)
})
.collect()
}
}
impl Default for LoadEstimator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_load_estimator_time_series() {
let estimator = LoadEstimator::new();
let mut counts = HashMap::new();
counts.insert("lora-math".to_string(), 5);
counts.insert("lora-code".to_string(), 3);
estimator.update_from_counts(counts);
let all_series = estimator.get_time_series();
let series_math = all_series.get("lora-math").unwrap();
let series_code = all_series.get("lora-code").unwrap();
assert_eq!(series_math.len(), 1);
assert_eq!(series_math[0].active_count, 5);
assert_eq!(series_code.len(), 1);
assert_eq!(series_code[0].active_count, 3);
assert!(!all_series.contains_key("lora-xyz"));
}
#[test]
fn test_load_estimator_max_samples() {
let config = LoadEstimatorConfig {
max_samples: 2,
..Default::default()
};
let estimator = LoadEstimator::with_config(config);
for count in [1, 2, 3] {
let mut counts = HashMap::new();
counts.insert("lora-math".to_string(), count);
estimator.update_from_counts(counts);
}
let all_series = estimator.get_time_series();
let series = all_series.get("lora-math").unwrap();
assert_eq!(series.len(), 2);
assert_eq!(series[0].active_count, 2);
assert_eq!(series[1].active_count, 3);
}
#[test]
fn test_increment_decrement_atomicity() {
let estimator = LoadEstimator::new();
estimator.increment_load("lora-test");
estimator.increment_load("lora-test");
let load = estimator.get_current_load();
assert_eq!(load.get("lora-test"), Some(&2));
estimator.decrement_load("lora-test");
let load = estimator.get_current_load();
assert_eq!(load.get("lora-test"), Some(&1));
let series = estimator.get_time_series();
let samples = series.get("lora-test").unwrap();
assert_eq!(samples.len(), 3);
assert_eq!(samples[0].active_count, 1);
assert_eq!(samples[1].active_count, 2);
assert_eq!(samples[2].active_count, 1);
}
}