gemini-cli-sdk 0.1.0

Rust SDK wrapping Google's Gemini CLI as a subprocess via JSON-RPC 2.0
Documentation
//! Message callback type for real-time message notifications.
//!
//! The [`MessageCallback`] type is the primary extension point for consumers
//! who need to observe messages as they stream — for logging, UI updates,
//! persistence, metrics, or any other side effect — without interfering with
//! the primary stream consumption.
//!
//! # Design
//!
//! Callbacks are `Arc<dyn Fn(Message) -> Pin<Box<dyn Future<Output = ()> + Send>>>`.
//! This makes them:
//!
//! - **Cloneable** — can be shared across tasks via `Arc::clone`.
//! - **Thread-safe** — `Send + Sync` bounds on the inner `Fn`.
//! - **Async-capable** — the callback can await I/O (e.g., writing to a database).
//! - **Zero-overhead for sync work** — use [`sync_callback`] to wrap a plain closure.
//!
//! # Example
//!
//! ```rust
//! use gemini_cli_sdk::callback::{sync_callback, tracing_callback};
//!
//! // Simple logging callback
//! let cb = sync_callback(|msg| {
//!     println!("received: {:?}", msg.session_id());
//! });
//!
//! // Built-in tracing callback
//! let _trace_cb = tracing_callback();
//! ```

use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;

use crate::types::messages::Message;

// ── MessageCallback type alias ──────────────────────────────────────────────

/// Callback invoked for each message received from the agent.
///
/// Used for side effects (logging, UI updates, persistence) while the primary
/// message stream is consumed by the caller.
///
/// # Thread Safety
///
/// `MessageCallback` is `Send + Sync`, so it can be safely shared across
/// threads and tasks via `Arc::clone`.
///
/// # Creating Callbacks
///
/// Use [`sync_callback`] for synchronous side effects, or construct directly
/// with `Arc::new(|msg| Box::pin(async move { ... }))` for async work.
pub type MessageCallback = Arc<
    dyn Fn(Message) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync,
>;

// ── Helper constructors ─────────────────────────────────────────────────────

/// Wrap a synchronous closure in a [`MessageCallback`].
///
/// This is the ergonomic constructor for callbacks that perform only
/// synchronous work. The closure receives ownership of each [`Message`]
/// and returns immediately; the resulting future resolves in a single poll.
///
/// # Example
///
/// ```rust
/// use gemini_cli_sdk::callback::sync_callback;
///
/// let cb = sync_callback(|msg| {
///     println!("message for session {:?}", msg.session_id());
/// });
/// ```
pub fn sync_callback<F>(f: F) -> MessageCallback
where
    F: Fn(Message) + Send + Sync + 'static,
{
    Arc::new(move |msg| {
        f(msg);
        Box::pin(async {})
    })
}

/// Create a [`MessageCallback`] that emits structured [`tracing`] events for
/// every message variant.
///
/// | Variant         | Level   | Fields                          |
/// |-----------------|---------|---------------------------------|
/// | `System`        | `INFO`  | `session_id`                    |
/// | `Assistant`     | `DEBUG` | _(message variant only)_        |
/// | `User`          | `DEBUG` | _(message variant only)_        |
/// | `Result`        | `INFO`  | `stop_reason`, `is_error`       |
/// | `StreamEvent`   | `TRACE` | `event_type`                    |
///
/// # Example
///
/// ```rust
/// use gemini_cli_sdk::callback::tracing_callback;
///
/// let cb = tracing_callback();
/// // Pass `cb` to `ClientConfig::builder().on_message(cb).build()`.
/// ```
pub fn tracing_callback() -> MessageCallback {
    sync_callback(|msg| {
        match &msg {
            Message::System(s) => {
                tracing::info!(session_id = %s.session_id, "System message");
            }
            Message::Assistant(_) => {
                tracing::debug!("Assistant message");
            }
            Message::User(_) => {
                tracing::debug!("User message");
            }
            Message::Result(r) => {
                tracing::info!(
                    stop_reason = %r.stop_reason,
                    is_error = r.is_error,
                    "Result message"
                );
            }
            Message::StreamEvent(e) => {
                tracing::trace!(
                    event_type = %e.event_type,
                    "Stream event"
                );
            }
        }
    })
}

// ── Tests ───────────────────────────────────────────────────────────────────

#[cfg(test)]
mod tests {
    use std::sync::{Arc, Mutex};

    use serde_json::Value;

    use super::{sync_callback, tracing_callback};
    use crate::types::messages::{
        AssistantMessage, AssistantMessageInner, Message, ResultMessage, StreamEvent, SystemMessage,
        Usage, UserMessage, UserMessageInner,
    };

    // ── Test fixture helpers ────────────────────────────────────────────

    fn make_system_message() -> Message {
        Message::System(SystemMessage {
            subtype: "init".to_owned(),
            session_id: "sess-test".to_owned(),
            cwd: "/tmp".to_owned(),
            tools: vec![],
            mcp_servers: vec![],
            model: "gemini-2.5-pro".to_owned(),
            extra: Value::Object(Default::default()),
        })
    }

    fn make_assistant_message() -> Message {
        Message::Assistant(AssistantMessage {
            message: AssistantMessageInner {
                role: "assistant".to_owned(),
                content: vec![],
                model: "gemini-2.5-pro".to_owned(),
                stop_reason: "end_turn".to_owned(),
                stop_sequence: None,
                extra: Value::Object(Default::default()),
            },
            session_id: "sess-test".to_owned(),
        })
    }

