Skip to main content

nemo_flow_adaptive/learner/
latency.rs

1// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Latency-sensitivity learner implementation.
5
6use std::pin::Pin;
7use std::sync::{Arc, RwLock};
8
9use crate::error::{AdaptiveError, Result};
10use crate::learner::traits::Learner;
11use crate::storage::traits::StorageBackendDyn;
12use crate::trie::builder::{PredictionTrieBuilder, SensitivityConfig};
13use crate::trie::data_models::PredictionTrieNode;
14use crate::trie::serialization::TrieEnvelope;
15use crate::types::cache::HotCache;
16use crate::types::metadata::AgentHints;
17use crate::types::records::RunRecord;
18
19/// Learner that derives default latency sensitivity hints from run history.
20pub struct LatencySensitivityLearner {
21    config: SensitivityConfig,
22    agent_id: String,
23}
24
25impl LatencySensitivityLearner {
26    /// Create a new latency-sensitivity learner.
27    ///
28    /// # Parameters
29    /// - `agent_id`: Agent identifier whose trie state should be updated.
30    /// - `config`: Sensitivity-derivation configuration for the trie builder.
31    ///
32    /// # Returns
33    /// A configured [`LatencySensitivityLearner`].
34    pub fn new(agent_id: impl Into<String>, config: SensitivityConfig) -> Self {
35        Self {
36            config,
37            agent_id: agent_id.into(),
38        }
39    }
40}
41
42impl Learner for LatencySensitivityLearner {
43    fn process_run<'a>(
44        &'a self,
45        run: &'a RunRecord,
46        backend: &'a dyn StorageBackendDyn,
47        hot_cache: &'a Arc<RwLock<HotCache>>,
48    ) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + 'a>> {
49        Box::pin(async move {
50            let existing = backend.load_accumulators(&self.agent_id).await?;
51            let mut builder = PredictionTrieBuilder::with_accumulators(
52                existing.unwrap_or_default(),
53                Some(self.config.clone()),
54            );
55            builder.add_run(run);
56            let trie_root = builder.build();
57
58            backend
59                .store_accumulators(&self.agent_id, builder.accumulators())
60                .await?;
61            let envelope = TrieEnvelope::new(trie_root.clone(), &self.agent_id);
62            backend.store_trie(&self.agent_id, &envelope).await?;
63
64            {
65                let mut guard = hot_cache.write().map_err(|error| {
66                    AdaptiveError::Internal(format!("hot cache lock poisoned: {error}"))
67                })?;
68                guard.agent_hints_default =
69                    compute_default_hints(&trie_root, self.config.sensitivity_scale);
70                guard.trie = Some(trie_root);
71            }
72
73            Ok(())
74        })
75    }
76}
77
78/// Compute default agent hints from the root trie prediction.
79///
80/// # Parameters
81/// - `trie_root`: Root node of the learned prediction trie.
82/// - `sensitivity_scale`: Scheduling scale used to derive the priority hint.
83///
84/// # Returns
85/// `Some(AgentHints)` when the trie contains an any-index prediction at the
86/// root and `None` otherwise.
87pub fn compute_default_hints(
88    trie_root: &PredictionTrieNode,
89    sensitivity_scale: u32,
90) -> Option<AgentHints> {
91    let prediction = trie_root.predictions_any_index.as_ref()?;
92
93    let latency_sensitivity = prediction.latency_sensitivity.unwrap_or(1);
94    let priority = (sensitivity_scale as i32 - latency_sensitivity as i32).max(0);
95
96    Some(AgentHints {
97        osl: prediction.output_tokens.p90.round() as u32,
98        iat: prediction.interarrival_ms.mean.round() as u32,
99        priority,
100        latency_sensitivity: if prediction.latency_sensitivity.is_some() {
101            latency_sensitivity as f64
102        } else {
103            0.0
104        },
105        prefix_id: "default".to_string(),
106        total_requests: prediction.remaining_calls.mean.round() as u32 + 1,
107    })
108}