dynamo_llm/kv_router/
scoring.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Scoring functions for the KV router.
5
6use super::protocols::{ForwardPassMetrics, LoadMetrics};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
11pub struct LoadEvent {
12    pub worker_id: i64,
13    pub data: ForwardPassMetrics,
14}
15
16/// [gluo FIXME] exactly the same as EndpointInfo except that 'data'
17/// is cleaned (not optional)
18#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
19pub struct Endpoint {
20    pub name: String,
21    pub subject: String,
22    pub data: LoadMetrics,
23}
24
25impl Endpoint {
26    pub fn worker_id(&self) -> i64 {
27        i64::from_str_radix(
28            self.subject
29                .split("-")
30                .last()
31                .expect("invalid subject")
32                .to_string()
33                .as_str(),
34            16,
35        )
36        .expect("invalid worker id")
37    }
38}
39
40#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)]
41pub struct ProcessedEndpoints {
42    pub endpoints: HashMap<i64, Endpoint>,
43    pub load_avg: f64,
44    pub load_std: f64,
45}
46
47impl ProcessedEndpoints {
48    pub fn new(endpoints: Vec<Endpoint>) -> Self {
49        // compute some basic statistics
50        let load_values: Vec<f64> = endpoints
51            .iter()
52            .map(|endpoint| endpoint.data.kv_active_blocks() as f64)
53            .collect();
54        let load_avg = load_values.iter().copied().sum::<f64>() / load_values.len() as f64;
55        let variance = load_values
56            .iter()
57            .map(|&x| (x - load_avg).powi(2))
58            .sum::<f64>()
59            / load_values.len() as f64;
60        let load_std = variance.sqrt();
61
62        let endpoints = endpoints.into_iter().map(|e| (e.worker_id(), e)).collect();
63
64        ProcessedEndpoints {
65            endpoints,
66            load_avg,
67            load_std,
68        }
69    }
70
71    pub fn worker_ids(&self) -> Vec<i64> {
72        self.endpoints.keys().copied().collect()
73    }
74
75    pub fn active_blocks(&self) -> HashMap<i64, usize> {
76        self.endpoints
77            .iter()
78            .map(|(&worker_id, endpoint)| (worker_id, endpoint.data.kv_active_blocks() as usize))
79            .collect()
80    }
81}