Skip to main content

adk_plugin/
callbacks.rs

1//! Plugin callback types
2//!
3//! Defines the callback function signatures used by plugins.
4
5pub use adk_core::OnToolErrorCallback;
6use adk_core::{
7    CallbackContext, Content, Event, InvocationContext, LlmRequest, LlmResponse, Result,
8};
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12
13/// Callback invoked when a user message is received.
14///
15/// Can modify the message content before processing.
16/// Return `Ok(Some(content))` to replace the message, `Ok(None)` to keep original.
17pub type OnUserMessageCallback = Box<
18    dyn Fn(
19            Arc<dyn InvocationContext>,
20            Content,
21        ) -> Pin<Box<dyn Future<Output = Result<Option<Content>>> + Send>>
22        + Send
23        + Sync,
24>;
25
26/// Callback invoked for each event generated by the agent.
27///
28/// Can modify events before they are yielded.
29/// Return `Ok(Some(event))` to replace the event, `Ok(None)` to keep original.
30pub type OnEventCallback = Box<
31    dyn Fn(
32            Arc<dyn InvocationContext>,
33            Event,
34        ) -> Pin<Box<dyn Future<Output = Result<Option<Event>>> + Send>>
35        + Send
36        + Sync,
37>;
38
39/// Callback invoked before the agent run starts.
40///
41/// Can perform setup, validation, or early exit.
42/// Return `Ok(Some(content))` to skip the run and return this content.
43/// Return `Ok(None)` to continue with the run.
44pub type BeforeRunCallback = Box<
45    dyn Fn(
46            Arc<dyn InvocationContext>,
47        ) -> Pin<Box<dyn Future<Output = Result<Option<Content>>> + Send>>
48        + Send
49        + Sync,
50>;
51
52/// Callback invoked after the agent run completes.
53///
54/// Used for cleanup, logging, metrics collection.
55/// This callback does NOT emit events - it's for side effects only.
56pub type AfterRunCallback = Box<
57    dyn Fn(Arc<dyn InvocationContext>) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync,
58>;
59
60/// Callback invoked when a model error occurs.
61///
62/// Can handle or transform model errors.
63/// Return `Ok(Some(response))` to provide a fallback response.
64/// Return `Ok(None)` to propagate the original error.
65pub type OnModelErrorCallback = Box<
66    dyn Fn(
67            Arc<dyn CallbackContext>,
68            LlmRequest,
69            String, // error message
70        ) -> Pin<Box<dyn Future<Output = Result<Option<LlmResponse>>> + Send>>
71        + Send
72        + Sync,
73>;
74
75/// Helper to create a simple logging callback for user messages.
76pub fn log_user_messages() -> OnUserMessageCallback {
77    Box::new(|_ctx, content| {
78        Box::pin(async move {
79            tracing::info!(role = %content.role, parts = ?content.parts.len(), "User message received");
80            Ok(None)
81        })
82    })
83}
84
85/// Helper to create a simple logging callback for events.
86pub fn log_events() -> OnEventCallback {
87    Box::new(|_ctx, event| {
88        Box::pin(async move {
89            tracing::info!(
90                id = %event.id,
91                author = %event.author,
92                partial = event.llm_response.partial,
93                "Event generated"
94            );
95            Ok(None)
96        })
97    })
98}
99
100/// Helper to create a metrics collection callback.
101pub fn collect_metrics(
102    on_run_start: impl Fn() + Send + Sync + 'static,
103    on_run_end: impl Fn() + Send + Sync + 'static,
104) -> (BeforeRunCallback, AfterRunCallback) {
105    let start_fn = Arc::new(on_run_start);
106    let end_fn = Arc::new(on_run_end);
107
108    let before = Box::new(move |_ctx: Arc<dyn InvocationContext>| {
109        let f = start_fn.clone();
110        Box::pin(async move {
111            f();
112            Ok(None)
113        }) as Pin<Box<dyn Future<Output = Result<Option<Content>>> + Send>>
114    });
115
116    let after = Box::new(move |_ctx: Arc<dyn InvocationContext>| {
117        let f = end_fn.clone();
118        Box::pin(async move {
119            f();
120        }) as Pin<Box<dyn Future<Output = ()> + Send>>
121    });
122
123    (before, after)
124}