use std::collections::HashMap;
use std::sync::{
Arc, RwLock,
atomic::{AtomicUsize, Ordering},
};
use nemo_flow::api::event::{Event, ScopeCategory};
use nemo_flow::api::scope::ScopeType;
use uuid::Uuid;
use crate::learner::traits::Learner;
use crate::storage::traits::StorageBackendDyn;
use crate::subscriber::{event_to_call_record, is_run_boundary};
use crate::types::cache::HotCache;
use crate::types::records::{CallRecord, RunRecord};
pub(crate) struct RunAccumulator {
agent_id: String,
open_runs: HashMap<Uuid, RunRecord>,
event_roots: HashMap<Uuid, Uuid>,
}
impl RunAccumulator {
pub(crate) fn new(agent_id: String) -> Self {
Self {
agent_id,
open_runs: HashMap::new(),
event_roots: HashMap::new(),
}
}
#[cfg(test)]
pub(crate) fn open_run_count(&self) -> usize {
self.open_runs.len()
}
pub(crate) fn process_event(&mut self, event: &Event) -> Option<RunRecord> {
if let Some(boundary_result) = self.process_run_boundary(event) {
return boundary_result;
}
match (event.scope_category(), event.scope_type()) {
(Some(ScopeCategory::Start), Some(ScopeType::Tool | ScopeType::Llm)) => {
self.track_call_start(event)?;
None
}
(Some(ScopeCategory::End), Some(ScopeType::Tool | ScopeType::Llm)) => {
self.track_call_end(event)?;
None
}
(Some(ScopeCategory::Start), Some(scope_type)) => {
self.track_nested_scope_start(event, scope_type)?;
None
}
(Some(ScopeCategory::End), Some(scope_type)) => {
self.track_nested_scope_end(event, scope_type);
None
}
_ => None,
}
}
fn process_run_boundary(&mut self, event: &Event) -> Option<Option<RunRecord>> {
if !is_run_boundary(event) {
return None;
}
if event.scope_category() == Some(ScopeCategory::Start) {
self.start_run(event);
return Some(None);
}
Some(self.finish_run(event))
}
fn start_run(&mut self, event: &Event) {
let root_uuid = event.uuid();
self.event_roots.insert(root_uuid, root_uuid);
let run = RunRecord {
id: Uuid::now_v7(),
agent_id: self.agent_id.clone(),
calls: vec![],
started_at: *event.timestamp(),
ended_at: None,
};
self.open_runs.insert(root_uuid, run);
}
fn finish_run(&mut self, event: &Event) -> Option<RunRecord> {
let root_uuid = self
.event_roots
.remove(&event.uuid())
.unwrap_or_else(|| event.uuid());
let mut run = self.open_runs.remove(&root_uuid)?;
run.ended_at = Some(*event.timestamp());
Some(run)
}
fn track_nested_scope_start(&mut self, event: &Event, scope_type: ScopeType) -> Option<()> {
if scope_type != ScopeType::Agent {
let root_uuid = self.infer_root_uuid(event)?;
self.event_roots.insert(event.uuid(), root_uuid);
}
Some(())
}
fn track_nested_scope_end(&mut self, event: &Event, scope_type: ScopeType) {
if scope_type != ScopeType::Agent {
self.event_roots.remove(&event.uuid());
}
}
fn track_call_start(&mut self, event: &Event) -> Option<()> {
let root_uuid = self.infer_root_uuid(event)?;
self.event_roots.insert(event.uuid(), root_uuid);
if let Some(record) = event_to_call_record(event)
&& let Some(run) = self.open_runs.get_mut(&root_uuid)
{
run.calls.push(record);
}
Some(())
}
fn track_call_end(&mut self, event: &Event) -> Option<()> {
let root_uuid = self.infer_root_uuid(event)?;
if let Some(run) = self.open_runs.get_mut(&root_uuid)
&& let Some(call) = find_open_call(run, event.name())
{
call.ended_at = Some(*event.timestamp());
apply_llm_end_metadata(call, event);
}
self.event_roots.remove(&event.uuid());
Some(())
}
fn infer_root_uuid(&self, event: &Event) -> Option<Uuid> {
self.event_roots.get(&event.uuid()).copied().or_else(|| {
event
.parent_uuid()
.and_then(|parent_uuid| self.event_roots.get(&parent_uuid).copied())
})
}
}
fn find_open_call<'a>(run: &'a mut RunRecord, event_name: &str) -> Option<&'a mut CallRecord> {
run.calls
.iter_mut()
.rev()
.find(|call| call.name == event_name && call.ended_at.is_none())
}
fn apply_llm_end_metadata(call: &mut CallRecord, event: &Event) {
if event.category().map(|category| category.as_str()) != Some("llm") {
return;
}
call.annotated_response = event.annotated_response().cloned();
let Some(ref annotated) = call.annotated_response else {
return;
};
if let Some(ref usage) = annotated.usage {
call.output_tokens = usage.completion_tokens.map(|tokens| tokens as u32);
call.prompt_tokens = usage.prompt_tokens.map(|tokens| tokens as u32);
call.total_tokens = usage.total_tokens.map(|tokens| tokens as u32);
}
call.model_name = annotated.model.clone();
call.tool_call_count = annotated
.tool_calls
.as_ref()
.map(|calls| calls.len() as u32);
}
async fn store_run(
backend: &Arc<dyn StorageBackendDyn + Send + Sync>,
completed_run: &RunRecord,
) -> bool {
if let Err(error) = backend.store_run_dyn(completed_run).await {
eprintln!("nemo-flow-adaptive drain: store_run failed: {error}");
return false;
}
true
}
async fn run_learners(
learners: &[Box<dyn Learner>],
completed_run: &RunRecord,
backend: &Arc<dyn StorageBackendDyn + Send + Sync>,
hot_cache: &Arc<RwLock<HotCache>>,
) {
for learner in learners {
if let Err(error) = learner
.process_run(completed_run, backend.as_ref(), hot_cache)
.await
{
eprintln!("nemo-flow-adaptive drain: learner failed: {error}");
}
}
}
async fn refresh_hot_cache_plan(
backend: &Arc<dyn StorageBackendDyn + Send + Sync>,
hot_cache: &Arc<RwLock<HotCache>>,
agent_id: &str,
) {
match backend.load_plan_dyn(agent_id).await {
Ok(plan) => {
if let Ok(mut guard) = hot_cache.write() {
guard.plan = plan;
}
}
Err(error) => eprintln!("nemo-flow-adaptive drain: load_plan failed: {error}"),
}
}
#[allow(dead_code)]
pub(crate) async fn drain_task(
rx: tokio::sync::mpsc::UnboundedReceiver<Event>,
backend: Arc<dyn StorageBackendDyn + Send + Sync>,
hot_cache: Arc<RwLock<HotCache>>,
agent_id: String,
learners: Vec<Box<dyn Learner>>,
) {
drain_task_with_counter(
rx,
backend,
hot_cache,
Arc::new(AtomicUsize::new(0)),
agent_id,
learners,
)
.await;
}
pub(crate) async fn drain_task_with_counter(
mut rx: tokio::sync::mpsc::UnboundedReceiver<Event>,
backend: Arc<dyn StorageBackendDyn + Send + Sync>,
hot_cache: Arc<RwLock<HotCache>>,
pending_events: Arc<AtomicUsize>,
agent_id: String,
learners: Vec<Box<dyn Learner>>,
) {
let mut accumulator = RunAccumulator::new(agent_id.clone());
while let Some(event) = rx.recv().await {
if let Some(completed_run) = accumulator.process_event(&event) {
if !store_run(&backend, &completed_run).await {
pending_events.fetch_sub(1, Ordering::SeqCst);
continue;
}
run_learners(&learners, &completed_run, &backend, &hot_cache).await;
refresh_hot_cache_plan(&backend, &hot_cache, &agent_id).await;
}
pending_events.fetch_sub(1, Ordering::SeqCst);
}
}
#[cfg(test)]
#[path = "../tests/unit/drain_tests.rs"]
mod tests;