dynamo_llm/kv_router/
scoring.rs1use super::protocols::{ForwardPassMetrics, LoadMetrics};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
11pub struct LoadEvent {
12 pub worker_id: i64,
13 pub data: ForwardPassMetrics,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
19pub struct Endpoint {
20 pub name: String,
21 pub subject: String,
22 pub data: LoadMetrics,
23}
24
25impl Endpoint {
26 pub fn worker_id(&self) -> i64 {
27 i64::from_str_radix(
28 self.subject
29 .split("-")
30 .last()
31 .expect("invalid subject")
32 .to_string()
33 .as_str(),
34 16,
35 )
36 .expect("invalid worker id")
37 }
38}
39
40#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)]
41pub struct ProcessedEndpoints {
42 pub endpoints: HashMap<i64, Endpoint>,
43 pub load_avg: f64,
44 pub load_std: f64,
45}
46
47impl ProcessedEndpoints {
48 pub fn new(endpoints: Vec<Endpoint>) -> Self {
49 let load_values: Vec<f64> = endpoints
51 .iter()
52 .map(|endpoint| endpoint.data.kv_active_blocks() as f64)
53 .collect();
54 let load_avg = load_values.iter().copied().sum::<f64>() / load_values.len() as f64;
55 let variance = load_values
56 .iter()
57 .map(|&x| (x - load_avg).powi(2))
58 .sum::<f64>()
59 / load_values.len() as f64;
60 let load_std = variance.sqrt();
61
62 let endpoints = endpoints.into_iter().map(|e| (e.worker_id(), e)).collect();
63
64 ProcessedEndpoints {
65 endpoints,
66 load_avg,
67 load_std,
68 }
69 }
70
71 pub fn worker_ids(&self) -> Vec<i64> {
72 self.endpoints.keys().copied().collect()
73 }
74
75 pub fn active_blocks(&self) -> HashMap<i64, usize> {
76 self.endpoints
77 .iter()
78 .map(|(&worker_id, endpoint)| (worker_id, endpoint.data.kv_active_blocks() as usize))
79 .collect()
80 }
81}