Skip to main content

claude_cli_sdk/
callback.rs

1//! Message callback — a simple observe/filter hook for SDK consumers.
2//!
3//! The [`MessageCallback`] is invoked for each [`Message`]
4//! received from the CLI. It can:
5//!
6//! - **Observe** messages (return `Some(msg)` unchanged)
7//! - **Transform** messages (return `Some(modified_msg)`)
8//! - **Filter** messages (return `None` to suppress)
9//!
10//! This is a lightweight SDK-level hook for message observation and filtering.
11//!
12//! # Example
13//!
14//! ```rust
15//! use std::sync::Arc;
16//! use claude_cli_sdk::callback::MessageCallback;
17//! use claude_cli_sdk::Message;
18//!
19//! // Log all messages, pass them through unchanged:
20//! let logger: MessageCallback = Arc::new(|msg: Message| {
21//!     eprintln!("received: {msg:?}");
22//!     Some(msg)
23//! });
24//!
25//! // Filter out system messages:
26//! let filter: MessageCallback = Arc::new(|msg: Message| {
27//!     match &msg {
28//!         Message::System(_) => None,
29//!         _ => Some(msg),
30//!     }
31//! });
32//! ```
33
34use std::sync::Arc;
35
36use crate::types::messages::Message;
37
38/// Optional callback invoked for each message received from the CLI.
39///
40/// - Return `Some(msg)` to pass the message through (possibly transformed).
41/// - Return `None` to filter the message out of the stream.
42///
43/// When no callback is configured, all messages pass through unchanged.
44pub type MessageCallback = Arc<dyn Fn(Message) -> Option<Message> + Send + Sync>;
45
46/// Apply a message callback to a message, or pass through if no callback is set.
47#[inline]
48pub fn apply_callback(msg: Message, callback: Option<&MessageCallback>) -> Option<Message> {
49    match callback {
50        Some(cb) => cb(msg),
51        None => Some(msg),
52    }
53}
54
55// ── Tests ────────────────────────────────────────────────────────────────────
56
57#[cfg(test)]
58mod tests {
59    use super::*;
60    use crate::types::content::TextBlock;
61    use crate::types::messages::*;
62
63    fn make_system_msg() -> Message {
64        Message::System(SystemMessage {
65            subtype: "init".into(),
66            session_id: "s1".into(),
67            cwd: "/tmp".into(),
68            tools: vec![],
69            mcp_servers: vec![],
70            model: "test".into(),
71            extra: serde_json::Value::Object(Default::default()),
72        })
73    }
74
75    fn make_assistant_msg(text: &str) -> Message {
76        Message::Assistant(AssistantMessage {
77            message: AssistantMessageInner {
78                id: "m1".into(),
79                content: vec![crate::types::content::ContentBlock::Text(TextBlock {
80                    text: text.into(),
81                })],
82                model: "test".into(),
83                stop_reason: None,
84                usage: Usage::default(),
85            },
86            session_id: None,
87            extra: serde_json::Value::Object(Default::default()),
88        })
89    }
90
91    #[test]
92    fn no_callback_passes_through() {
93        let msg = make_system_msg();
94        let result = apply_callback(msg.clone(), None);
95        assert!(result.is_some());
96    }
97
98    #[test]
99    fn callback_can_filter() {
100        let filter: MessageCallback = Arc::new(|msg| match &msg {
101            Message::System(_) => None,
102            _ => Some(msg),
103        });
104
105        assert!(apply_callback(make_system_msg(), Some(&filter)).is_none());
106        assert!(apply_callback(make_assistant_msg("hi"), Some(&filter)).is_some());
107    }
108
109    #[test]
110    fn callback_can_transform() {
111        let transform: MessageCallback = Arc::new(|msg| {
112            // Pass through but we could modify here
113            Some(msg)
114        });
115        let msg = make_assistant_msg("hello");
116        let result = apply_callback(msg, Some(&transform));
117        assert!(result.is_some());
118    }
119
120    #[test]
121    fn callback_can_observe() {
122        use std::sync::atomic::{AtomicUsize, Ordering};
123        let count = Arc::new(AtomicUsize::new(0));
124        let count_clone = Arc::clone(&count);
125
126        let observer: MessageCallback = Arc::new(move |msg| {
127            count_clone.fetch_add(1, Ordering::Relaxed);
128            Some(msg)
129        });
130
131        apply_callback(make_system_msg(), Some(&observer));
132        apply_callback(make_assistant_msg("test"), Some(&observer));
133
134        assert_eq!(count.load(Ordering::Relaxed), 2);
135    }
136}