use crate::capability_binding::CapabilityTrace;
use crate::thought_stream::ThoughtEvent;
use crate::types::{Timestamp, StructuredContent, ProvenanceChain};
use crate::Result;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing;
use uuid::Uuid;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum TraceEvent {
Thought(ThoughtEvent),
Capability(CapabilityTrace),
ContextAccess {
operation: String,
query: String,
timestamp: Timestamp,
duration_ms: u64,
},
ExecutionState {
from: String,
to: String,
timestamp: Timestamp,
metadata: serde_json::Value,
},
UserIntervention {
action: String,
reason: String,
timestamp: Timestamp,
},
}
impl TraceEvent {
pub fn timestamp(&self) -> Timestamp {
match self {
Self::Thought(event) => event.timestamp,
Self::Capability(trace) => trace.metadata.timestamp,
Self::ContextAccess { timestamp, .. } => *timestamp,
Self::ExecutionState { timestamp, .. } => *timestamp,
Self::UserIntervention { timestamp, .. } => *timestamp,
}
}
pub fn is_failure(&self) -> bool {
match self {
Self::Capability(trace) => !trace.success,
_ => false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionTrace {
pub id: Uuid,
pub created_at: Timestamp,
pub agent_id: String,
pub events: Vec<TraceEvent>,
pub metadata: TraceMetadata,
}
impl ExecutionTrace {
pub fn new(agent_id: impl Into<String>) -> Self {
Self {
id: Uuid::new_v4(),
created_at: Timestamp::now(),
agent_id: agent_id.into(),
events: Vec::new(),
metadata: TraceMetadata::default(),
}
}
pub fn add_event(&mut self, event: TraceEvent) {
self.events.push(event);
}
pub fn len(&self) -> usize {
self.events.len()
}
pub fn is_empty(&self) -> bool {
self.events.is_empty()
}
pub fn filter_events(&self, predicate: fn(&TraceEvent) -> bool) -> Vec<&TraceEvent> {
self.events.iter().filter(|e| predicate(e)).collect()
}
pub fn thought_events(&self) -> Vec<&ThoughtEvent> {
self.events
.iter()
.filter_map(|e| match e {
TraceEvent::Thought(event) => Some(event),
_ => None,
})
.collect()
}
pub fn capability_traces(&self) -> Vec<&CapabilityTrace> {
self.events
.iter()
.filter_map(|e| match e {
TraceEvent::Capability(trace) => Some(trace),
_ => None,
})
.collect()
}
pub fn is_successful(&self) -> bool {
!self.events.iter().any(|e| e.is_failure())
}
pub fn duration(&self) -> Option<std::time::Duration> {
if self.events.is_empty() {
return None;
}
let first = self.events.first()?.timestamp();
let last = self.events.last()?.timestamp();
Some(last.as_datetime().signed_duration_since(first.as_datetime()).to_std().ok()?)
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TraceMetadata {
pub event_count: usize,
pub failure_count: usize,
pub resource_usage: serde_json::Value,
pub custom: serde_json::Value,
}
#[derive(Clone)]
pub struct TraceRecorder {
trace: Arc<RwLock<ExecutionTrace>>,
archive: Arc<RwLock<Vec<ExecutionTrace>>>,
}
impl TraceRecorder {
pub fn new(agent_id: impl Into<String>) -> Self {
Self {
trace: Arc::new(RwLock::new(ExecutionTrace::new(agent_id))),
archive: Arc::new(RwLock::new(Vec::new())),
}
}
pub async fn record(&self, event: TraceEvent) {
let mut trace = self.trace.write().await;
trace.add_event(event);
}
pub async fn record_thought(&self, event: ThoughtEvent) {
self.record(TraceEvent::Thought(event)).await;
}
pub async fn record_capability(&self, trace: CapabilityTrace) {
self.record(TraceEvent::Capability(trace)).await;
}
pub async fn current_trace(&self) -> ExecutionTrace {
let trace = self.trace.read().await;
trace.clone()
}
pub async fn finalize(&self) -> ExecutionTrace {
let completed = {
let mut trace = self.trace.write().await;
let mut completed = trace.clone();
completed.metadata.event_count = completed.len();
completed.metadata.failure_count = completed
.events
.iter()
.filter(|e| e.is_failure())
.count();
let mut archive = self.archive.write().await;
archive.push(completed.clone());
completed
};
let agent_id = completed.agent_id.clone();
*self.trace.write().await = ExecutionTrace::new(agent_id);
completed
}
pub async fn archive(&self) -> Vec<ExecutionTrace> {
let archive = self.archive.read().await;
archive.clone()
}
pub async fn clear_archive(&self) {
let mut archive = self.archive.write().await;
archive.clear();
}
pub async fn replay(trace: &ExecutionTrace) -> Result<()> {
tracing::info!("Replaying trace {} with {} events", trace.id, trace.len());
for (idx, event) in trace.events.iter().enumerate() {
tracing::debug!("Replaying event {}: {:?}", idx, event);
match event {
TraceEvent::Thought(thought) => {
tracing::debug!("Thought event: {:?}", thought.event_type);
}
TraceEvent::Capability(cap) => {
tracing::debug!("Capability event: {}", cap.capability_name);
if !cap.success {
return Err(crate::Error::ExecutionError("Failed capability in replay".into()));
}
}
TraceEvent::ContextAccess { operation, .. } => {
tracing::debug!("Context access: {}", operation);
}
TraceEvent::ExecutionState { from, to, .. } => {
tracing::debug!("State transition: {} -> {}", from, to);
}
TraceEvent::UserIntervention { action, .. } => {
tracing::info!("User intervention during replay: {}", action);
}
}
}
Ok(())
}
}
impl Default for TraceRecorder {
fn default() -> Self {
Self::new("default")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_execution_trace_creation() {
let trace = ExecutionTrace::new("test_agent");
assert_eq!(trace.agent_id, "test_agent");
assert!(trace.is_empty());
}
#[tokio::test]
async fn test_trace_recorder() {
let recorder = TraceRecorder::new("test_agent");
let thought = ThoughtEvent::observation(
StructuredContent::text("test observation"),
ProvenanceChain::new(),
);
recorder.record_thought(thought).await;
let trace = recorder.current_trace().await;
assert_eq!(trace.len(), 1);
let events = trace.thought_events();
assert_eq!(events.len(), 1);
}
#[test]
fn test_trace_duration() {
let mut trace = ExecutionTrace::new("test");
trace.add_event(TraceEvent::ContextAccess {
operation: "test".into(),
query: "test".into(),
timestamp: Timestamp::now(),
duration_ms: 0,
});
let duration = trace.duration();
assert!(duration.is_some());
assert_eq!(duration.unwrap().as_millis(), 0); }
}