1use std::collections::HashSet;
5use std::fmt;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::{
9 Arc, RwLock,
10 atomic::{AtomicUsize, Ordering},
11};
12use std::time::Duration;
13
14use crate::acg::CacheRequestFacts;
15use nemo_flow::api::event::Event;
16use nemo_flow::api::registry::{
17 scope_deregister_llm_request_intercept, scope_register_llm_request_intercept,
18};
19use nemo_flow::api::runtime::{
20 EventSubscriberFn, LlmExecutionFn, LlmRequestInterceptFn, LlmStreamExecutionFn, ToolExecutionFn,
21};
22use nemo_flow::codec::request::AnnotatedLlmRequest;
23use nemo_flow::plugin::{
24 ConfigReport, DiagnosticLevel, PluginError, PluginRegistration as ComponentRegistration,
25 PluginRegistrationContext as HostedRegistrationContext, rollback_registrations,
26};
27use uuid::Uuid;
28
29use crate::acg_component::{
30 build_provider_plugin, create_acg_llm_execution_intercept, create_acg_llm_request_intercept,
31 create_acg_llm_stream_execution_intercept, load_persisted_acg_state,
32};
33use crate::acg_learner::AcgLearner;
34use crate::adaptive_hints_intercept::AdaptiveHintsIntercept;
35use crate::cache_diagnostics::{self, CacheDiagnosticsTracker};
36use crate::config::{
37 AcgComponentConfig, AdaptiveConfig, AdaptiveHintsComponentConfig, TelemetryComponentConfig,
38 ToolParallelismComponentConfig,
39};
40use crate::context_helpers::resolve_agent_id;
41use crate::error::{AdaptiveError, Result};
42use crate::intercepts::create_tool_execution_intercept_with_mode;
43use crate::learner::latency::LatencySensitivityLearner;
44use crate::learner::traits::Learner;
45use crate::runtime::backend::build_backend;
46use crate::runtime::validation::validate_config;
47use crate::storage::traits::StorageBackendDyn;
48use crate::subscriber::create_subscriber_with_counter;
49use crate::tool_parallelism_learner::ToolParallelismLearner;
50use crate::types::cache::HotCache;
51
52pub struct AdaptiveRuntime {
58 config: AdaptiveConfig,
59 report: ConfigReport,
60 registered_agent_id: Option<String>,
61 backend: Option<Arc<dyn StorageBackendDyn + Send + Sync>>,
62 hot_cache: Arc<RwLock<HotCache>>,
63 cache_diagnostics_tracker: Arc<RwLock<CacheDiagnosticsTracker>>,
64 pending_events: Arc<AtomicUsize>,
65 event_tx: tokio::sync::mpsc::UnboundedSender<Event>,
66 event_rx: Option<tokio::sync::mpsc::UnboundedReceiver<Event>>,
67 drain_handle: Option<tokio::task::JoinHandle<()>>,
68 registered: bool,
69 runtime_id: Uuid,
70 bound_scopes: Arc<RwLock<HashSet<Uuid>>>,
71 registrations: Vec<ComponentRegistration>,
72}
73
74impl fmt::Debug for AdaptiveRuntime {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 f.debug_struct("AdaptiveRuntime")
77 .field("runtime_id", &self.runtime_id)
78 .field("registered", &self.registered)
79 .finish_non_exhaustive()
80 }
81}
82
83struct RegistrationContext<'a> {
84 runtime: &'a mut AdaptiveRuntime,
85 registrations: HostedRegistrationContext,
86}
87
88impl<'a> RegistrationContext<'a> {
89 fn new(runtime: &'a mut AdaptiveRuntime) -> Self {
90 Self {
91 runtime,
92 registrations: HostedRegistrationContext::new(),
93 }
94 }
95
96 fn register_subscriber(&mut self, name: &str, callback: EventSubscriberFn) -> Result<()> {
97 self.registrations
98 .register_subscriber(name, callback)
99 .map_err(Into::into)
100 }
101
102 fn register_llm_request_intercept(
103 &mut self,
104 name: &str,
105 priority: i32,
106 break_chain: bool,
107 callback: LlmRequestInterceptFn,
108 ) -> Result<()> {
109 self.registrations
110 .register_llm_request_intercept(name, priority, break_chain, callback)
111 .map_err(Into::into)
112 }
113
114 fn register_llm_execution_intercept(
115 &mut self,
116 name: &str,
117 priority: i32,
118 callback: LlmExecutionFn,
119 ) -> Result<()> {
120 self.registrations
121 .register_llm_execution_intercept(name, priority, callback)
122 .map_err(Into::into)
123 }
124
125 fn register_llm_stream_execution_intercept(
126 &mut self,
127 name: &str,
128 priority: i32,
129 callback: LlmStreamExecutionFn,
130 ) -> Result<()> {
131 self.registrations
132 .register_llm_stream_execution_intercept(name, priority, callback)
133 .map_err(Into::into)
134 }
135
136 fn register_tool_execution_intercept(
137 &mut self,
138 name: &str,
139 priority: i32,
140 callback: ToolExecutionFn,
141 ) -> Result<()> {
142 self.registrations
143 .register_tool_execution_intercept(name, priority, callback)
144 .map_err(Into::into)
145 }
146
147 fn take_event_receiver(&mut self) -> Result<tokio::sync::mpsc::UnboundedReceiver<Event>> {
148 self.runtime
149 .event_rx
150 .take()
151 .ok_or_else(|| AdaptiveError::Internal("telemetry already registered".into()))
152 }
153
154 fn set_drain_task(&mut self, handle: tokio::task::JoinHandle<()>) {
155 self.runtime.drain_handle = Some(handle);
156 }
157
158 fn finish(self) -> Vec<ComponentRegistration> {
159 self.registrations.into_registrations()
160 }
161}
162
163trait AdaptiveFeature: Send + Sync + 'static {
164 fn register<'a>(
165 &'a mut self,
166 ctx: &'a mut RegistrationContext<'_>,
167 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>;
168}
169
170impl AdaptiveRuntime {
171 pub async fn new(config: AdaptiveConfig) -> Result<Self> {
184 let report = validate_config(&config);
185 if report.has_errors() {
186 let joined = report
187 .diagnostics
188 .iter()
189 .filter(|diagnostic| diagnostic.level == DiagnosticLevel::Error)
190 .map(|diagnostic| diagnostic.message.clone())
191 .collect::<Vec<_>>()
192 .join("; ");
193 return Err(AdaptiveError::InvalidConfig(joined));
194 }
195
196 let backend = match config.state.as_ref() {
197 Some(state) => Some(build_backend(&state.backend).await?),
198 None => None,
199 };
200 let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel();
201
202 Ok(Self {
203 config,
204 report,
205 registered_agent_id: None,
206 backend,
207 hot_cache: Arc::new(RwLock::new(HotCache {
208 plan: None,
209 trie: None,
210 agent_hints_default: None,
211 acg_profiles: std::collections::HashMap::new(),
212 acg_profile_observation_counts: std::collections::HashMap::new(),
213 acg_stability: None,
214 acg_observation_count: 0,
215 })),
216 cache_diagnostics_tracker: Arc::new(RwLock::new(CacheDiagnosticsTracker::default())),
217 pending_events: Arc::new(AtomicUsize::new(0)),
218 event_tx,
219 event_rx: Some(event_rx),
220 drain_handle: None,
221 registered: false,
222 runtime_id: Uuid::now_v7(),
223 bound_scopes: Arc::new(RwLock::new(HashSet::new())),
224 registrations: vec![],
225 })
226 }
227
228 pub fn validate_config(config: &AdaptiveConfig) -> ConfigReport {
236 validate_config(config)
237 }
238
239 pub fn report(&self) -> &ConfigReport {
244 &self.report
245 }
246
247 pub fn wait_for_idle(&self) {
253 loop {
254 if self.pending_events.load(Ordering::SeqCst) == 0 {
255 return;
256 }
257 std::thread::sleep(Duration::from_millis(5));
258 }
259 }
260
261 #[must_use]
262 pub fn build_cache_request_facts(
273 &self,
274 agent_id: &str,
275 provider: &str,
276 annotated_request: &AnnotatedLlmRequest,
277 ) -> Option<CacheRequestFacts> {
278 cache_diagnostics::build_cache_request_facts(
279 agent_id,
280 provider,
281 annotated_request,
282 &self.hot_cache,
283 &self.cache_diagnostics_tracker,
284 )
285 }
286
287 fn acg_scope_registration_name(&self, scope_uuid: Uuid) -> String {
288 format!(
289 "adaptive_{}_acg_scope_request_{scope_uuid}",
290 self.runtime_id
291 )
292 }
293
294 pub fn bind_scope(&mut self, scope_uuid: Uuid) -> Result<()> {
307 if !self.registered {
308 return Err(AdaptiveError::RegistrationFailed(
309 "adaptive runtime must be registered before binding ACG request intercepts".into(),
310 ));
311 }
312
313 let agent_id = self.registered_agent_id.as_deref().ok_or_else(|| {
314 AdaptiveError::Internal("adaptive runtime missing registered agent id".into())
315 })?;
316 let acg_config = self.config.acg.as_ref().ok_or_else(|| {
317 AdaptiveError::InvalidConfig(
318 "adaptive runtime does not enable scope-bound ACG request intercepts".into(),
319 )
320 })?;
321 if self
322 .bound_scopes
323 .read()
324 .map_err(|error| AdaptiveError::Internal(error.to_string()))?
325 .contains(&scope_uuid)
326 {
327 return Ok(());
328 }
329
330 let provider = acg_config.provider.clone();
331 let priority = acg_config.priority;
332 let plugin = build_provider_plugin(&provider)?;
333 let registration_name = self.acg_scope_registration_name(scope_uuid);
334 scope_register_llm_request_intercept(
335 &scope_uuid,
336 ®istration_name,
337 priority,
338 false,
339 create_acg_llm_request_intercept(
340 self.hot_cache.clone(),
341 agent_id.to_string(),
342 provider.clone(),
343 plugin,
344 ),
345 )
346 .map_err(|error| {
347 AdaptiveError::RegistrationFailed(format!(
348 "scope-bound ACG llm request intercept: {error}"
349 ))
350 })?;
351
352 self.bound_scopes
353 .write()
354 .map_err(|error| AdaptiveError::Internal(error.to_string()))?
355 .insert(scope_uuid);
356
357 let bound_scopes = self.bound_scopes.clone();
358 self.registrations.push(ComponentRegistration::new(
359 "adaptive_scope",
360 registration_name.clone(),
361 Box::new(move || {
362 if let Ok(mut guard) = bound_scopes.write() {
363 guard.remove(&scope_uuid);
364 }
365 scope_deregister_llm_request_intercept(&scope_uuid, ®istration_name)
366 .map(|_| ())
367 .map_err(|error| {
368 PluginError::RegistrationFailed(format!(
369 "scope-bound ACG llm request intercept deregistration failed: {error}"
370 ))
371 })
372 }),
373 ));
374
375 Ok(())
376 }
377 pub async fn register(&mut self) -> Result<()> {
385 if self.registered {
386 return Ok(());
387 }
388
389 let agent_id = self.agent_id();
390 self.registered_agent_id = Some(agent_id.clone());
391 Self::seed_hot_cache(self.backend.clone(), self.hot_cache.clone(), &agent_id).await;
392
393 if self.config.acg.is_some()
394 && let Some(backend) = self.backend.as_ref()
395 && let Err(error) =
396 load_persisted_acg_state(&agent_id, backend.as_ref(), &self.hot_cache).await
397 {
398 eprintln!("nemo-flow-adaptive: acg hot cache seeding failed: {error}");
399 }
400
401 let mut pending = self.pending_features(&agent_id);
402
403 for feature in &mut pending {
404 self.register_feature(feature).await?;
405 }
406
407 self.registered = true;
408 Ok(())
409 }
410
411 fn agent_id(&self) -> String {
412 self.config
413 .agent_id
414 .clone()
415 .or_else(resolve_agent_id)
416 .unwrap_or_else(|| "default-agent".to_string())
417 }
418
419 async fn seed_hot_cache(
420 backend: Option<Arc<dyn StorageBackendDyn + Send + Sync>>,
421 hot_cache: Arc<RwLock<HotCache>>,
422 agent_id: &str,
423 ) {
424 let Some(backend) = backend else {
425 return;
426 };
427
428 match backend.load_plan_dyn(agent_id).await {
429 Ok(plan) => {
430 if let Ok(mut guard) = hot_cache.write() {
431 guard.plan = plan;
432 }
433 }
434 Err(error) => eprintln!("nemo-flow-adaptive: hot cache seeding failed: {error}"),
435 }
436 }
437
438 fn pending_features(&self, agent_id: &str) -> Vec<Box<dyn AdaptiveFeature>> {
439 let mut pending: Vec<Box<dyn AdaptiveFeature>> = vec![];
440 if let Some(config) = self.config.telemetry.clone()
441 && self.backend.is_some()
442 {
443 pending.push(Box::new(TelemetryFeature::new(
444 config,
445 agent_id.to_string(),
446 self.runtime_id,
447 self.config.acg.clone(),
448 )));
449 }
450 if let Some(config) = self.config.adaptive_hints.clone() {
451 pending.push(Box::new(AdaptiveHintsFeature::new(
452 config,
453 self.hot_cache.clone(),
454 agent_id.to_string(),
455 self.runtime_id,
456 )));
457 }
458 if let Some(config) = self.config.tool_parallelism.clone() {
459 pending.push(Box::new(ToolParallelismFeature::new(
460 config,
461 self.hot_cache.clone(),
462 self.runtime_id,
463 )));
464 }
465 if let Some(config) = self.config.acg.clone()
466 && self.backend.is_some()
467 {
468 pending.push(Box::new(AcgFeature::new(
469 config,
470 self.hot_cache.clone(),
471 self.bound_scopes.clone(),
472 agent_id.to_string(),
473 self.runtime_id,
474 )));
475 }
476 pending
477 }
478
479 async fn register_feature(&mut self, feature: &mut Box<dyn AdaptiveFeature>) -> Result<()> {
480 let mut ctx = RegistrationContext::new(self);
481 if let Err(error) = feature.register(&mut ctx).await {
482 let mut just_registered = ctx.finish();
483 rollback_registrations(&mut just_registered);
484 rollback_registrations(&mut self.registrations);
485 if let Some(handle) = self.drain_handle.take() {
486 handle.abort();
487 }
488 self.registered = false;
489 return Err(error);
490 }
491
492 let completed = ctx.finish();
493 self.registrations.extend(completed);
494 Ok(())
495 }
496
497 pub fn deregister(&mut self) -> Result<()> {
505 rollback_registrations(&mut self.registrations);
506 if let Ok(mut guard) = self.bound_scopes.write() {
507 guard.clear();
508 }
509 if let Some(handle) = self.drain_handle.take() {
510 handle.abort();
511 }
512 self.registered = false;
513 Ok(())
514 }
515
516 pub async fn shutdown(mut self) -> Result<()> {
524 self.deregister()
525 }
526}
527
528impl Drop for AdaptiveRuntime {
529 fn drop(&mut self) {
530 let _ = self.deregister();
531 }
532}
533
534struct TelemetryFeature {
535 agent_id: String,
536 subscriber_name: String,
537 learners: Vec<Box<dyn Learner>>,
538}
539
540impl TelemetryFeature {
541 fn new(
542 config: TelemetryComponentConfig,
543 agent_id: String,
544 runtime_id: Uuid,
545 acg_config: Option<AcgComponentConfig>,
546 ) -> Self {
547 let subscriber_name = config
548 .subscriber_name
549 .unwrap_or_else(|| format!("adaptive_{runtime_id}_subscriber"));
550 Self {
551 learners: build_learners(&agent_id, &config.learners, acg_config.as_ref()),
552 agent_id,
553 subscriber_name,
554 }
555 }
556}
557
558impl AdaptiveFeature for TelemetryFeature {
559 fn register<'a>(
560 &'a mut self,
561 ctx: &'a mut RegistrationContext<'_>,
562 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
563 Box::pin(async move {
564 let backend = ctx.runtime.backend.as_ref().cloned().ok_or_else(|| {
565 AdaptiveError::InvalidConfig("telemetry requires state backend".into())
566 })?;
567 let rx = ctx.take_event_receiver()?;
568 let cache = ctx.runtime.hot_cache.clone();
569 let agent_id = self.agent_id.clone();
570 let learners = std::mem::take(&mut self.learners);
571 let pending_events = ctx.runtime.pending_events.clone();
572 ctx.set_drain_task(tokio::spawn(async move {
573 crate::drain::drain_task_with_counter(
574 rx,
575 backend,
576 cache,
577 pending_events,
578 agent_id,
579 learners,
580 )
581 .await;
582 }));
583 ctx.register_subscriber(
584 &self.subscriber_name,
585 create_subscriber_with_counter(
586 ctx.runtime.event_tx.clone(),
587 ctx.runtime.pending_events.clone(),
588 ),
589 )
590 })
591 }
592}
593
594struct AdaptiveHintsFeature {
595 name: String,
596 priority: i32,
597 break_chain: bool,
598 hot_cache: Arc<RwLock<HotCache>>,
599 agent_id: String,
600}
601
602impl AdaptiveHintsFeature {
603 fn new(
604 config: AdaptiveHintsComponentConfig,
605 hot_cache: Arc<RwLock<HotCache>>,
606 agent_id: String,
607 runtime_id: Uuid,
608 ) -> Self {
609 Self {
610 name: format!("adaptive_{runtime_id}_adaptive_hints_request"),
611 priority: config.priority,
612 break_chain: config.break_chain,
613 hot_cache,
614 agent_id,
615 }
616 }
617}
618
619impl AdaptiveFeature for AdaptiveHintsFeature {
620 fn register<'a>(
621 &'a mut self,
622 ctx: &'a mut RegistrationContext<'_>,
623 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
624 Box::pin(async move {
625 let adaptive_hints =
626 AdaptiveHintsIntercept::new(self.hot_cache.clone(), self.agent_id.clone());
627 ctx.register_llm_request_intercept(
628 &self.name,
629 self.priority,
630 self.break_chain,
631 adaptive_hints.into_request_fn(),
632 )
633 })
634 }
635}
636
637struct ToolParallelismFeature {
638 name: String,
639 priority: i32,
640 hot_cache: Arc<RwLock<HotCache>>,
641 mode: String,
642}
643
644impl ToolParallelismFeature {
645 fn new(
646 config: ToolParallelismComponentConfig,
647 hot_cache: Arc<RwLock<HotCache>>,
648 runtime_id: Uuid,
649 ) -> Self {
650 Self {
651 name: format!("adaptive_{runtime_id}_tool_execution"),
652 priority: config.priority,
653 hot_cache,
654 mode: config.mode,
655 }
656 }
657}
658
659impl AdaptiveFeature for ToolParallelismFeature {
660 fn register<'a>(
661 &'a mut self,
662 ctx: &'a mut RegistrationContext<'_>,
663 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
664 Box::pin(async move {
665 ctx.register_tool_execution_intercept(
666 &self.name,
667 self.priority,
668 create_tool_execution_intercept_with_mode(
669 self.hot_cache.clone(),
670 self.mode.clone(),
671 ),
672 )
673 })
674 }
675}
676
677struct AcgFeature {
678 execution_name: String,
679 stream_name: String,
680 priority: i32,
681 hot_cache: Arc<RwLock<HotCache>>,
682 bound_scopes: Arc<RwLock<HashSet<Uuid>>>,
683 agent_id: String,
684 provider: String,
685}
686
687impl AcgFeature {
688 fn new(
689 config: AcgComponentConfig,
690 hot_cache: Arc<RwLock<HotCache>>,
691 bound_scopes: Arc<RwLock<HashSet<Uuid>>>,
692 agent_id: String,
693 runtime_id: Uuid,
694 ) -> Self {
695 Self {
696 execution_name: format!("adaptive_{runtime_id}_acg_llm_execution"),
697 stream_name: format!("adaptive_{runtime_id}_acg_llm_stream_execution"),
698 priority: config.priority,
699 hot_cache,
700 bound_scopes,
701 agent_id,
702 provider: config.provider,
703 }
704 }
705}
706
707impl AdaptiveFeature for AcgFeature {
708 fn register<'a>(
709 &'a mut self,
710 ctx: &'a mut RegistrationContext<'_>,
711 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
712 Box::pin(async move {
713 let plugin = build_provider_plugin(&self.provider)?;
714 let execution_intercept = create_acg_llm_execution_intercept(
715 self.hot_cache.clone(),
716 self.agent_id.clone(),
717 self.provider.clone(),
718 plugin.clone(),
719 );
720 let bound_scopes = self.bound_scopes.clone();
721 ctx.register_llm_execution_intercept(
722 &self.execution_name,
723 self.priority,
724 Arc::new(move |name, request, next| {
725 let execution_intercept = execution_intercept.clone();
726 let bound_scopes = bound_scopes.clone();
727 let name = name.to_string();
728 Box::pin(async move {
729 let has_bound_scopes = bound_scopes
730 .read()
731 .map(|guard| !guard.is_empty())
732 .unwrap_or(false);
733 if has_bound_scopes {
734 return next(request).await;
735 }
736 execution_intercept(&name, request, next).await
737 })
738 }),
739 )?;
740 let stream_intercept = create_acg_llm_stream_execution_intercept(
741 self.hot_cache.clone(),
742 self.agent_id.clone(),
743 self.provider.clone(),
744 plugin,
745 );
746 let bound_scopes = self.bound_scopes.clone();
747 ctx.register_llm_stream_execution_intercept(
748 &self.stream_name,
749 self.priority,
750 Arc::new(move |name, request, next| {
751 let stream_intercept = stream_intercept.clone();
752 let bound_scopes = bound_scopes.clone();
753 let name = name.to_string();
754 Box::pin(async move {
755 let has_bound_scopes = bound_scopes
756 .read()
757 .map(|guard| !guard.is_empty())
758 .unwrap_or(false);
759 if has_bound_scopes {
760 return next(request).await;
761 }
762 stream_intercept(&name, request, next).await
763 })
764 }),
765 )
766 })
767 }
768}
769
770fn build_learners(
771 agent_id: &str,
772 learners: &[String],
773 acg_config: Option<&AcgComponentConfig>,
774) -> Vec<Box<dyn Learner>> {
775 let mut built: Vec<Box<dyn Learner>> = vec![];
776 for learner in learners {
777 match learner.as_str() {
778 "latency_sensitivity" => built.push(Box::new(LatencySensitivityLearner::new(
779 agent_id,
780 crate::trie::builder::SensitivityConfig::default(),
781 ))),
782 "tool_parallelism" => built.push(Box::new(ToolParallelismLearner::new(agent_id))),
783 "acg" => {
784 if let Some(config) = acg_config {
785 built.push(Box::new(AcgLearner::new(
786 agent_id,
787 config.observation_window,
788 config.stability_thresholds.clone(),
789 )));
790 }
791 }
792 _ => {}
793 }
794 }
795 built
796}
797
798#[cfg(test)]
799#[path = "../../tests/unit/runtime_features_tests.rs"]
800mod tests;