Skip to main content

gemini_cli_sdk/
callback.rs

1//! Message callback type for real-time message notifications.
2//!
3//! The [`MessageCallback`] type is the primary extension point for consumers
4//! who need to observe messages as they stream — for logging, UI updates,
5//! persistence, metrics, or any other side effect — without interfering with
6//! the primary stream consumption.
7//!
8//! # Design
9//!
10//! Callbacks are `Arc<dyn Fn(Message) -> Pin<Box<dyn Future<Output = ()> + Send>>>`.
11//! This makes them:
12//!
13//! - **Cloneable** — can be shared across tasks via `Arc::clone`.
14//! - **Thread-safe** — `Send + Sync` bounds on the inner `Fn`.
15//! - **Async-capable** — the callback can await I/O (e.g., writing to a database).
16//! - **Zero-overhead for sync work** — use [`sync_callback`] to wrap a plain closure.
17//!
18//! # Example
19//!
20//! ```rust
21//! use gemini_cli_sdk::callback::{sync_callback, tracing_callback};
22//!
23//! // Simple logging callback
24//! let cb = sync_callback(|msg| {
25//!     println!("received: {:?}", msg.session_id());
26//! });
27//!
28//! // Built-in tracing callback
29//! let _trace_cb = tracing_callback();
30//! ```
31
32use std::future::Future;
33use std::pin::Pin;
34use std::sync::Arc;
35
36use crate::types::messages::Message;
37
38// ── MessageCallback type alias ──────────────────────────────────────────────
39
40/// Callback invoked for each message received from the agent.
41///
42/// Used for side effects (logging, UI updates, persistence) while the primary
43/// message stream is consumed by the caller.
44///
45/// # Thread Safety
46///
47/// `MessageCallback` is `Send + Sync`, so it can be safely shared across
48/// threads and tasks via `Arc::clone`.
49///
50/// # Creating Callbacks
51///
52/// Use [`sync_callback`] for synchronous side effects, or construct directly
53/// with `Arc::new(|msg| Box::pin(async move { ... }))` for async work.
54pub type MessageCallback = Arc<
55    dyn Fn(Message) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync,
56>;
57
58// ── Helper constructors ─────────────────────────────────────────────────────
59
60/// Wrap a synchronous closure in a [`MessageCallback`].
61///
62/// This is the ergonomic constructor for callbacks that perform only
63/// synchronous work. The closure receives ownership of each [`Message`]
64/// and returns immediately; the resulting future resolves in a single poll.
65///
66/// # Example
67///
68/// ```rust
69/// use gemini_cli_sdk::callback::sync_callback;
70///
71/// let cb = sync_callback(|msg| {
72///     println!("message for session {:?}", msg.session_id());
73/// });
74/// ```
75pub fn sync_callback<F>(f: F) -> MessageCallback
76where
77    F: Fn(Message) + Send + Sync + 'static,
78{
79    Arc::new(move |msg| {
80        f(msg);
81        Box::pin(async {})
82    })
83}
84
85/// Create a [`MessageCallback`] that emits structured [`tracing`] events for
86/// every message variant.
87///
88/// | Variant         | Level   | Fields                          |
89/// |-----------------|---------|---------------------------------|
90/// | `System`        | `INFO`  | `session_id`                    |
91/// | `Assistant`     | `DEBUG` | _(message variant only)_        |
92/// | `User`          | `DEBUG` | _(message variant only)_        |
93/// | `Result`        | `INFO`  | `stop_reason`, `is_error`       |
94/// | `StreamEvent`   | `TRACE` | `event_type`                    |
95///
96/// # Example
97///
98/// ```rust
99/// use gemini_cli_sdk::callback::tracing_callback;
100///
101/// let cb = tracing_callback();
102/// // Pass `cb` to `ClientConfig::builder().on_message(cb).build()`.
103/// ```
104pub fn tracing_callback() -> MessageCallback {
105    sync_callback(|msg| {
106        match &msg {
107            Message::System(s) => {
108                tracing::info!(session_id = %s.session_id, "System message");
109            }
110            Message::Assistant(_) => {
111                tracing::debug!("Assistant message");
112            }
113            Message::User(_) => {
114                tracing::debug!("User message");
115            }
116            Message::Result(r) => {
117                tracing::info!(
118                    stop_reason = %r.stop_reason,
119                    is_error = r.is_error,
120                    "Result message"
121                );
122            }
123            Message::StreamEvent(e) => {
124                tracing::trace!(
125                    event_type = %e.event_type,
126                    "Stream event"
127                );
128            }
129        }
130    })
131}
132
133// ── Tests ───────────────────────────────────────────────────────────────────
134
135#[cfg(test)]
136mod tests {
137    use std::sync::{Arc, Mutex};
138
139    use serde_json::Value;
140
141    use super::{sync_callback, tracing_callback};
142    use crate::types::messages::{
143        AssistantMessage, AssistantMessageInner, Message, ResultMessage, StreamEvent, SystemMessage,
144        Usage, UserMessage, UserMessageInner,
145    };
146
147    // ── Test fixture helpers ────────────────────────────────────────────
148
149    fn make_system_message() -> Message {
150        Message::System(SystemMessage {
151            subtype: "init".to_owned(),
152            session_id: "sess-test".to_owned(),
153            cwd: "/tmp".to_owned(),
154            tools: vec![],
155            mcp_servers: vec![],
156            model: "gemini-2.5-pro".to_owned(),
157            extra: Value::Object(Default::default()),
158        })
159    }
160
161    fn make_assistant_message() -> Message {
162        Message::Assistant(AssistantMessage {
163            message: AssistantMessageInner {
164                role: "assistant".to_owned(),
165                content: vec![],
166                model: "gemini-2.5-pro".to_owned(),
167                stop_reason: "end_turn".to_owned(),
168                stop_sequence: None,
169                extra: Value::Object(Default::default()),
170            },
171            session_id: "sess-test".to_owned(),
172        })
173    }
174
175    fn make_user_message() -> Message {
176        Message::User(UserMessage {
177            message: UserMessageInner {
178                role: "user".to_owned(),
179                content: vec![],
180                extra: Value::Object(Default::default()),
181            },
182            session_id: "sess-test".to_owned(),
183        })
184    }
185
186    fn make_result_message() -> Message {
187        Message::Result(ResultMessage {
188            subtype: "success".to_owned(),
189            is_error: false,
190            duration_ms: 42.0,
191            duration_api_ms: 38.0,
192            num_turns: 1,
193            session_id: "sess-test".to_owned(),
194            usage: Usage::default(),
195            stop_reason: "end_turn".to_owned(),
196            extra: Value::Object(Default::default()),
197        })
198    }
199
200    fn make_stream_event_message() -> Message {
201        Message::StreamEvent(StreamEvent {
202            event_type: "tool_call_start".to_owned(),
203            data: Value::Object(Default::default()),
204            session_id: "sess-test".to_owned(),
205        })
206    }
207
208    /// Collect all five message variants into a `Vec` for exhaustive tests.
209    fn all_message_variants() -> Vec<Message> {
210        vec![
211            make_system_message(),
212            make_assistant_message(),
213            make_user_message(),
214            make_result_message(),
215            make_stream_event_message(),
216        ]
217    }
218
219    // ── test_sync_callback_receives_message ─────────────────────────────
220
221    /// Verify that a `sync_callback` is invoked exactly once per message and
222    /// that the captured message matches what was passed in.
223    #[tokio::test]
224    async fn test_sync_callback_receives_message() {
225        let captured: Arc<Mutex<Vec<Message>>> = Arc::new(Mutex::new(Vec::new()));
226        let captured_clone = Arc::clone(&captured);
227
228        let cb = sync_callback(move |msg| {
229            captured_clone
230                .lock()
231                .expect("mutex not poisoned")
232                .push(msg);
233        });
234
235        let msg = make_system_message();
236        // Call the callback and await the returned future.
237        cb(msg.clone()).await;
238
239        let messages = captured.lock().expect("mutex not poisoned");
240        assert_eq!(messages.len(), 1, "callback must be called exactly once");
241        assert_eq!(
242            messages[0], msg,
243            "captured message must equal the one passed in"
244        );
245    }
246
247    /// Verify that multiple messages are all captured, in order.
248    #[tokio::test]
249    async fn test_sync_callback_receives_multiple_messages() {
250        let captured: Arc<Mutex<Vec<Message>>> = Arc::new(Mutex::new(Vec::new()));
251        let captured_clone = Arc::clone(&captured);
252
253        let cb = sync_callback(move |msg| {
254            captured_clone
255                .lock()
256                .expect("mutex not poisoned")
257                .push(msg);
258        });
259
260        let variants = all_message_variants();
261        for msg in &variants {
262            cb(msg.clone()).await;
263        }
264
265        let messages = captured.lock().expect("mutex not poisoned");
266        assert_eq!(
267            messages.len(),
268            variants.len(),
269            "all messages must be captured"
270        );
271        for (i, (got, expected)) in messages.iter().zip(variants.iter()).enumerate() {
272            assert_eq!(got, expected, "message at index {i} must match");
273        }
274    }
275
276    // ── test_tracing_callback_does_not_panic ────────────────────────────
277
278    /// Verify that `tracing_callback` handles all `Message` variants without
279    /// panicking. No tracing subscriber is installed — the macros are no-ops
280    /// when there is no subscriber, which is exactly the behavior we rely on.
281    #[tokio::test]
282    async fn test_tracing_callback_does_not_panic() {
283        let cb = tracing_callback();
284
285        for msg in all_message_variants() {
286            // This must not panic for any variant.
287            cb(msg).await;
288        }
289    }
290
291    /// Verify that `tracing_callback` correctly handles the error `Result`
292    /// variant without panicking.
293    #[tokio::test]
294    async fn test_tracing_callback_error_result_does_not_panic() {
295        let cb = tracing_callback();
296
297        let error_result = Message::Result(ResultMessage {
298            subtype: "error".to_owned(),
299            is_error: true,
300            duration_ms: 0.0,
301            duration_api_ms: 0.0,
302            num_turns: 0,
303            session_id: "sess-err".to_owned(),
304            usage: Usage::default(),
305            stop_reason: "error".to_owned(),
306            extra: Value::Object(Default::default()),
307        });
308
309        cb(error_result).await;
310    }
311
312    // ── Cloneability and Send + Sync ────────────────────────────────────
313
314    /// Verify that a `MessageCallback` can be cloned and used from multiple
315    /// locations (Arc semantics).
316    #[tokio::test]
317    async fn test_callback_is_cloneable() {
318        let counter: Arc<Mutex<u32>> = Arc::new(Mutex::new(0));
319        let counter_clone = Arc::clone(&counter);
320
321        let cb = sync_callback(move |_msg| {
322            *counter_clone.lock().expect("mutex not poisoned") += 1;
323        });
324
325        // Clone the Arc — both point to the same closure.
326        let cb2 = Arc::clone(&cb);
327
328        cb(make_system_message()).await;
329        cb2(make_result_message()).await;
330
331        let count = *counter.lock().expect("mutex not poisoned");
332        assert_eq!(count, 2, "both cloned callbacks must share state");
333    }
334
335    /// Compile-time check: `MessageCallback` must be `Send + Sync`.
336    ///
337    /// This test body is empty — it exists solely to trigger a compile error
338    /// if the type bounds regress.
339    #[test]
340    fn test_callback_is_send_sync() {
341        fn assert_send_sync<T: Send + Sync>() {}
342        assert_send_sync::<super::MessageCallback>();
343    }
344}