Skip to main content

imp_core/agent/
events.rs

1use std::path::PathBuf;
2
3use imp_llm::{AssistantMessage, Cost, Message, StreamEvent, Usage};
4use serde::{Deserialize, Serialize};
5use serde_json::json;
6
7use crate::mana_review::TurnManaReview;
8use crate::reference_monitor::PolicyTraceRecord;
9use crate::trace::TraceEvent;
10use crate::trust::Provenance;
11use crate::workflow::VerificationGate;
12
13use super::NextActionAssessment;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum TimingStage {
17    ContextAssemblyStart,
18    ContextAssemblyEnd,
19    LlmRequestStart,
20    FirstStreamEvent,
21    FirstTextDelta,
22    FirstToolCall,
23    MessageEnd,
24    ToolExecutionStart,
25    ToolExecutionEnd,
26    PostTurnAssessmentStart,
27    PostTurnAssessmentEnd,
28}
29
30impl TimingStage {
31    pub fn as_str(self) -> &'static str {
32        match self {
33            Self::ContextAssemblyStart => "context_assembly_start",
34            Self::ContextAssemblyEnd => "context_assembly_end",
35            Self::LlmRequestStart => "llm_request_start",
36            Self::FirstStreamEvent => "first_stream_event",
37            Self::FirstTextDelta => "first_text_delta",
38            Self::FirstToolCall => "first_tool_call",
39            Self::MessageEnd => "message_end",
40            Self::ToolExecutionStart => "tool_execution_start",
41            Self::ToolExecutionEnd => "tool_execution_end",
42            Self::PostTurnAssessmentStart => "post_turn_assessment_start",
43            Self::PostTurnAssessmentEnd => "post_turn_assessment_end",
44        }
45    }
46}
47
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub struct TimingEvent {
50    pub turn: u32,
51    pub stage: TimingStage,
52    pub since_turn_start_ms: u64,
53    pub since_llm_request_start_ms: Option<u64>,
54    pub duration_ms: Option<u64>,
55    pub label: Option<String>,
56    pub success: Option<bool>,
57}
58
59impl TimingEvent {
60    pub fn new(
61        turn: u32,
62        stage: TimingStage,
63        turn_started_at: std::time::Instant,
64        llm_request_started_at: Option<std::time::Instant>,
65    ) -> Self {
66        let now = std::time::Instant::now();
67        Self {
68            turn,
69            stage,
70            since_turn_start_ms: now.duration_since(turn_started_at).as_millis() as u64,
71            since_llm_request_start_ms: llm_request_started_at
72                .map(|started_at| now.duration_since(started_at).as_millis() as u64),
73            duration_ms: None,
74            label: None,
75            success: None,
76        }
77    }
78
79    pub fn with_duration_ms(mut self, duration_ms: u64) -> Self {
80        self.duration_ms = Some(duration_ms);
81        self
82    }
83
84    pub fn with_label(mut self, label: impl Into<String>) -> Self {
85        self.label = Some(label.into());
86        self
87    }
88
89    pub fn with_success(mut self, success: bool) -> Self {
90        self.success = Some(success);
91        self
92    }
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
96#[serde(rename_all = "snake_case")]
97pub enum RecoveryCheckpointKind {
98    ProviderRequestStart,
99    AssistantToolCallObserved,
100    AssistantMessageFinalized,
101    ToolPlanCreated,
102    ToolExecutionStart,
103    ToolExecutionEnd,
104    ToolResultAddedToContext,
105    ProviderRequestCompleted,
106}
107
108impl RecoveryCheckpointKind {
109    pub fn as_str(self) -> &'static str {
110        match self {
111            Self::ProviderRequestStart => "provider_request_start",
112            Self::AssistantToolCallObserved => "assistant_tool_call_observed",
113            Self::AssistantMessageFinalized => "assistant_message_finalized",
114            Self::ToolPlanCreated => "tool_plan_created",
115            Self::ToolExecutionStart => "tool_execution_start",
116            Self::ToolExecutionEnd => "tool_execution_end",
117            Self::ToolResultAddedToContext => "tool_result_added_to_context",
118            Self::ProviderRequestCompleted => "provider_request_completed",
119        }
120    }
121}
122
123#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
124pub struct RecoveryCheckpoint {
125    pub version: u32,
126    pub turn: u32,
127    pub kind: RecoveryCheckpointKind,
128    pub tool_call_id: Option<String>,
129    pub tool_name: Option<String>,
130    pub args_hash: Option<String>,
131    pub success: Option<bool>,
132    pub error_class: Option<String>,
133    pub timestamp: u64,
134}
135
136/// Events emitted by the agent during execution.
137#[derive(Debug, Clone)]
138pub enum AgentEvent {
139    AgentStart {
140        model: String,
141        timestamp: u64,
142    },
143    AgentEnd {
144        usage: Usage,
145        cost: Cost,
146        status: crate::agent::RunFinalStatus,
147    },
148    TurnStart {
149        index: u32,
150    },
151    TurnAssessment {
152        index: u32,
153        assessment: NextActionAssessment,
154    },
155    TurnEnd {
156        index: u32,
157        message: AssistantMessage,
158        mana_review: TurnManaReview,
159    },
160    MessageStart {
161        message: Message,
162    },
163    MessageDelta {
164        delta: StreamEvent,
165    },
166    MessageEnd {
167        message: Message,
168    },
169    ToolExecutionStart {
170        tool_call_id: String,
171        tool_name: String,
172        args: serde_json::Value,
173    },
174    ToolOutputDelta {
175        tool_call_id: String,
176        text: String,
177    },
178    ToolExecutionEnd {
179        tool_call_id: String,
180        result: imp_llm::ToolResultMessage,
181        provenance: Option<Provenance>,
182    },
183    Warning {
184        message: String,
185    },
186    Timing {
187        timing: TimingEvent,
188    },
189    RecoveryCheckpoint {
190        checkpoint: RecoveryCheckpoint,
191    },
192    VerificationStarted {
193        gate: VerificationGate,
194    },
195    VerificationCompleted {
196        gate: VerificationGate,
197        closeout_effect: crate::workflow::VerificationCloseoutEffect,
198    },
199    EvidenceWritten {
200        path: PathBuf,
201    },
202    PolicyChecked {
203        record: PolicyTraceRecord,
204    },
205    Error {
206        error: String,
207    },
208}
209
210impl AgentEvent {
211    pub fn to_trace_event(&self, run_id: impl Into<String>) -> TraceEvent {
212        let run_id = run_id.into();
213        match self {
214            AgentEvent::AgentStart { model, timestamp } => TraceEvent::new(
215                run_id,
216                "agent.start",
217                json!({ "model": model, "source_timestamp": timestamp }),
218            ),
219            AgentEvent::AgentEnd {
220                usage,
221                cost,
222                status,
223            } => TraceEvent::new(
224                run_id,
225                "agent.end",
226                json!({
227                    "usage": format!("{usage:?}"),
228                    "cost": format!("{cost:?}"),
229                    "status": format!("{status:?}"),
230                }),
231            ),
232            AgentEvent::TurnStart { index } => {
233                TraceEvent::new(run_id, "turn.start", json!({ "index": index })).with_turn(*index)
234            }
235            AgentEvent::TurnAssessment { index, assessment } => TraceEvent::new(
236                run_id,
237                "turn.assessment",
238                json!({ "index": index, "assessment": format!("{assessment:?}") }),
239            )
240            .with_turn(*index),
241            AgentEvent::TurnEnd {
242                index,
243                message,
244                mana_review,
245            } => TraceEvent::new(
246                run_id,
247                "turn.end",
248                json!({
249                    "index": index,
250                    "message": format!("{message:?}"),
251                    "mana_review": format!("{mana_review:?}"),
252                }),
253            )
254            .with_turn(*index),
255            AgentEvent::MessageStart { message } => TraceEvent::new(
256                run_id,
257                "message.start",
258                json!({ "message": format!("{message:?}") }),
259            ),
260            AgentEvent::MessageDelta { delta } => TraceEvent::new(
261                run_id,
262                "message.delta",
263                json!({ "delta": format!("{delta:?}") }),
264            ),
265            AgentEvent::MessageEnd { message } => TraceEvent::new(
266                run_id,
267                "message.end",
268                json!({ "message": format!("{message:?}") }),
269            ),
270            AgentEvent::ToolExecutionStart {
271                tool_call_id,
272                tool_name,
273                args,
274            } => TraceEvent::new(
275                run_id,
276                "tool.execution.start",
277                json!({ "tool_call_id": tool_call_id, "tool_name": tool_name, "args": args }),
278            )
279            .with_tool_call_id(tool_call_id.clone()),
280            AgentEvent::ToolOutputDelta { tool_call_id, text } => TraceEvent::new(
281                run_id,
282                "tool.output.delta",
283                json!({ "tool_call_id": tool_call_id, "text": text }),
284            )
285            .with_tool_call_id(tool_call_id.clone()),
286            AgentEvent::ToolExecutionEnd {
287                tool_call_id,
288                result,
289                provenance,
290            } => TraceEvent::new(
291                run_id,
292                "tool.execution.end",
293                json!({
294                    "tool_call_id": tool_call_id,
295                    "result": format!("{result:?}"),
296                    "provenance": provenance,
297                }),
298            )
299            .with_tool_call_id(tool_call_id.clone()),
300            AgentEvent::Warning { message } => {
301                TraceEvent::new(run_id, "warning", json!({ "message": message }))
302            }
303            AgentEvent::Timing { timing } => TraceEvent::new(
304                run_id,
305                "timing",
306                json!({
307                    "turn": timing.turn,
308                    "stage": timing.stage.as_str(),
309                    "since_turn_start_ms": timing.since_turn_start_ms,
310                    "since_llm_request_start_ms": timing.since_llm_request_start_ms,
311                    "duration_ms": timing.duration_ms,
312                    "label": timing.label,
313                    "success": timing.success,
314                }),
315            )
316            .with_turn(timing.turn),
317            AgentEvent::RecoveryCheckpoint { checkpoint } => {
318                let mut event = TraceEvent::new(
319                    run_id,
320                    "recovery.checkpoint",
321                    json!({
322                        "version": checkpoint.version,
323                        "turn": checkpoint.turn,
324                        "kind": checkpoint.kind.as_str(),
325                        "tool_call_id": checkpoint.tool_call_id,
326                        "tool_name": checkpoint.tool_name,
327                        "args_hash": checkpoint.args_hash,
328                        "success": checkpoint.success,
329                        "error_class": checkpoint.error_class,
330                        "checkpoint_timestamp": checkpoint.timestamp,
331                    }),
332                )
333                .with_turn(checkpoint.turn);
334                if let Some(tool_call_id) = &checkpoint.tool_call_id {
335                    event = event.with_tool_call_id(tool_call_id.clone());
336                }
337                event
338            }
339            AgentEvent::VerificationStarted { gate } => TraceEvent::new(
340                run_id,
341                "verification.started",
342                verification_gate_payload(gate, None),
343            ),
344            AgentEvent::VerificationCompleted {
345                gate,
346                closeout_effect,
347            } => TraceEvent::new(
348                run_id,
349                "verification.completed",
350                verification_gate_payload(gate, Some(*closeout_effect)),
351            ),
352            AgentEvent::EvidenceWritten { path } => TraceEvent::new(
353                run_id,
354                "evidence.written",
355                json!({ "path": path.display().to_string() }),
356            ),
357            AgentEvent::PolicyChecked { record } => record.to_trace_event(run_id),
358            AgentEvent::Error { error } => {
359                TraceEvent::new(run_id, "error", json!({ "error": error }))
360            }
361        }
362    }
363}
364
365fn verification_gate_payload(
366    gate: &VerificationGate,
367    closeout_effect: Option<crate::workflow::VerificationCloseoutEffect>,
368) -> serde_json::Value {
369    json!({
370        "id": gate.id,
371        "name": gate.name,
372        "kind": gate.kind,
373        "required": gate.is_required(),
374        "status": gate.status,
375        "command": gate.command.as_ref().map(|command| &command.command),
376        "exit_code": gate.result.as_ref().and_then(|result| result.exit_code),
377        "summary": gate.result.as_ref().and_then(|result| result.summary.as_deref()).or(gate.reason.as_deref()),
378        "artifacts": gate.artifacts.iter().map(|artifact| json!({
379            "kind": artifact.kind,
380            "path": artifact.path.display().to_string(),
381            "summary": artifact.summary,
382            "bytes": artifact.bytes,
383            "redaction": artifact.redaction,
384        })).collect::<Vec<_>>(),
385        "closeout_effect": closeout_effect,
386    })
387}
388
389#[cfg(test)]
390mod trace_tests {
391    use super::*;
392    use serde_json::json;
393
394    #[test]
395    fn agent_events_convert_to_trace_events() {
396        let start = AgentEvent::AgentStart {
397            model: "test-model".into(),
398            timestamp: 123,
399        }
400        .to_trace_event("run-1");
401        assert_eq!(start.kind, "agent.start");
402        assert_eq!(start.run_id, "run-1");
403        assert_eq!(start.payload["model"], "test-model");
404
405        let tool = AgentEvent::ToolExecutionStart {
406            tool_call_id: "call-1".into(),
407            tool_name: "read".into(),
408            args: json!({"path": "README.md"}),
409        }
410        .to_trace_event("run-1");
411        assert_eq!(tool.kind, "tool.execution.start");
412        assert_eq!(tool.correlation.tool_call_id.as_deref(), Some("call-1"));
413        assert_eq!(tool.payload["tool_name"], "read");
414
415        let timing = AgentEvent::Timing {
416            timing: TimingEvent {
417                turn: 2,
418                stage: TimingStage::LlmRequestStart,
419                since_turn_start_ms: 10,
420                since_llm_request_start_ms: None,
421                duration_ms: None,
422                label: None,
423                success: None,
424            },
425        }
426        .to_trace_event("run-1");
427        assert_eq!(timing.kind, "timing");
428        assert_eq!(timing.turn, Some(2));
429        assert_eq!(timing.payload["stage"], "llm_request_start");
430    }
431
432    #[test]
433    fn recovery_checkpoint_conversion_preserves_correlation() {
434        let event = AgentEvent::RecoveryCheckpoint {
435            checkpoint: RecoveryCheckpoint {
436                version: 1,
437                turn: 3,
438                kind: RecoveryCheckpointKind::ToolExecutionEnd,
439                tool_call_id: Some("call-2".into()),
440                tool_name: Some("bash".into()),
441                args_hash: Some("abc".into()),
442                success: Some(false),
443                error_class: Some("timeout".into()),
444                timestamp: 456,
445            },
446        }
447        .to_trace_event("run-1");
448
449        assert_eq!(event.kind, "recovery.checkpoint");
450        assert_eq!(event.turn, Some(3));
451        assert_eq!(event.correlation.tool_call_id.as_deref(), Some("call-2"));
452        assert_eq!(event.payload["kind"], "tool_execution_end");
453        assert_eq!(event.payload["error_class"], "timeout");
454    }
455}