gemini_cli_sdk/callback.rs
1//! Message callback type for real-time message notifications.
2//!
3//! The [`MessageCallback`] type is the primary extension point for consumers
4//! who need to observe messages as they stream — for logging, UI updates,
5//! persistence, metrics, or any other side effect — without interfering with
6//! the primary stream consumption.
7//!
8//! # Design
9//!
10//! Callbacks are `Arc<dyn Fn(Message) -> Pin<Box<dyn Future<Output = ()> + Send>>>`.
11//! This makes them:
12//!
13//! - **Cloneable** — can be shared across tasks via `Arc::clone`.
14//! - **Thread-safe** — `Send + Sync` bounds on the inner `Fn`.
15//! - **Async-capable** — the callback can await I/O (e.g., writing to a database).
16//! - **Zero-overhead for sync work** — use [`sync_callback`] to wrap a plain closure.
17//!
18//! # Example
19//!
20//! ```rust
21//! use gemini_cli_sdk::callback::{sync_callback, tracing_callback};
22//!
23//! // Simple logging callback
24//! let cb = sync_callback(|msg| {
25//! println!("received: {:?}", msg.session_id());
26//! });
27//!
28//! // Built-in tracing callback
29//! let _trace_cb = tracing_callback();
30//! ```
31
32use std::future::Future;
33use std::pin::Pin;
34use std::sync::Arc;
35
36use crate::types::messages::Message;
37
38// ── MessageCallback type alias ──────────────────────────────────────────────
39
40/// Callback invoked for each message received from the agent.
41///
42/// Used for side effects (logging, UI updates, persistence) while the primary
43/// message stream is consumed by the caller.
44///
45/// # Thread Safety
46///
47/// `MessageCallback` is `Send + Sync`, so it can be safely shared across
48/// threads and tasks via `Arc::clone`.
49///
50/// # Creating Callbacks
51///
52/// Use [`sync_callback`] for synchronous side effects, or construct directly
53/// with `Arc::new(|msg| Box::pin(async move { ... }))` for async work.
54pub type MessageCallback = Arc<
55 dyn Fn(Message) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync,
56>;
57
58// ── Helper constructors ─────────────────────────────────────────────────────
59
60/// Wrap a synchronous closure in a [`MessageCallback`].
61///
62/// This is the ergonomic constructor for callbacks that perform only
63/// synchronous work. The closure receives ownership of each [`Message`]
64/// and returns immediately; the resulting future resolves in a single poll.
65///
66/// # Example
67///
68/// ```rust
69/// use gemini_cli_sdk::callback::sync_callback;
70///
71/// let cb = sync_callback(|msg| {
72/// println!("message for session {:?}", msg.session_id());
73/// });
74/// ```
75pub fn sync_callback<F>(f: F) -> MessageCallback
76where
77 F: Fn(Message) + Send + Sync + 'static,
78{
79 Arc::new(move |msg| {
80 f(msg);
81 Box::pin(async {})
82 })
83}
84
85/// Create a [`MessageCallback`] that emits structured [`tracing`] events for
86/// every message variant.
87///
88/// | Variant | Level | Fields |
89/// |-----------------|---------|---------------------------------|
90/// | `System` | `INFO` | `session_id` |
91/// | `Assistant` | `DEBUG` | _(message variant only)_ |
92/// | `User` | `DEBUG` | _(message variant only)_ |
93/// | `Result` | `INFO` | `stop_reason`, `is_error` |
94/// | `StreamEvent` | `TRACE` | `event_type` |
95///
96/// # Example
97///
98/// ```rust
99/// use gemini_cli_sdk::callback::tracing_callback;
100///
101/// let cb = tracing_callback();
102/// // Pass `cb` to `ClientConfig::builder().on_message(cb).build()`.
103/// ```
104pub fn tracing_callback() -> MessageCallback {
105 sync_callback(|msg| {
106 match &msg {
107 Message::System(s) => {
108 tracing::info!(session_id = %s.session_id, "System message");
109 }
110 Message::Assistant(_) => {
111 tracing::debug!("Assistant message");
112 }
113 Message::User(_) => {
114 tracing::debug!("User message");
115 }
116 Message::Result(r) => {
117 tracing::info!(
118 stop_reason = %r.stop_reason,
119 is_error = r.is_error,
120 "Result message"
121 );
122 }
123 Message::StreamEvent(e) => {
124 tracing::trace!(
125 event_type = %e.event_type,
126 "Stream event"
127 );
128 }
129 }
130 })
131}
132
133// ── Tests ───────────────────────────────────────────────────────────────────
134
135#[cfg(test)]
136mod tests {
137 use std::sync::{Arc, Mutex};
138
139 use serde_json::Value;
140
141 use super::{sync_callback, tracing_callback};
142 use crate::types::messages::{
143 AssistantMessage, AssistantMessageInner, Message, ResultMessage, StreamEvent, SystemMessage,
144 Usage, UserMessage, UserMessageInner,
145 };
146
147 // ── Test fixture helpers ────────────────────────────────────────────
148
149 fn make_system_message() -> Message {
150 Message::System(SystemMessage {
151 subtype: "init".to_owned(),
152 session_id: "sess-test".to_owned(),
153 cwd: "/tmp".to_owned(),
154 tools: vec![],
155 mcp_servers: vec![],
156 model: "gemini-2.5-pro".to_owned(),
157 extra: Value::Object(Default::default()),
158 })
159 }
160
161 fn make_assistant_message() -> Message {
162 Message::Assistant(AssistantMessage {
163 message: AssistantMessageInner {
164 role: "assistant".to_owned(),
165 content: vec![],
166 model: "gemini-2.5-pro".to_owned(),
167 stop_reason: "end_turn".to_owned(),
168 stop_sequence: None,
169 extra: Value::Object(Default::default()),
170 },
171 session_id: "sess-test".to_owned(),
172 })
173 }
174
175 fn make_user_message() -> Message {
176 Message::User(UserMessage {
177 message: UserMessageInner {
178 role: "user".to_owned(),
179 content: vec![],
180 extra: Value::Object(Default::default()),
181 },
182 session_id: "sess-test".to_owned(),
183 })
184 }
185
186 fn make_result_message() -> Message {
187 Message::Result(ResultMessage {
188 subtype: "success".to_owned(),
189 is_error: false,
190 duration_ms: 42.0,
191 duration_api_ms: 38.0,
192 num_turns: 1,
193 session_id: "sess-test".to_owned(),
194 usage: Usage::default(),
195 stop_reason: "end_turn".to_owned(),
196 extra: Value::Object(Default::default()),
197 })
198 }
199
200 fn make_stream_event_message() -> Message {
201 Message::StreamEvent(StreamEvent {
202 event_type: "tool_call_start".to_owned(),
203 data: Value::Object(Default::default()),
204 session_id: "sess-test".to_owned(),
205 })
206 }
207
208 /// Collect all five message variants into a `Vec` for exhaustive tests.
209 fn all_message_variants() -> Vec<Message> {
210 vec![
211 make_system_message(),
212 make_assistant_message(),
213 make_user_message(),
214 make_result_message(),
215 make_stream_event_message(),
216 ]
217 }
218
219 // ── test_sync_callback_receives_message ─────────────────────────────
220
221 /// Verify that a `sync_callback` is invoked exactly once per message and
222 /// that the captured message matches what was passed in.
223 #[tokio::test]
224 async fn test_sync_callback_receives_message() {
225 let captured: Arc<Mutex<Vec<Message>>> = Arc::new(Mutex::new(Vec::new()));
226 let captured_clone = Arc::clone(&captured);
227
228 let cb = sync_callback(move |msg| {
229 captured_clone
230 .lock()
231 .expect("mutex not poisoned")
232 .push(msg);
233 });
234
235 let msg = make_system_message();
236 // Call the callback and await the returned future.
237 cb(msg.clone()).await;
238
239 let messages = captured.lock().expect("mutex not poisoned");
240 assert_eq!(messages.len(), 1, "callback must be called exactly once");
241 assert_eq!(
242 messages[0], msg,
243 "captured message must equal the one passed in"
244 );
245 }
246
247 /// Verify that multiple messages are all captured, in order.
248 #[tokio::test]
249 async fn test_sync_callback_receives_multiple_messages() {
250 let captured: Arc<Mutex<Vec<Message>>> = Arc::new(Mutex::new(Vec::new()));
251 let captured_clone = Arc::clone(&captured);
252
253 let cb = sync_callback(move |msg| {
254 captured_clone
255 .lock()
256 .expect("mutex not poisoned")
257 .push(msg);
258 });
259
260 let variants = all_message_variants();
261 for msg in &variants {
262 cb(msg.clone()).await;
263 }
264
265 let messages = captured.lock().expect("mutex not poisoned");
266 assert_eq!(
267 messages.len(),
268 variants.len(),
269 "all messages must be captured"
270 );
271 for (i, (got, expected)) in messages.iter().zip(variants.iter()).enumerate() {
272 assert_eq!(got, expected, "message at index {i} must match");
273 }
274 }
275
276 // ── test_tracing_callback_does_not_panic ────────────────────────────
277
278 /// Verify that `tracing_callback` handles all `Message` variants without
279 /// panicking. No tracing subscriber is installed — the macros are no-ops
280 /// when there is no subscriber, which is exactly the behavior we rely on.
281 #[tokio::test]
282 async fn test_tracing_callback_does_not_panic() {
283 let cb = tracing_callback();
284
285 for msg in all_message_variants() {
286 // This must not panic for any variant.
287 cb(msg).await;
288 }
289 }
290
291 /// Verify that `tracing_callback` correctly handles the error `Result`
292 /// variant without panicking.
293 #[tokio::test]
294 async fn test_tracing_callback_error_result_does_not_panic() {
295 let cb = tracing_callback();
296
297 let error_result = Message::Result(ResultMessage {
298 subtype: "error".to_owned(),
299 is_error: true,
300 duration_ms: 0.0,
301 duration_api_ms: 0.0,
302 num_turns: 0,
303 session_id: "sess-err".to_owned(),
304 usage: Usage::default(),
305 stop_reason: "error".to_owned(),
306 extra: Value::Object(Default::default()),
307 });
308
309 cb(error_result).await;
310 }
311
312 // ── Cloneability and Send + Sync ────────────────────────────────────
313
314 /// Verify that a `MessageCallback` can be cloned and used from multiple
315 /// locations (Arc semantics).
316 #[tokio::test]
317 async fn test_callback_is_cloneable() {
318 let counter: Arc<Mutex<u32>> = Arc::new(Mutex::new(0));
319 let counter_clone = Arc::clone(&counter);
320
321 let cb = sync_callback(move |_msg| {
322 *counter_clone.lock().expect("mutex not poisoned") += 1;
323 });
324
325 // Clone the Arc — both point to the same closure.
326 let cb2 = Arc::clone(&cb);
327
328 cb(make_system_message()).await;
329 cb2(make_result_message()).await;
330
331 let count = *counter.lock().expect("mutex not poisoned");
332 assert_eq!(count, 2, "both cloned callbacks must share state");
333 }
334
335 /// Compile-time check: `MessageCallback` must be `Send + Sync`.
336 ///
337 /// This test body is empty — it exists solely to trigger a compile error
338 /// if the type bounds regress.
339 #[test]
340 fn test_callback_is_send_sync() {
341 fn assert_send_sync<T: Send + Sync>() {}
342 assert_send_sync::<super::MessageCallback>();
343 }
344}