dynamo_llm/kv_router/
metrics_aggregator.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::sync::Once;
5
6use crate::kv_router::KV_METRICS_ENDPOINT;
7pub use crate::kv_router::protocols::{ForwardPassMetrics, LoadMetrics, PredictiveLoadMetrics};
8
9use crate::kv_router::ProcessedEndpoints;
10use crate::kv_router::scoring::Endpoint;
11use dynamo_runtime::component::Component;
12use dynamo_runtime::{Result, service::EndpointInfo, utils::Duration};
13use tokio::sync::watch;
14use tokio_util::sync::CancellationToken;
15
16static METRICS_WAITING_MESSAGE: Once = Once::new();
17static METRICS_FOUND_MESSAGE: Once = Once::new();
18
19pub struct EndpointCollector {
20    pub service_name: String,
21    pub endpoints_rx: watch::Receiver<ProcessedEndpoints>,
22}
23
24impl EndpointCollector {
25    pub async fn new(component: Component, cancellation_token: CancellationToken) -> Self {
26        let (watch_tx, watch_rx) = watch::channel(ProcessedEndpoints::default());
27
28        tokio::spawn(collect_endpoints_task(
29            component.clone(),
30            watch_tx,
31            cancellation_token.clone(),
32            "generate".to_string(),
33        ));
34
35        Self {
36            service_name: component.service_name(),
37            endpoints_rx: watch_rx,
38        }
39    }
40
41    pub fn get_endpoints(&self) -> ProcessedEndpoints {
42        self.endpoints_rx.borrow().clone()
43    }
44
45    pub fn endpoints_watcher(&self) -> watch::Receiver<ProcessedEndpoints> {
46        self.endpoints_rx.clone()
47    }
48}
49
50pub struct KvMetricsAggregator {
51    pub service_name: String,
52    pub endpoints_rx: watch::Receiver<ProcessedEndpoints>,
53}
54
55impl KvMetricsAggregator {
56    pub async fn new(component: Component, cancellation_token: CancellationToken) -> Self {
57        let (watch_tx, watch_rx) = watch::channel(ProcessedEndpoints::default());
58
59        tokio::spawn(collect_endpoints_task(
60            component.clone(),
61            watch_tx,
62            cancellation_token.clone(),
63            KV_METRICS_ENDPOINT.to_string(),
64        ));
65
66        Self {
67            service_name: component.service_name(),
68            endpoints_rx: watch_rx,
69        }
70    }
71
72    pub fn get_endpoints(&self) -> ProcessedEndpoints {
73        self.endpoints_rx.borrow().clone()
74    }
75
76    pub fn endpoints_watcher(&self) -> watch::Receiver<ProcessedEndpoints> {
77        self.endpoints_rx.clone()
78    }
79}
80
81/// [gluo TODO] 'collect_endpoints' is from component/metrics,
82/// should consolidate these functions into generic metrics aggregator
83/// functions and shared by KvMetricsAggregator and component/metrics.
84/// Collect endpoints from a component
85pub async fn collect_endpoints(
86    component: &Component,
87    subject: &str,
88    timeout: Duration,
89) -> Result<Vec<EndpointInfo>> {
90    // Collect stats from each backend
91    let stream = component.scrape_stats(timeout).await?;
92
93    // Filter the stats by the service subject
94    let endpoints = stream
95        .into_endpoints()
96        .filter(|e| e.subject.starts_with(subject))
97        .collect::<Vec<_>>();
98    if endpoints.is_empty() {
99        // Only print it once, we poll while the worker starts
100        METRICS_WAITING_MESSAGE.call_once(|| {
101            tracing::debug!("Waiting for metrics endpoint..");
102        });
103    } else {
104        METRICS_FOUND_MESSAGE.call_once(|| {
105            tracing::debug!("Found metrics endpoint");
106        });
107    }
108
109    Ok(endpoints)
110}
111
112pub async fn collect_endpoints_task(
113    component: Component,
114    watch_tx: watch::Sender<ProcessedEndpoints>,
115    cancel: CancellationToken,
116    subject: String,
117) {
118    let backoff_delay = Duration::from_millis(100);
119    let scrape_timeout = Duration::from_millis(300);
120    let endpoint = component.endpoint(&subject);
121    let service_subject = endpoint.subject();
122
123    // Keep track of the last sent value to avoid unnecessary updates
124    let mut last_sent: Option<ProcessedEndpoints> = None;
125
126    loop {
127        tokio::select! {
128            _ = cancel.cancelled() => {
129                break;
130            }
131            _ = tokio::time::sleep(backoff_delay) => {
132                tracing::trace!("collecting endpoints for service: {}", service_subject);
133                let unfiltered_endpoints =
134                    match collect_endpoints(&component, &service_subject, scrape_timeout).await
135                    {
136                        Ok(v) => v,
137                        Err(e) => {
138                            tracing::warn!("Failed to retrieve endpoints for {}: {:?}", service_subject, e);
139                            continue;
140                        }
141                    };
142
143                let endpoints: Vec<Endpoint> = if subject == KV_METRICS_ENDPOINT {
144                    // Original filtering behavior
145                    unfiltered_endpoints
146                        .into_iter()
147                        .filter_map(|s| {
148                            s.data?
149                                .decode::<ForwardPassMetrics>()
150                                .map(|data| Endpoint {
151                                    name: s.name,
152                                    subject: s.subject,
153                                    data: LoadMetrics::EngineLoadMetrics(data),
154                                })
155                                .inspect_err(|e| {
156                                    tracing::warn!("skip endpoint data that can't be parsed as ForwardPassMetrics: {:?}", e);
157                                })
158                                .ok()
159                        })
160                        .collect()
161                } else {
162                    // No filtering - just use default LoadMetrics
163                    unfiltered_endpoints
164                        .into_iter()
165                        .map(|s| Endpoint {
166                            name: s.name,
167                            subject: s.subject,
168                            data: LoadMetrics::default(),
169                        })
170                        .collect()
171                };
172
173                tracing::trace!("Found {} endpoints for service: {service_subject}", endpoints.len());
174
175                let processed = ProcessedEndpoints::new(endpoints);
176
177                // Only send if different from last sent value
178                // This is necessary because the watch channel does not track changes
179                // https://docs.rs/tokio/latest/tokio/sync/watch/struct.Receiver.html#method.has_changed
180                let should_send = match &last_sent {
181                    Some(last) => last != &processed,
182                    None => true,
183                };
184
185                if should_send {
186                    tracing::trace!("Endpoints changed, sending update for service: {service_subject}");
187                    if watch_tx.send(processed.clone()).is_err() {
188                        tracing::error!("failed to send processed endpoints; shutting down");
189                        break;
190                    }
191                    last_sent = Some(processed);
192                } else {
193                    tracing::trace!("Endpoints unchanged, skipping update for service: {service_subject}");
194                }
195            }
196        }
197    }
198}