claude_cli_sdk/
callback.rs1use std::sync::Arc;
35
36use crate::types::messages::Message;
37
38pub type MessageCallback = Arc<dyn Fn(Message) -> Option<Message> + Send + Sync>;
45
46#[inline]
48pub fn apply_callback(msg: Message, callback: Option<&MessageCallback>) -> Option<Message> {
49 match callback {
50 Some(cb) => cb(msg),
51 None => Some(msg),
52 }
53}
54
55#[cfg(test)]
58mod tests {
59 use super::*;
60 use crate::types::content::TextBlock;
61 use crate::types::messages::*;
62
63 fn make_system_msg() -> Message {
64 Message::System(SystemMessage {
65 subtype: "init".into(),
66 session_id: "s1".into(),
67 cwd: "/tmp".into(),
68 tools: vec![],
69 mcp_servers: vec![],
70 model: "test".into(),
71 extra: serde_json::Value::Object(Default::default()),
72 })
73 }
74
75 fn make_assistant_msg(text: &str) -> Message {
76 Message::Assistant(AssistantMessage {
77 message: AssistantMessageInner {
78 id: "m1".into(),
79 content: vec![crate::types::content::ContentBlock::Text(TextBlock {
80 text: text.into(),
81 })],
82 model: "test".into(),
83 stop_reason: None,
84 usage: Usage::default(),
85 },
86 session_id: None,
87 extra: serde_json::Value::Object(Default::default()),
88 })
89 }
90
91 #[test]
92 fn no_callback_passes_through() {
93 let msg = make_system_msg();
94 let result = apply_callback(msg.clone(), None);
95 assert!(result.is_some());
96 }
97
98 #[test]
99 fn callback_can_filter() {
100 let filter: MessageCallback = Arc::new(|msg| match &msg {
101 Message::System(_) => None,
102 _ => Some(msg),
103 });
104
105 assert!(apply_callback(make_system_msg(), Some(&filter)).is_none());
106 assert!(apply_callback(make_assistant_msg("hi"), Some(&filter)).is_some());
107 }
108
109 #[test]
110 fn callback_can_transform() {
111 let transform: MessageCallback = Arc::new(|msg| {
112 Some(msg)
114 });
115 let msg = make_assistant_msg("hello");
116 let result = apply_callback(msg, Some(&transform));
117 assert!(result.is_some());
118 }
119
120 #[test]
121 fn callback_can_observe() {
122 use std::sync::atomic::{AtomicUsize, Ordering};
123 let count = Arc::new(AtomicUsize::new(0));
124 let count_clone = Arc::clone(&count);
125
126 let observer: MessageCallback = Arc::new(move |msg| {
127 count_clone.fetch_add(1, Ordering::Relaxed);
128 Some(msg)
129 });
130
131 apply_callback(make_system_msg(), Some(&observer));
132 apply_callback(make_assistant_msg("test"), Some(&observer));
133
134 assert_eq!(count.load(Ordering::Relaxed), 2);
135 }
136}