use std::collections::HashSet;
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use awaken_contract::contract::event::AgentEvent;
use awaken_contract::contract::event_sink::EventSink;
use awaken_contract::contract::progress::{
TOOL_CALL_PROGRESS_ACTIVITY_TYPE, ToolCallProgressState,
};
pub struct ProgressForwardingSink {
parent_sink: Arc<dyn EventSink>,
seen_child_calls: Mutex<HashSet<String>>,
}
impl ProgressForwardingSink {
pub fn new(parent_sink: Arc<dyn EventSink>) -> Self {
Self {
parent_sink,
seen_child_calls: Mutex::new(HashSet::new()),
}
}
}
#[async_trait]
impl EventSink for ProgressForwardingSink {
async fn emit(&self, event: AgentEvent) {
match &event {
AgentEvent::ToolCallStart { id, .. } => {
self.seen_child_calls.lock().unwrap().insert(id.clone());
}
AgentEvent::ActivitySnapshot {
activity_type,
content,
..
} if activity_type == TOOL_CALL_PROGRESS_ACTIVITY_TYPE => {
if let Ok(state) = serde_json::from_value::<ToolCallProgressState>(content.clone())
&& self
.seen_child_calls
.lock()
.unwrap()
.contains(&state.call_id)
{
self.parent_sink.emit(event).await;
}
}
_ => {}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use awaken_contract::contract::progress::ProgressStatus;
use serde_json::json;
use std::sync::Mutex as StdMutex;
struct RecordingSink {
events: StdMutex<Vec<AgentEvent>>,
}
impl RecordingSink {
fn new() -> Self {
Self {
events: StdMutex::new(Vec::new()),
}
}
fn events(&self) -> Vec<AgentEvent> {
self.events.lock().unwrap().clone()
}
}
#[async_trait]
impl EventSink for RecordingSink {
async fn emit(&self, event: AgentEvent) {
self.events.lock().unwrap().push(event);
}
}
fn make_progress_snapshot(call_id: &str, tool_name: &str) -> AgentEvent {
let state = ToolCallProgressState {
schema: "tool-call-progress.v1".into(),
node_id: format!("tool_call:{call_id}"),
call_id: call_id.into(),
tool_name: tool_name.into(),
status: ProgressStatus::Running,
progress: Some(0.5),
loaded: None,
total: None,
message: Some("working...".into()),
parent_node_id: None,
parent_call_id: None,
run_id: None,
parent_run_id: None,
thread_id: None,
};
AgentEvent::ActivitySnapshot {
message_id: "msg-1".into(),
activity_type: TOOL_CALL_PROGRESS_ACTIVITY_TYPE.into(),
content: serde_json::to_value(state).unwrap(),
replace: None,
}
}
#[tokio::test]
async fn forwards_progress_from_known_child_call() {
let recorder = Arc::new(RecordingSink::new());
let sink = ProgressForwardingSink::new(recorder.clone());
sink.emit(AgentEvent::ToolCallStart {
id: "call-1".into(),
name: "search".into(),
})
.await;
sink.emit(make_progress_snapshot("call-1", "search")).await;
let events = recorder.events();
assert_eq!(events.len(), 1);
assert!(matches!(
&events[0],
AgentEvent::ActivitySnapshot { activity_type, .. }
if activity_type == TOOL_CALL_PROGRESS_ACTIVITY_TYPE
));
}
#[tokio::test]
async fn drops_progress_from_unknown_call() {
let recorder = Arc::new(RecordingSink::new());
let sink = ProgressForwardingSink::new(recorder.clone());
sink.emit(make_progress_snapshot("unknown-call", "deep_tool"))
.await;
assert!(recorder.events().is_empty());
}
#[tokio::test]
async fn drops_non_progress_events() {
let recorder = Arc::new(RecordingSink::new());
let sink = ProgressForwardingSink::new(recorder.clone());
sink.emit(AgentEvent::TextDelta {
delta: "hello".into(),
})
.await;
sink.emit(AgentEvent::StepStart {
message_id: "m1".into(),
})
.await;
sink.emit(AgentEvent::StepEnd).await;
assert!(recorder.events().is_empty());
}
#[tokio::test]
async fn tracks_multiple_child_calls() {
let recorder = Arc::new(RecordingSink::new());
let sink = ProgressForwardingSink::new(recorder.clone());
sink.emit(AgentEvent::ToolCallStart {
id: "call-a".into(),
name: "tool_a".into(),
})
.await;
sink.emit(AgentEvent::ToolCallStart {
id: "call-b".into(),
name: "tool_b".into(),
})
.await;
sink.emit(make_progress_snapshot("call-a", "tool_a")).await;
sink.emit(make_progress_snapshot("call-b", "tool_b")).await;
sink.emit(make_progress_snapshot("call-c", "tool_c")).await;
assert_eq!(recorder.events().len(), 2);
}
#[tokio::test]
async fn ignores_non_progress_activity_snapshots() {
let recorder = Arc::new(RecordingSink::new());
let sink = ProgressForwardingSink::new(recorder.clone());
sink.emit(AgentEvent::ToolCallStart {
id: "call-1".into(),
name: "tool".into(),
})
.await;
sink.emit(AgentEvent::ActivitySnapshot {
message_id: "msg-1".into(),
activity_type: "file-activity".into(),
content: json!({"some": "data"}),
replace: None,
})
.await;
assert!(recorder.events().is_empty());
}
}