use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use crate::types::messages::Message;
pub type MessageCallback = Arc<
dyn Fn(Message) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync,
>;
pub fn sync_callback<F>(f: F) -> MessageCallback
where
F: Fn(Message) + Send + Sync + 'static,
{
Arc::new(move |msg| {
f(msg);
Box::pin(async {})
})
}
pub fn tracing_callback() -> MessageCallback {
sync_callback(|msg| {
match &msg {
Message::System(s) => {
tracing::info!(session_id = %s.session_id, "System message");
}
Message::Assistant(_) => {
tracing::debug!("Assistant message");
}
Message::User(_) => {
tracing::debug!("User message");
}
Message::Result(r) => {
tracing::info!(
stop_reason = %r.stop_reason,
is_error = r.is_error,
"Result message"
);
}
Message::StreamEvent(e) => {
tracing::trace!(
event_type = %e.event_type,
"Stream event"
);
}
}
})
}
#[cfg(test)]
mod tests {
use std::sync::{Arc, Mutex};
use serde_json::Value;
use super::{sync_callback, tracing_callback};
use crate::types::messages::{
AssistantMessage, AssistantMessageInner, Message, ResultMessage, StreamEvent, SystemMessage,
Usage, UserMessage, UserMessageInner,
};
fn make_system_message() -> Message {
Message::System(SystemMessage {
subtype: "init".to_owned(),
session_id: "sess-test".to_owned(),
cwd: "/tmp".to_owned(),
tools: vec![],
mcp_servers: vec![],
model: "gemini-2.5-pro".to_owned(),
extra: Value::Object(Default::default()),
})
}
fn make_assistant_message() -> Message {
Message::Assistant(AssistantMessage {
message: AssistantMessageInner {
role: "assistant".to_owned(),
content: vec![],
model: "gemini-2.5-pro".to_owned(),
stop_reason: "end_turn".to_owned(),
stop_sequence: None,
extra: Value::Object(Default::default()),
},
session_id: "sess-test".to_owned(),
})
}
fn make_user_message() -> Message {
Message::User(UserMessage {
message: UserMessageInner {
role: "user".to_owned(),
content: vec![],
extra: Value::Object(Default::default()),
},
session_id: "sess-test".to_owned(),
})
}
fn make_result_message() -> Message {
Message::Result(ResultMessage {
subtype: "success".to_owned(),
is_error: false,
duration_ms: 42.0,
duration_api_ms: 38.0,
num_turns: 1,
session_id: "sess-test".to_owned(),
usage: Usage::default(),
stop_reason: "end_turn".to_owned(),
extra: Value::Object(Default::default()),
})
}
fn make_stream_event_message() -> Message {
Message::StreamEvent(StreamEvent {
event_type: "tool_call_start".to_owned(),
data: Value::Object(Default::default()),
session_id: "sess-test".to_owned(),
})
}
fn all_message_variants() -> Vec<Message> {
vec![
make_system_message(),
make_assistant_message(),
make_user_message(),
make_result_message(),
make_stream_event_message(),
]
}
#[tokio::test]
async fn test_sync_callback_receives_message() {
let captured: Arc<Mutex<Vec<Message>>> = Arc::new(Mutex::new(Vec::new()));
let captured_clone = Arc::clone(&captured);
let cb = sync_callback(move |msg| {
captured_clone
.lock()
.expect("mutex not poisoned")
.push(msg);
});
let msg = make_system_message();
cb(msg.clone()).await;
let messages = captured.lock().expect("mutex not poisoned");
assert_eq!(messages.len(), 1, "callback must be called exactly once");
assert_eq!(
messages[0], msg,
"captured message must equal the one passed in"
);
}
#[tokio::test]
async fn test_sync_callback_receives_multiple_messages() {
let captured: Arc<Mutex<Vec<Message>>> = Arc::new(Mutex::new(Vec::new()));
let captured_clone = Arc::clone(&captured);
let cb = sync_callback(move |msg| {
captured_clone
.lock()
.expect("mutex not poisoned")
.push(msg);
});
let variants = all_message_variants();
for msg in &variants {
cb(msg.clone()).await;
}
let messages = captured.lock().expect("mutex not poisoned");
assert_eq!(
messages.len(),
variants.len(),
"all messages must be captured"
);
for (i, (got, expected)) in messages.iter().zip(variants.iter()).enumerate() {
assert_eq!(got, expected, "message at index {i} must match");
}
}
#[tokio::test]
async fn test_tracing_callback_does_not_panic() {
let cb = tracing_callback();
for msg in all_message_variants() {
cb(msg).await;
}
}
#[tokio::test]
async fn test_tracing_callback_error_result_does_not_panic() {
let cb = tracing_callback();
let error_result = Message::Result(ResultMessage {
subtype: "error".to_owned(),
is_error: true,
duration_ms: 0.0,
duration_api_ms: 0.0,
num_turns: 0,
session_id: "sess-err".to_owned(),
usage: Usage::default(),
stop_reason: "error".to_owned(),
extra: Value::Object(Default::default()),
});
cb(error_result).await;
}
#[tokio::test]
async fn test_callback_is_cloneable() {
let counter: Arc<Mutex<u32>> = Arc::new(Mutex::new(0));
let counter_clone = Arc::clone(&counter);
let cb = sync_callback(move |_msg| {
*counter_clone.lock().expect("mutex not poisoned") += 1;
});
let cb2 = Arc::clone(&cb);
cb(make_system_message()).await;
cb2(make_result_message()).await;
let count = *counter.lock().expect("mutex not poisoned");
assert_eq!(count, 2, "both cloned callbacks must share state");
}
#[test]
fn test_callback_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<super::MessageCallback>();
}
}