use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use async_trait::async_trait;
use futures_core::Stream;
use serde_json::Value;
use tokio::sync::{mpsc, Mutex};
use crate::transport::Transport;
use crate::Result;
pub struct MockTransport {
messages: Mutex<Vec<Value>>,
writes: Mutex<Vec<String>>,
ready: AtomicBool,
tx: Mutex<Option<mpsc::Sender<Result<Value>>>>,
rx: Mutex<Option<mpsc::Receiver<Result<Value>>>>,
}
impl MockTransport {
pub fn new(messages: Vec<Value>) -> Self {
let (tx, rx) = mpsc::channel(1024);
Self {
messages: Mutex::new(messages),
writes: Mutex::new(Vec::new()),
ready: AtomicBool::new(false),
tx: Mutex::new(Some(tx)),
rx: Mutex::new(Some(rx)),
}
}
pub async fn captured_writes(&self) -> Vec<String> {
self.writes.lock().await.clone()
}
pub async fn last_write_json(&self) -> Option<Value> {
let writes = self.writes.lock().await;
writes.last().and_then(|s| serde_json::from_str(s).ok())
}
pub async fn push_message(&self, msg: Value) {
if let Some(tx) = self.tx.lock().await.as_ref() {
let _ = tx.send(Ok(msg)).await;
}
}
pub async fn push_error(&self, err: crate::Error) {
if let Some(tx) = self.tx.lock().await.as_ref() {
let _ = tx.send(Err(err)).await;
}
}
pub async fn close_stream(&self) {
*self.tx.lock().await = None;
}
}
#[async_trait]
impl Transport for MockTransport {
async fn connect(&self) -> Result<()> {
let messages = {
let mut guard = self.messages.lock().await;
std::mem::take(&mut *guard)
};
if let Some(tx) = self.tx.lock().await.as_ref() {
for msg in messages {
let _ = tx.send(Ok(msg)).await;
}
}
self.ready.store(true, Ordering::Release);
Ok(())
}
async fn write(&self, data: &str) -> Result<()> {
self.writes.lock().await.push(data.to_string());
Ok(())
}
fn read_messages(&self) -> Pin<Box<dyn Stream<Item = Result<Value>> + Send>> {
let rx = self
.rx
.try_lock()
.ok()
.and_then(|mut guard| guard.take());
match rx {
Some(mut rx) => Box::pin(async_stream::stream! {
while let Some(item) = rx.recv().await {
yield item;
}
}),
None => {
#[allow(unreachable_code)]
Box::pin(async_stream::stream! {
return;
yield Ok(serde_json::Value::Null);
})
}
}
}
async fn end_input(&self) -> Result<()> {
Ok(())
}
async fn interrupt(&self) -> Result<()> {
Ok(())
}
fn is_ready(&self) -> bool {
self.ready.load(Ordering::Relaxed)
}
async fn close(&self) -> Result<Option<i32>> {
self.ready.store(false, Ordering::Release);
*self.tx.lock().await = None;
Ok(Some(0))
}
}
#[cfg(test)]
mod tests {
use tokio_stream::StreamExt;
use super::*;
use crate::testing::{assistant_text, tool_call, tool_result};
#[tokio::test]
async fn test_mock_connect_sets_ready() {
let t = MockTransport::new(vec![]);
assert!(!t.is_ready(), "must not be ready before connect");
t.connect().await.unwrap();
assert!(t.is_ready(), "must be ready after connect");
}
#[tokio::test]
async fn test_mock_write_capture() {
let t = MockTransport::new(vec![]);
t.connect().await.unwrap();
t.write(r#"{"method":"ping"}"#).await.unwrap();
t.write(r#"{"method":"pong"}"#).await.unwrap();
let writes = t.captured_writes().await;
assert_eq!(writes.len(), 2);
assert_eq!(writes[0], r#"{"method":"ping"}"#);
assert_eq!(writes[1], r#"{"method":"pong"}"#);
}
#[tokio::test]
async fn test_mock_read_messages_yields_preloaded() {
let msg1 = assistant_text("Hello");
let msg2 = assistant_text("World");
let t = MockTransport::new(vec![msg1.clone(), msg2.clone()]);
t.connect().await.unwrap();
t.close_stream().await;
let mut stream = t.read_messages();
let got1 = stream.next().await.expect("first item").unwrap();
let got2 = stream.next().await.expect("second item").unwrap();
let done = stream.next().await;
assert_eq!(got1, msg1);
assert_eq!(got2, msg2);
assert!(done.is_none(), "stream must terminate after sender is closed");
}
#[tokio::test]
async fn test_mock_push_message_after_connect() {
let t = MockTransport::new(vec![]);
t.connect().await.unwrap();
let t_arc = std::sync::Arc::new(t);
let reader = {
let t2 = t_arc.clone();
tokio::spawn(async move {
let mut stream = t2.read_messages();
let mut collected: Vec<Value> = Vec::new();
while let Some(item) = stream.next().await {
collected.push(item.unwrap());
}
collected
})
};
t_arc.push_message(assistant_text("Dynamic 1")).await;
t_arc.push_message(assistant_text("Dynamic 2")).await;
t_arc.close_stream().await;
let collected = reader.await.unwrap();
assert_eq!(collected.len(), 2, "expected exactly 2 dynamically pushed messages");
}
#[tokio::test]
async fn test_mock_close_terminates_stream() {
let t = MockTransport::new(vec![]);
t.connect().await.unwrap();
let mut stream = t.read_messages();
let exit_code = t.close().await.unwrap();
assert_eq!(exit_code, Some(0));
assert!(!t.is_ready());
let next = stream.next().await;
assert!(next.is_none(), "stream must terminate after close()");
}
#[test]
fn test_scenario_builder_exchange() {
use crate::testing::ScenarioBuilder;
let transport = ScenarioBuilder::new("sess-42")
.with_model("gemini-2.0-flash")
.exchange(vec![assistant_text("Hi"), assistant_text("Bye")])
.build();
let b = ScenarioBuilder::new("sess-42").with_model("gemini-2.0-flash");
assert_eq!(b.session_id(), "sess-42");
assert_eq!(b.model(), "gemini-2.0-flash");
drop(transport);
}
#[test]
fn test_assistant_text_builder_shape() {
let v = assistant_text("Hello, world!");
assert_eq!(v["jsonrpc"], "2.0");
assert_eq!(v["method"], "session/update");
assert_eq!(v["params"]["sessionUpdate"], "agent_message_chunk");
assert_eq!(v["params"]["content"]["type"], "text");
assert_eq!(v["params"]["content"]["text"], "Hello, world!");
}
#[test]
fn test_tool_call_builder_shape() {
let v = tool_call("tc-1", "Read file", "file_read");
assert_eq!(v["jsonrpc"], "2.0");
assert_eq!(v["method"], "session/update");
assert_eq!(v["params"]["sessionUpdate"], "tool_call");
assert_eq!(v["params"]["toolCallId"], "tc-1");
assert_eq!(v["params"]["title"], "Read file");
assert_eq!(v["params"]["kind"], "file_read");
assert_eq!(v["params"]["status"], "pending");
}
#[test]
fn test_tool_result_builder_shape() {
let v = tool_result("tc-1", "contents of file");
assert_eq!(v["jsonrpc"], "2.0");
assert_eq!(v["method"], "session/update");
assert_eq!(v["params"]["sessionUpdate"], "tool_call_update");
assert_eq!(v["params"]["toolCallId"], "tc-1");
assert_eq!(v["params"]["status"], "completed");
assert_eq!(v["params"]["content"][0]["type"], "text");
assert_eq!(v["params"]["content"][0]["text"], "contents of file");
}
}