gemini-cli-sdk 0.1.0

Rust SDK wrapping Google's Gemini CLI as a subprocess via JSON-RPC 2.0
Documentation
//! In-memory Transport implementation for unit tests.

use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};

use async_trait::async_trait;
use futures_core::Stream;
use serde_json::Value;
use tokio::sync::{mpsc, Mutex};

use crate::transport::Transport;
use crate::Result;

/// In-memory transport for unit testing.
///
/// Pre-loaded with JSON values that are yielded when `read_messages()` is called.
/// Captures all writes for assertion in tests.
///
/// # Example
///
/// ```rust,no_run
/// use gemini_cli_sdk::testing::{MockTransport, assistant_text};
/// use gemini_cli_sdk::transport::Transport;
///
/// #[tokio::main]
/// async fn main() {
///     let transport = MockTransport::new(vec![assistant_text("Hello!")]);
///     transport.connect().await.unwrap();
///     assert!(transport.is_ready());
/// }
/// ```
pub struct MockTransport {
    /// Pre-loaded messages to drain into the channel on `connect()`.
    messages: Mutex<Vec<Value>>,
    /// All raw strings written via `write()`, in order.
    writes: Mutex<Vec<String>>,
    /// Set to `true` after `connect()` succeeds.
    ready: AtomicBool,
    /// Sender half of the message channel. `None` after `close()` or
    /// `close_stream()` is called.
    tx: Mutex<Option<mpsc::Sender<Result<Value>>>>,
    /// Receiver half. Taken (moved out) when `read_messages()` is called.
    rx: Mutex<Option<mpsc::Receiver<Result<Value>>>>,
}

impl MockTransport {
    /// Create a new mock transport with pre-loaded messages.
    ///
    /// The messages are buffered until `connect()` is called, at which point
    /// they are drained into the internal channel and become readable via
    /// `read_messages()`.
    pub fn new(messages: Vec<Value>) -> Self {
        let (tx, rx) = mpsc::channel(1024);
        Self {
            messages: Mutex::new(messages),
            writes: Mutex::new(Vec::new()),
            ready: AtomicBool::new(false),
            tx: Mutex::new(Some(tx)),
            rx: Mutex::new(Some(rx)),
        }
    }

    /// Return all strings that were written via `write()`, cloned.
    pub async fn captured_writes(&self) -> Vec<String> {
        self.writes.lock().await.clone()
    }

    /// Return the last written string parsed as JSON, or `None` if nothing has
    /// been written yet or the last write was not valid JSON.
    pub async fn last_write_json(&self) -> Option<Value> {
        let writes = self.writes.lock().await;
        writes.last().and_then(|s| serde_json::from_str(s).ok())
    }

    /// Push an additional message into the stream after `connect()`.
    ///
    /// Useful for simulating server-sent events mid-session. Silently drops
    /// the message if the sender has been closed.
    pub async fn push_message(&self, msg: Value) {
        if let Some(tx) = self.tx.lock().await.as_ref() {
            let _ = tx.send(Ok(msg)).await;
        }
    }

    /// Push an error into the stream, simulating a transport failure.
    ///
    /// The error will be yielded as an `Err` item by `read_messages()`.
    pub async fn push_error(&self, err: crate::Error) {
        if let Some(tx) = self.tx.lock().await.as_ref() {
            let _ = tx.send(Err(err)).await;
        }
    }

    /// Drop the sender, causing `read_messages()` to terminate after all
    /// queued items are consumed.
    pub async fn close_stream(&self) {
        *self.tx.lock().await = None;
    }
}

#[async_trait]
impl Transport for MockTransport {
    /// Load pre-configured messages into the channel and mark the transport
    /// as ready. Subsequent calls to `push_message` / `push_error` will also
    /// be visible through `read_messages()`.
    async fn connect(&self) -> Result<()> {
        // Drain pre-loaded messages atomically under the lock, then release
        // before sending so we don't hold `messages` across an await.
        let messages = {
            let mut guard = self.messages.lock().await;
            std::mem::take(&mut *guard)
        };

        if let Some(tx) = self.tx.lock().await.as_ref() {
            for msg in messages {
                // Ignore send errors — if the receiver is gone the test will
                // notice when it tries to read.
                let _ = tx.send(Ok(msg)).await;
            }
        }

        self.ready.store(true, Ordering::Release);
        Ok(())
    }

