nemo_flow_adaptive/learner/
latency.rs1use std::pin::Pin;
7use std::sync::{Arc, RwLock};
8
9use crate::error::{AdaptiveError, Result};
10use crate::learner::traits::Learner;
11use crate::storage::traits::StorageBackendDyn;
12use crate::trie::builder::{PredictionTrieBuilder, SensitivityConfig};
13use crate::trie::data_models::PredictionTrieNode;
14use crate::trie::serialization::TrieEnvelope;
15use crate::types::cache::HotCache;
16use crate::types::metadata::AgentHints;
17use crate::types::records::RunRecord;
18
19pub struct LatencySensitivityLearner {
21 config: SensitivityConfig,
22 agent_id: String,
23}
24
25impl LatencySensitivityLearner {
26 pub fn new(agent_id: impl Into<String>, config: SensitivityConfig) -> Self {
35 Self {
36 config,
37 agent_id: agent_id.into(),
38 }
39 }
40}
41
42impl Learner for LatencySensitivityLearner {
43 fn process_run<'a>(
44 &'a self,
45 run: &'a RunRecord,
46 backend: &'a dyn StorageBackendDyn,
47 hot_cache: &'a Arc<RwLock<HotCache>>,
48 ) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + 'a>> {
49 Box::pin(async move {
50 let existing = backend.load_accumulators(&self.agent_id).await?;
51 let mut builder = PredictionTrieBuilder::with_accumulators(
52 existing.unwrap_or_default(),
53 Some(self.config.clone()),
54 );
55 builder.add_run(run);
56 let trie_root = builder.build();
57
58 backend
59 .store_accumulators(&self.agent_id, builder.accumulators())
60 .await?;
61 let envelope = TrieEnvelope::new(trie_root.clone(), &self.agent_id);
62 backend.store_trie(&self.agent_id, &envelope).await?;
63
64 {
65 let mut guard = hot_cache.write().map_err(|error| {
66 AdaptiveError::Internal(format!("hot cache lock poisoned: {error}"))
67 })?;
68 guard.agent_hints_default =
69 compute_default_hints(&trie_root, self.config.sensitivity_scale);
70 guard.trie = Some(trie_root);
71 }
72
73 Ok(())
74 })
75 }
76}
77
78pub fn compute_default_hints(
88 trie_root: &PredictionTrieNode,
89 sensitivity_scale: u32,
90) -> Option<AgentHints> {
91 let prediction = trie_root.predictions_any_index.as_ref()?;
92
93 let latency_sensitivity = prediction.latency_sensitivity.unwrap_or(1);
94 let priority = (sensitivity_scale as i32 - latency_sensitivity as i32).max(0);
95
96 Some(AgentHints {
97 osl: prediction.output_tokens.p90.round() as u32,
98 iat: prediction.interarrival_ms.mean.round() as u32,
99 priority,
100 latency_sensitivity: if prediction.latency_sensitivity.is_some() {
101 latency_sensitivity as f64
102 } else {
103 0.0
104 },
105 prefix_id: "default".to_string(),
106 total_requests: prediction.remaining_calls.mean.round() as u32 + 1,
107 })
108}