    fn make_user_message() -> Message {
        Message::User(UserMessage {
            message: UserMessageInner {
                role: "user".to_owned(),
                content: vec![],
                extra: Value::Object(Default::default()),
            },
            session_id: "sess-test".to_owned(),
        })
    }

    fn make_result_message() -> Message {
        Message::Result(ResultMessage {
            subtype: "success".to_owned(),
            is_error: false,
            duration_ms: 42.0,
            duration_api_ms: 38.0,
            num_turns: 1,
            session_id: "sess-test".to_owned(),
            usage: Usage::default(),
            stop_reason: "end_turn".to_owned(),
            extra: Value::Object(Default::default()),
        })
    }

    fn make_stream_event_message() -> Message {
        Message::StreamEvent(StreamEvent {
            event_type: "tool_call_start".to_owned(),
            data: Value::Object(Default::default()),
            session_id: "sess-test".to_owned(),
        })
    }

    /// Collect all five message variants into a `Vec` for exhaustive tests.
    fn all_message_variants() -> Vec<Message> {
        vec![
            make_system_message(),
            make_assistant_message(),
            make_user_message(),
            make_result_message(),
            make_stream_event_message(),
        ]
    }

    // ── test_sync_callback_receives_message ─────────────────────────────

    /// Verify that a `sync_callback` is invoked exactly once per message and
    /// that the captured message matches what was passed in.
    #[tokio::test]
    async fn test_sync_callback_receives_message() {
        let captured: Arc<Mutex<Vec<Message>>> = Arc::new(Mutex::new(Vec::new()));
        let captured_clone = Arc::clone(&captured);

        let cb = sync_callback(move |msg| {
            captured_clone
                .lock()
                .expect("mutex not poisoned")
                .push(msg);
        });

        let msg = make_system_message();
        // Call the callback and await the returned future.
        cb(msg.clone()).await;

        let messages = captured.lock().expect("mutex not poisoned");
        assert_eq!(messages.len(), 1, "callback must be called exactly once");
        assert_eq!(
            messages[0], msg,
            "captured message must equal the one passed in"
        );
    }

    /// Verify that multiple messages are all captured, in order.
    #[tokio::test]
    async fn test_sync_callback_receives_multiple_messages() {
        let captured: Arc<Mutex<Vec<Message>>> = Arc::new(Mutex::new(Vec::new()));
        let captured_clone = Arc::clone(&captured);

        let cb = sync_callback(move |msg| {
            captured_clone
                .lock()
                .expect("mutex not poisoned")
                .push(msg);
        });

        let variants = all_message_variants();
        for msg in &variants {
            cb(msg.clone()).await;
        }

        let messages = captured.lock().expect("mutex not poisoned");
        assert_eq!(
            messages.len(),
            variants.len(),
            "all messages must be captured"
        );
        for (i, (got, expected)) in messages.iter().zip(variants.iter()).enumerate() {
            assert_eq!(got, expected, "message at index {i} must match");
        }
    }

    // ── test_tracing_callback_does_not_panic ────────────────────────────

    /// Verify that `tracing_callback` handles all `Message` variants without
    /// panicking. No tracing subscriber is installed — the macros are no-ops
    /// when there is no subscriber, which is exactly the behavior we rely on.
    #[tokio::test]
    async fn test_tracing_callback_does_not_panic() {
        let cb = tracing_callback();

        for msg in all_message_variants() {
            // This must not panic for any variant.
            cb(msg).await;
        }
    }

    /// Verify that `tracing_callback` correctly handles the error `Result`
    /// variant without panicking.
    #[tokio::test]
    async fn test_tracing_callback_error_result_does_not_panic() {
        let cb = tracing_callback();

        let error_result = Message::Result(ResultMessage {
            subtype: "error".to_owned(),
            is_error: true,
            duration_ms: 0.0,
            duration_api_ms: 0.0,
            num_turns: 0,
            session_id: "sess-err".to_owned(),
            usage: Usage::default(),
            stop_reason: "error".to_owned(),
            extra: Value::Object(Default::default()),
        });

        cb(error_result).await;
    }

    // ── Cloneability and Send + Sync ────────────────────────────────────

    /// Verify that a `MessageCallback` can be cloned and used from multiple
    /// locations (Arc semantics).
    #[tokio::test]
    async fn test_callback_is_cloneable() {
        let counter: Arc<Mutex<u32>> = Arc::new(Mutex::new(0));
        let counter_clone = Arc::clone(&counter);

        let cb = sync_callback(move |_msg| {
            *counter_clone.lock().expect("mutex not poisoned") += 1;
        });

        // Clone the Arc — both point to the same closure.
        let cb2 = Arc::clone(&cb);

        cb(make_system_message()).await;
        cb2(make_result_message()).await;

        let count = *counter.lock().expect("mutex not poisoned");
        assert_eq!(count, 2, "both cloned callbacks must share state");
    }

    /// Compile-time check: `MessageCallback` must be `Send + Sync`.
    ///
    /// This test body is empty — it exists solely to trigger a compile error
    /// if the type bounds regress.
    #[test]
    fn test_callback_is_send_sync() {
        fn assert_send_sync<T: Send + Sync>() {}
        assert_send_sync::<super::MessageCallback>();
    }
}