    /// Append `data` to the captured-writes list.
    ///
    /// Never fails; always returns `Ok(())`.
    async fn write(&self, data: &str) -> Result<()> {
        self.writes.lock().await.push(data.to_string());
        Ok(())
    }

    /// Return a stream over the internal mpsc channel.
    ///
    /// The receiver is moved out on the first call; subsequent calls return an
    /// immediately-terminating empty stream. This mirrors the production
    /// `GeminiTransport` contract where only one reader is expected.
    fn read_messages(&self) -> Pin<Box<dyn Stream<Item = Result<Value>> + Send>> {
        // Try to take the receiver without blocking. If another caller already
        // consumed it, or if the transport has not been constructed yet, return
        // an empty stream rather than panicking.
        let rx = self
            .rx
            .try_lock()
            .ok()
            .and_then(|mut guard| guard.take());

        match rx {
            Some(mut rx) => Box::pin(async_stream::stream! {
                while let Some(item) = rx.recv().await {
                    yield item;
                }
            }),
            None => {
                // No receiver available — return a stream that terminates on
                // the first poll without yielding any items.
                #[allow(unreachable_code)]
                Box::pin(async_stream::stream! {
                    return;
                    // Type-inference anchor: the compiler must know the Item type.
                    yield Ok(serde_json::Value::Null);
                })
            }
        }
    }

    /// No-op: the mock transport has no real stdin to close.
    async fn end_input(&self) -> Result<()> {
        Ok(())
    }

    /// No-op: the mock transport has no real process to interrupt.
    async fn interrupt(&self) -> Result<()> {
        Ok(())
    }

    fn is_ready(&self) -> bool {
        self.ready.load(Ordering::Relaxed)
    }

    /// Mark the transport as not ready and drop the sender, terminating the
    /// stream. Returns `Some(0)` to indicate a clean exit.
    async fn close(&self) -> Result<Option<i32>> {
        self.ready.store(false, Ordering::Release);
        *self.tx.lock().await = None;
        Ok(Some(0))
    }
}

// ── Tests ─────────────────────────────────────────────────────────────────────

#[cfg(test)]
mod tests {
    use tokio_stream::StreamExt;

    use super::*;
    use crate::testing::{assistant_text, tool_call, tool_result};

    // ── connect / is_ready ────────────────────────────────────────────────────

    #[tokio::test]
    async fn test_mock_connect_sets_ready() {
        let t = MockTransport::new(vec![]);
        assert!(!t.is_ready(), "must not be ready before connect");
        t.connect().await.unwrap();
        assert!(t.is_ready(), "must be ready after connect");
    }

    // ── write capture ─────────────────────────────────────────────────────────

