Skip to main content

nemo_flow_adaptive/
acg_learner.rs

1// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Adaptive Cache Governor (ACG) learner for the adaptive telemetry pipeline.
5
6use std::collections::{HashMap, VecDeque};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::{Arc, RwLock};
10
11use crate::acg::ir_builder::build_prompt_ir;
12use crate::acg::prompt_ir::PromptIR;
13use crate::acg::stability::{StabilityThresholds, analyze_stability};
14
15use crate::acg_profile::derive_acg_learning_key;
16use crate::error::{AdaptiveError, Result};
17use crate::learner::traits::Learner;
18use crate::storage::traits::StorageBackendDyn;
19use crate::types::cache::HotCache;
20use crate::types::records::{CallKind, RunRecord};
21
22/// Learner that derives prompt stability state for ACG.
23///
24/// This learner groups annotated LLM requests by derived ACG profile key,
25/// builds prompt IR observations, persists a bounded observation window, and
26/// updates the hot cache with the latest stability results.
27pub struct AcgLearner {
28    agent_id: String,
29    observation_window: usize,
30    thresholds: StabilityThresholds,
31}
32
33impl AcgLearner {
34    /// Create a new ACG learner.
35    ///
36    /// # Parameters
37    /// - `agent_id`: Agent identifier whose observations should be updated.
38    /// - `observation_window`: Maximum number of observations to retain per
39    ///   profile.
40    /// - `thresholds`: Stability thresholds used during analysis.
41    ///
42    /// # Returns
43    /// A configured [`AcgLearner`].
44    pub fn new(
45        agent_id: impl Into<String>,
46        observation_window: usize,
47        thresholds: StabilityThresholds,
48    ) -> Self {
49        Self {
50            agent_id: agent_id.into(),
51            observation_window,
52            thresholds,
53        }
54    }
55}
56
57impl Learner for AcgLearner {
58    fn process_run<'a>(
59        &'a self,
60        run: &'a RunRecord,
61        backend: &'a dyn StorageBackendDyn,
62        hot_cache: &'a Arc<RwLock<HotCache>>,
63    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
64        Box::pin(async move {
65            let mut grouped_observations: HashMap<String, Vec<PromptIR>> = run
66                .calls
67                .iter()
68                .filter(|call| call.kind == CallKind::Llm)
69                .filter_map(|call| call.annotated_request.as_ref())
70                .filter_map(|request| {
71                    build_prompt_ir(request).ok().map(|prompt_ir| {
72                        (derive_acg_learning_key(&self.agent_id, request), prompt_ir)
73                    })
74                })
75                .fold(HashMap::new(), |mut grouped, (key, prompt_ir)| {
76                    grouped.entry(key).or_default().push(prompt_ir);
77                    grouped
78                });
79
80            if grouped_observations.is_empty() {
81                return Ok(());
82            }
83
84            let mut profile_stability = HashMap::new();
85            let mut profile_counts = HashMap::new();
86            let mut best_profile_seed: Option<(
87                Vec<PromptIR>,
88                crate::acg::stability::StabilityAnalysisResult,
89            )> = None;
90
91            for (profile_key, new_observations) in grouped_observations.drain() {
92                let existing = backend.load_observations(&profile_key).await?;
93                let mut window: VecDeque<PromptIR> =
94                    existing.unwrap_or_default().into_iter().collect();
95
96                for observation in new_observations {
97                    if window.len() >= self.observation_window {
98                        window.pop_front();
99                    }
100                    window.push_back(observation);
101                }
102
103                let observations_vec: Vec<PromptIR> = window.into_iter().collect();
104                backend
105                    .store_observations(&profile_key, &observations_vec)
106                    .await?;
107
108                let stability_result = analyze_stability(&observations_vec, &self.thresholds);
109                backend
110                    .store_stability(&profile_key, &stability_result)
111                    .await?;
112
113                profile_counts.insert(profile_key.clone(), stability_result.total_observations);
114                profile_stability.insert(profile_key, stability_result.clone());
115
116                let replace_best = best_profile_seed
117                    .as_ref()
118                    .map(|(_, current)| {
119                        (
120                            stability_result.stable_prefix_length,
121                            stability_result.total_observations,
122                        ) > (current.stable_prefix_length, current.total_observations)
123                    })
124                    .unwrap_or(true);
125                if replace_best {
126                    best_profile_seed = Some((observations_vec.clone(), stability_result.clone()));
127                }
128            }
129
130            if let Some((aggregate_observations, aggregate_stability)) = best_profile_seed.as_ref()
131            {
132                // Persist the runtime seed entry under plain agent_id so registration can
133                // rehydrate HotCache without scanning profile-specific keys.
134                backend
135                    .store_observations(&self.agent_id, aggregate_observations)
136                    .await?;
137                backend
138                    .store_stability(&self.agent_id, aggregate_stability)
139                    .await?;
140            }
141
142            let mut guard = hot_cache.write().map_err(|error| {
143                AdaptiveError::Internal(format!("hot cache lock poisoned: {error}"))
144            })?;
145            guard.acg_profiles.extend(profile_stability);
146            guard.acg_profile_observation_counts.extend(profile_counts);
147            if let Some((_, aggregate_stability)) = best_profile_seed {
148                guard.acg_observation_count = aggregate_stability.total_observations;
149                guard.acg_stability = Some(aggregate_stability);
150            }
151
152            Ok(())
153        })
154    }
155}
156
157#[cfg(test)]
158#[path = "../tests/unit/acg_learner_tests.rs"]
159mod tests;