Skip to main content

adk_plugin/
callbacks.rs

1//! Plugin callback types
2//!
3//! Defines the callback function signatures used by plugins.
4
5use adk_core::{
6    CallbackContext, Content, Event, InvocationContext, LlmRequest, LlmResponse, Result, Tool,
7};
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11
12/// Callback invoked when a user message is received.
13///
14/// Can modify the message content before processing.
15/// Return `Ok(Some(content))` to replace the message, `Ok(None)` to keep original.
16pub type OnUserMessageCallback = Box<
17    dyn Fn(
18            Arc<dyn InvocationContext>,
19            Content,
20        ) -> Pin<Box<dyn Future<Output = Result<Option<Content>>> + Send>>
21        + Send
22        + Sync,
23>;
24
25/// Callback invoked for each event generated by the agent.
26///
27/// Can modify events before they are yielded.
28/// Return `Ok(Some(event))` to replace the event, `Ok(None)` to keep original.
29pub type OnEventCallback = Box<
30    dyn Fn(
31            Arc<dyn InvocationContext>,
32            Event,
33        ) -> Pin<Box<dyn Future<Output = Result<Option<Event>>> + Send>>
34        + Send
35        + Sync,
36>;
37
38/// Callback invoked before the agent run starts.
39///
40/// Can perform setup, validation, or early exit.
41/// Return `Ok(Some(content))` to skip the run and return this content.
42/// Return `Ok(None)` to continue with the run.
43pub type BeforeRunCallback = Box<
44    dyn Fn(
45            Arc<dyn InvocationContext>,
46        ) -> Pin<Box<dyn Future<Output = Result<Option<Content>>> + Send>>
47        + Send
48        + Sync,
49>;
50
51/// Callback invoked after the agent run completes.
52///
53/// Used for cleanup, logging, metrics collection.
54/// This callback does NOT emit events - it's for side effects only.
55pub type AfterRunCallback = Box<
56    dyn Fn(Arc<dyn InvocationContext>) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync,
57>;
58
59/// Callback invoked when a model error occurs.
60///
61/// Can handle or transform model errors.
62/// Return `Ok(Some(response))` to provide a fallback response.
63/// Return `Ok(None)` to propagate the original error.
64pub type OnModelErrorCallback = Box<
65    dyn Fn(
66            Arc<dyn CallbackContext>,
67            LlmRequest,
68            String, // error message
69        ) -> Pin<Box<dyn Future<Output = Result<Option<LlmResponse>>> + Send>>
70        + Send
71        + Sync,
72>;
73
74/// Callback invoked when a tool error occurs.
75///
76/// Can handle or transform tool errors.
77/// Return `Ok(Some(result))` to provide a fallback result.
78/// Return `Ok(None)` to propagate the original error.
79pub type OnToolErrorCallback = Box<
80    dyn Fn(
81            Arc<dyn CallbackContext>,
82            Arc<dyn Tool>,
83            serde_json::Value, // args
84            String,            // error message
85        ) -> Pin<Box<dyn Future<Output = Result<Option<serde_json::Value>>> + Send>>
86        + Send
87        + Sync,
88>;
89
90/// Helper to create a simple logging callback for user messages.
91pub fn log_user_messages() -> OnUserMessageCallback {
92    Box::new(|_ctx, content| {
93        Box::pin(async move {
94            tracing::info!(role = %content.role, parts = ?content.parts.len(), "User message received");
95            Ok(None)
96        })
97    })
98}
99
100/// Helper to create a simple logging callback for events.
101pub fn log_events() -> OnEventCallback {
102    Box::new(|_ctx, event| {
103        Box::pin(async move {
104            tracing::info!(
105                id = %event.id,
106                author = %event.author,
107                partial = event.llm_response.partial,
108                "Event generated"
109            );
110            Ok(None)
111        })
112    })
113}
114
115/// Helper to create a metrics collection callback.
116pub fn collect_metrics(
117    on_run_start: impl Fn() + Send + Sync + 'static,
118    on_run_end: impl Fn() + Send + Sync + 'static,
119) -> (BeforeRunCallback, AfterRunCallback) {
120    let start_fn = Arc::new(on_run_start);
121    let end_fn = Arc::new(on_run_end);
122
123    let before = Box::new(move |_ctx: Arc<dyn InvocationContext>| {
124        let f = start_fn.clone();
125        Box::pin(async move {
126            f();
127            Ok(None)
128        }) as Pin<Box<dyn Future<Output = Result<Option<Content>>> + Send>>
129    });
130
131    let after = Box::new(move |_ctx: Arc<dyn InvocationContext>| {
132        let f = end_fn.clone();
133        Box::pin(async move {
134            f();
135        }) as Pin<Box<dyn Future<Output = ()> + Send>>
136    });
137
138    (before, after)
139}