    #[tokio::test]
    async fn test_mock_write_capture() {
        let t = MockTransport::new(vec![]);
        t.connect().await.unwrap();

        t.write(r#"{"method":"ping"}"#).await.unwrap();
        t.write(r#"{"method":"pong"}"#).await.unwrap();

        let writes = t.captured_writes().await;
        assert_eq!(writes.len(), 2);
        assert_eq!(writes[0], r#"{"method":"ping"}"#);
        assert_eq!(writes[1], r#"{"method":"pong"}"#);
    }

    // ── read_messages: pre-loaded ─────────────────────────────────────────────

    #[tokio::test]
    async fn test_mock_read_messages_yields_preloaded() {
        let msg1 = assistant_text("Hello");
        let msg2 = assistant_text("World");
        let t = MockTransport::new(vec![msg1.clone(), msg2.clone()]);
        t.connect().await.unwrap();

        // Close the sender so the stream terminates after the pre-loaded items.
        t.close_stream().await;

        let mut stream = t.read_messages();
        let got1 = stream.next().await.expect("first item").unwrap();
        let got2 = stream.next().await.expect("second item").unwrap();
        let done = stream.next().await;

        assert_eq!(got1, msg1);
        assert_eq!(got2, msg2);
        assert!(done.is_none(), "stream must terminate after sender is closed");
    }

    // ── push_message after connect ────────────────────────────────────────────

    #[tokio::test]
    async fn test_mock_push_message_after_connect() {
        let t = MockTransport::new(vec![]);
        t.connect().await.unwrap();

        // Spawn a reader task so the channel is being consumed.
        let t_arc = std::sync::Arc::new(t);
        let reader = {
            let t2 = t_arc.clone();
            tokio::spawn(async move {
                let mut stream = t2.read_messages();
                let mut collected: Vec<Value> = Vec::new();
                while let Some(item) = stream.next().await {
                    collected.push(item.unwrap());
                }
                collected
            })
        };

        // Push messages dynamically, then close the stream.
        t_arc.push_message(assistant_text("Dynamic 1")).await;
        t_arc.push_message(assistant_text("Dynamic 2")).await;
        t_arc.close_stream().await;

        let collected = reader.await.unwrap();
        assert_eq!(collected.len(), 2, "expected exactly 2 dynamically pushed messages");
    }

    // ── close terminates stream ───────────────────────────────────────────────

    #[tokio::test]
    async fn test_mock_close_terminates_stream() {
        let t = MockTransport::new(vec![]);
        t.connect().await.unwrap();

        let mut stream = t.read_messages();

        // close() drops the sender, which unblocks the recv() inside the stream.
        let exit_code = t.close().await.unwrap();
        assert_eq!(exit_code, Some(0));
        assert!(!t.is_ready());

        // The stream must terminate (yield None) after close.
        let next = stream.next().await;
        assert!(next.is_none(), "stream must terminate after close()");
    }

    // ── ScenarioBuilder ───────────────────────────────────────────────────────

    #[test]
    fn test_scenario_builder_exchange() {
        use crate::testing::ScenarioBuilder;

        let transport = ScenarioBuilder::new("sess-42")
            .with_model("gemini-2.0-flash")
            .exchange(vec![assistant_text("Hi"), assistant_text("Bye")])
            .build();

        // build() returns a MockTransport — the pre-loaded count is opaque but
        // the builder itself must not panic and must honour the accessor values.
        let b = ScenarioBuilder::new("sess-42").with_model("gemini-2.0-flash");
        assert_eq!(b.session_id(), "sess-42");
        assert_eq!(b.model(), "gemini-2.0-flash");

        // Suppress "unused variable" warning from the transport above.
        drop(transport);
    }

    // ── Builder shape tests ───────────────────────────────────────────────────

    #[test]
    fn test_assistant_text_builder_shape() {
        let v = assistant_text("Hello, world!");
        assert_eq!(v["jsonrpc"], "2.0");
        assert_eq!(v["method"], "session/update");
        assert_eq!(v["params"]["sessionUpdate"], "agent_message_chunk");
        assert_eq!(v["params"]["content"]["type"], "text");
        assert_eq!(v["params"]["content"]["text"], "Hello, world!");
    }

    #[test]
    fn test_tool_call_builder_shape() {
        let v = tool_call("tc-1", "Read file", "file_read");
        assert_eq!(v["jsonrpc"], "2.0");
        assert_eq!(v["method"], "session/update");
        assert_eq!(v["params"]["sessionUpdate"], "tool_call");
        assert_eq!(v["params"]["toolCallId"], "tc-1");
        assert_eq!(v["params"]["title"], "Read file");
        assert_eq!(v["params"]["kind"], "file_read");
        assert_eq!(v["params"]["status"], "pending");
    }

    #[test]
    fn test_tool_result_builder_shape() {
        let v = tool_result("tc-1", "contents of file");
        assert_eq!(v["jsonrpc"], "2.0");
        assert_eq!(v["method"], "session/update");
        assert_eq!(v["params"]["sessionUpdate"], "tool_call_update");
        assert_eq!(v["params"]["toolCallId"], "tc-1");
        assert_eq!(v["params"]["status"], "completed");
        assert_eq!(v["params"]["content"][0]["type"], "text");
        assert_eq!(v["params"]["content"][0]["text"], "contents of file");
    }
}