use std::sync::atomic::{AtomicBool, Ordering};
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::transports::event_plane::EventSubscriber;
use crate::agents::trace::{
AgentRequestMetrics, AgentToolEventRelay, AgentTracePolicy, AgentTraceRecord,
DEFAULT_TOOL_EVENTS_TOPIC, WorkerInfo,
};
use crate::local_model::LocalModel;
use crate::protocols::common::timing::RequestTracker;
pub(crate) fn record_llm_metric_tokens(
tracker: Option<&RequestTracker>,
input_tokens: Option<usize>,
output_tokens: usize,
cached_tokens: Option<usize>,
) {
let Some(tracker) = tracker else {
return;
};
if input_tokens.is_some() || cached_tokens.is_some() {
tracker.record_isl(input_tokens.unwrap_or(0), cached_tokens);
}
tracker.record_osl(output_tokens);
}
static TOOL_EVENT_INGEST_STARTED: AtomicBool = AtomicBool::new(false);
static TOOL_EVENT_RELAY_STARTED: AtomicBool = AtomicBool::new(false);
pub(crate) fn request_metrics(
request_id: String,
x_request_id: Option<String>,
model: String,
tracker: Option<&RequestTracker>,
) -> AgentRequestMetrics {
let timing = tracker.map(RequestTracker::get_timing_info);
let worker = tracker.and_then(|tracker| {
tracker.get_worker_info().map(|worker| WorkerInfo {
prefill_worker_id: worker.prefill_worker_id,
prefill_dp_rank: worker.prefill_dp_rank,
decode_worker_id: worker.decode_worker_id,
decode_dp_rank: worker.decode_dp_rank,
})
});
AgentRequestMetrics {
request_id,
x_request_id,
model,
input_tokens: tracker.and_then(|tracker| tracker.isl_tokens().map(|v| v as u64)),
output_tokens: tracker.map(RequestTracker::osl_tokens),
cached_tokens: tracker.and_then(|tracker| tracker.cached_tokens().map(|v| v as u64)),
request_received_ms: timing.as_ref().map(|timing| timing.request_received_ms),
prefill_wait_time_ms: timing
.as_ref()
.and_then(|timing| timing.prefill_wait_time_ms),
prefill_time_ms: timing.as_ref().and_then(|timing| timing.prefill_time_ms),
ttft_ms: timing.as_ref().and_then(|timing| timing.ttft_ms),
total_time_ms: timing.as_ref().and_then(|timing| timing.total_time_ms),
avg_itl_ms: tracker.and_then(RequestTracker::avg_itl_ms),
kv_hit_rate: timing.as_ref().and_then(|timing| timing.kv_hit_rate),
kv_transfer_estimated_latency_ms: timing
.as_ref()
.and_then(|timing| timing.kv_transfer_estimated_latency_ms),
queue_depth: timing
.as_ref()
.and_then(|timing| timing.router_queue_depth.map(|v| v as u64)),
worker,
replay: None,
}
}
pub(crate) async fn start_tool_event_ingest_from_policy(
drt: DistributedRuntime,
local_model: &LocalModel,
) -> anyhow::Result<()> {
let policy = super::policy();
if policy.tool_events_zmq_endpoint.is_none() {
return Ok(());
}
start_tool_event_relay_from_policy(drt.clone(), local_model, policy).await?;
if TOOL_EVENT_INGEST_STARTED
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_err()
{
tracing::debug!("agent tool event ingest already started");
return Ok(());
}
let namespace_name = tool_events_namespace(local_model);
let mut subscriber = match async {
let namespace = drt.namespace(namespace_name.clone())?;
EventSubscriber::for_namespace(&namespace, DEFAULT_TOOL_EVENTS_TOPIC)
.await
.map(|sub| sub.typed::<AgentTraceRecord>())
}
.await
{
Ok(subscriber) => subscriber,
Err(error) => {
TOOL_EVENT_INGEST_STARTED.store(false, Ordering::Release);
return Err(error);
}
};
let shutdown = drt.child_token();
drt.runtime().secondary().spawn(async move {
tracing::info!(
namespace = %namespace_name,
topic = DEFAULT_TOOL_EVENTS_TOPIC,
"Agent tool event ingest started"
);
loop {
tokio::select! {
_ = shutdown.cancelled() => {
tracing::debug!("agent tool event ingest stopping");
break;
}
next = subscriber.next() => {
match next {
Some(Ok((_envelope, record))) => {
super::publish_tool_record(record);
}
Some(Err(error)) => {
tracing::warn!(%error, "agent tool event ingest failed to decode event");
}
None => {
tracing::warn!("agent tool event ingest stream ended");
break;
}
}
}
}
}
TOOL_EVENT_INGEST_STARTED.store(false, Ordering::Release);
tracing::info!("agent tool event ingest stopped");
});
Ok(())
}
async fn start_tool_event_relay_from_policy(
drt: DistributedRuntime,
local_model: &LocalModel,
policy: &AgentTracePolicy,
) -> anyhow::Result<()> {
let Some(zmq_endpoint) = policy.tool_events_zmq_endpoint.clone() else {
return Ok(());
};
if TOOL_EVENT_RELAY_STARTED
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_err()
{
tracing::debug!("agent tool event relay already started");
return Ok(());
}
let namespace_name = tool_events_namespace(local_model);
let relay = match async {
let namespace = drt.namespace(namespace_name.clone())?;
let component = namespace.component(local_model.endpoint_id().component.clone())?;
AgentToolEventRelay::start(
component,
zmq_endpoint.clone(),
policy.tool_events_zmq_topic.clone(),
Some(namespace_name.clone()),
Some(DEFAULT_TOOL_EVENTS_TOPIC.to_string()),
)
.await
}
.await
{
Ok(relay) => relay,
Err(error) => {
TOOL_EVENT_RELAY_STARTED.store(false, Ordering::Release);
return Err(error);
}
};
let shutdown = drt.child_token();
drt.runtime().secondary().spawn(async move {
tracing::info!(
namespace = %namespace_name,
topic = DEFAULT_TOOL_EVENTS_TOPIC,
zmq_endpoint = %zmq_endpoint,
"Agent tool event relay started"
);
shutdown.cancelled().await;
relay.shutdown();
TOOL_EVENT_RELAY_STARTED.store(false, Ordering::Release);
tracing::info!("agent tool event relay stopped");
});
Ok(())
}
fn tool_events_namespace(local_model: &LocalModel) -> String {
local_model
.namespace()
.map(str::to_string)
.unwrap_or_else(|| local_model.endpoint_id().namespace.clone())
}
#[cfg(test)]
mod tests {
use std::{thread, time::Duration};
use crate::protocols::common::timing::{RequestTracker, WORKER_TYPE_DECODE};
use super::request_metrics;
#[test]
fn test_request_metrics_from_tracker() {
let tracker = RequestTracker::new();
tracker.record_isl(128, Some(32));
tracker.record_kv_hit(4.0, 8);
tracker.record_osl(5);
tracker.record_router_queue_depth(3);
tracker.record_worker(17, Some(2), WORKER_TYPE_DECODE);
tracker.record_prefill_start();
thread::sleep(Duration::from_millis(1));
tracker.record_first_token();
tracker.record_prefill_complete();
thread::sleep(Duration::from_millis(1));
tracker.record_decode_first_token();
thread::sleep(Duration::from_millis(1));
tracker.record_finish();
let metrics = request_metrics(
"req-1".to_string(),
Some("llm-call-1".to_string()),
"test-model".to_string(),
Some(&tracker),
);
assert_eq!(metrics.request_id, "req-1");
assert_eq!(metrics.x_request_id.as_deref(), Some("llm-call-1"));
assert_eq!(metrics.model, "test-model");
assert_eq!(metrics.input_tokens, Some(128));
assert_eq!(metrics.output_tokens, Some(5));
assert_eq!(metrics.cached_tokens, Some(32));
assert!(metrics.request_received_ms.is_some_and(|ms| ms > 0));
assert!(metrics.prefill_wait_time_ms.is_some());
assert!(metrics.prefill_time_ms.is_some());
assert!(metrics.ttft_ms.is_some());
assert!(metrics.total_time_ms.is_some());
assert!(metrics.avg_itl_ms.is_some());
assert_eq!(metrics.kv_hit_rate, Some(0.5));
assert!(metrics.kv_transfer_estimated_latency_ms.is_some());
assert_eq!(metrics.queue_depth, Some(3));
let worker = metrics.worker.expect("worker info should be set");
assert_eq!(worker.prefill_worker_id, Some(17));
assert_eq!(worker.prefill_dp_rank, Some(2));
assert_eq!(worker.decode_worker_id, Some(17));
assert_eq!(worker.decode_dp_rank, Some(2));
}
#[test]
fn test_request_metrics_without_tracker_is_partial() {
let metrics = request_metrics(
"req-1".to_string(),
Some("llm-call-1".to_string()),
"test-model".to_string(),
None,
);
assert_eq!(metrics.request_id, "req-1");
assert_eq!(metrics.x_request_id.as_deref(), Some("llm-call-1"));
assert_eq!(metrics.model, "test-model");
assert_eq!(metrics.input_tokens, None);
assert_eq!(metrics.output_tokens, None);
assert_eq!(metrics.cached_tokens, None);
assert_eq!(metrics.request_received_ms, None);
assert_eq!(metrics.prefill_wait_time_ms, None);
assert_eq!(metrics.prefill_time_ms, None);
assert_eq!(metrics.ttft_ms, None);
assert_eq!(metrics.total_time_ms, None);
assert_eq!(metrics.avg_itl_ms, None);
assert_eq!(metrics.kv_hit_rate, None);
assert_eq!(metrics.kv_transfer_estimated_latency_ms, None);
assert_eq!(metrics.queue_depth, None);
assert!(metrics.worker.is_none());
}
}