nemo_flow_adaptive/
acg_learner.rs1use std::collections::{HashMap, VecDeque};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::{Arc, RwLock};
10
11use crate::acg::ir_builder::build_prompt_ir;
12use crate::acg::prompt_ir::PromptIR;
13use crate::acg::stability::{StabilityThresholds, analyze_stability};
14
15use crate::acg_profile::derive_acg_learning_key;
16use crate::error::{AdaptiveError, Result};
17use crate::learner::traits::Learner;
18use crate::storage::traits::StorageBackendDyn;
19use crate::types::cache::HotCache;
20use crate::types::records::{CallKind, RunRecord};
21
22pub struct AcgLearner {
28 agent_id: String,
29 observation_window: usize,
30 thresholds: StabilityThresholds,
31}
32
33impl AcgLearner {
34 pub fn new(
45 agent_id: impl Into<String>,
46 observation_window: usize,
47 thresholds: StabilityThresholds,
48 ) -> Self {
49 Self {
50 agent_id: agent_id.into(),
51 observation_window,
52 thresholds,
53 }
54 }
55}
56
57impl Learner for AcgLearner {
58 fn process_run<'a>(
59 &'a self,
60 run: &'a RunRecord,
61 backend: &'a dyn StorageBackendDyn,
62 hot_cache: &'a Arc<RwLock<HotCache>>,
63 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
64 Box::pin(async move {
65 let mut grouped_observations: HashMap<String, Vec<PromptIR>> = run
66 .calls
67 .iter()
68 .filter(|call| call.kind == CallKind::Llm)
69 .filter_map(|call| call.annotated_request.as_ref())
70 .filter_map(|request| {
71 build_prompt_ir(request).ok().map(|prompt_ir| {
72 (derive_acg_learning_key(&self.agent_id, request), prompt_ir)
73 })
74 })
75 .fold(HashMap::new(), |mut grouped, (key, prompt_ir)| {
76 grouped.entry(key).or_default().push(prompt_ir);
77 grouped
78 });
79
80 if grouped_observations.is_empty() {
81 return Ok(());
82 }
83
84 let mut profile_stability = HashMap::new();
85 let mut profile_counts = HashMap::new();
86 let mut best_profile_seed: Option<(
87 Vec<PromptIR>,
88 crate::acg::stability::StabilityAnalysisResult,
89 )> = None;
90
91 for (profile_key, new_observations) in grouped_observations.drain() {
92 let existing = backend.load_observations(&profile_key).await?;
93 let mut window: VecDeque<PromptIR> =
94 existing.unwrap_or_default().into_iter().collect();
95
96 for observation in new_observations {
97 if window.len() >= self.observation_window {
98 window.pop_front();
99 }
100 window.push_back(observation);
101 }
102
103 let observations_vec: Vec<PromptIR> = window.into_iter().collect();
104 backend
105 .store_observations(&profile_key, &observations_vec)
106 .await?;
107
108 let stability_result = analyze_stability(&observations_vec, &self.thresholds);
109 backend
110 .store_stability(&profile_key, &stability_result)
111 .await?;
112
113 profile_counts.insert(profile_key.clone(), stability_result.total_observations);
114 profile_stability.insert(profile_key, stability_result.clone());
115
116 let replace_best = best_profile_seed
117 .as_ref()
118 .map(|(_, current)| {
119 (
120 stability_result.stable_prefix_length,
121 stability_result.total_observations,
122 ) > (current.stable_prefix_length, current.total_observations)
123 })
124 .unwrap_or(true);
125 if replace_best {
126 best_profile_seed = Some((observations_vec.clone(), stability_result.clone()));
127 }
128 }
129
130 if let Some((aggregate_observations, aggregate_stability)) = best_profile_seed.as_ref()
131 {
132 backend
135 .store_observations(&self.agent_id, aggregate_observations)
136 .await?;
137 backend
138 .store_stability(&self.agent_id, aggregate_stability)
139 .await?;
140 }
141
142 let mut guard = hot_cache.write().map_err(|error| {
143 AdaptiveError::Internal(format!("hot cache lock poisoned: {error}"))
144 })?;
145 guard.acg_profiles.extend(profile_stability);
146 guard.acg_profile_observation_counts.extend(profile_counts);
147 if let Some((_, aggregate_stability)) = best_profile_seed {
148 guard.acg_observation_count = aggregate_stability.total_observations;
149 guard.acg_stability = Some(aggregate_stability);
150 }
151
152 Ok(())
153 })
154 }
155}
156
157#[cfg(test)]
158#[path = "../tests/unit/acg_learner_tests.rs"]
159mod tests;