potato_agent/agents/
callbacks.rs1use crate::agents::run_context::AgentRunContext;
2use potato_type::prompt::Prompt;
3use potato_type::tools::ToolCall;
4use serde_json::Value;
5use std::fmt::Debug;
6
7use crate::agents::types::AgentResponse;
8
9#[derive(Debug)]
11pub enum CallbackAction {
12 Continue,
14 OverrideResponse(String),
16 Abort(String),
18}
19
20pub trait AgentCallback: Send + Sync + Debug {
22 fn before_model_call(&self, _ctx: &AgentRunContext, _prompt: &Prompt) -> CallbackAction {
23 CallbackAction::Continue
24 }
25
26 fn after_model_call(
27 &self,
28 _ctx: &AgentRunContext,
29 _response: &AgentResponse,
30 ) -> CallbackAction {
31 CallbackAction::Continue
32 }
33
34 fn before_tool_call(&self, _ctx: &AgentRunContext, _call: &ToolCall) -> CallbackAction {
35 CallbackAction::Continue
36 }
37
38 fn after_tool_call(
39 &self,
40 _ctx: &AgentRunContext,
41 _call: &ToolCall,
42 _result: &Value,
43 ) -> CallbackAction {
44 CallbackAction::Continue
45 }
46}
47
48#[derive(Debug)]
49pub struct LoggingCallback;
50
51impl AgentCallback for LoggingCallback {
52 fn before_model_call(&self, ctx: &AgentRunContext, _prompt: &Prompt) -> CallbackAction {
53 tracing::info!(agent_id = %ctx.agent_id, iteration = ctx.iteration, "before model call");
54 CallbackAction::Continue
55 }
56
57 fn after_model_call(&self, ctx: &AgentRunContext, response: &AgentResponse) -> CallbackAction {
58 tracing::info!(
59 agent_id = %ctx.agent_id,
60 iteration = ctx.iteration,
61 response_len = response.response_text().len(),
62 "after model call"
63 );
64 CallbackAction::Continue
65 }
66
67 fn before_tool_call(&self, ctx: &AgentRunContext, call: &ToolCall) -> CallbackAction {
68 tracing::info!(agent_id = %ctx.agent_id, tool = %call.tool_name, "before tool call");
69 CallbackAction::Continue
70 }
71
72 fn after_tool_call(
73 &self,
74 ctx: &AgentRunContext,
75 call: &ToolCall,
76 _result: &Value,
77 ) -> CallbackAction {
78 tracing::info!(agent_id = %ctx.agent_id, tool = %call.tool_name, "after tool call");
79 CallbackAction::Continue
80 }
81}