use crate::errors::Result;
use crate::transport::Transport;
use async_trait::async_trait;
use futures_core::Stream;
use serde_json::Value;
use std::pin::Pin;
use std::sync::Mutex;
use std::sync::atomic::{AtomicBool, Ordering};
pub struct MockTransport {
events: Mutex<Vec<Value>>,
written: Mutex<Vec<String>>,
message_rx: tokio::sync::Mutex<Option<tokio::sync::mpsc::Receiver<Result<Value>>>>,
ready: AtomicBool,
exit_code: Mutex<Option<i32>>,
interrupt_called: AtomicBool,
}
impl MockTransport {
pub fn new() -> Self {
Self {
events: Mutex::new(vec![]),
written: Mutex::new(vec![]),
message_rx: tokio::sync::Mutex::new(None),
ready: AtomicBool::new(false),
exit_code: Mutex::new(None),
interrupt_called: AtomicBool::new(false),
}
}
pub fn interrupt_called(&self) -> bool {
self.interrupt_called.load(Ordering::Acquire)
}
pub fn enqueue_event(&self, value: Value) {
self.events
.lock()
.unwrap_or_else(|e| e.into_inner())
.push(value);
}
pub fn enqueue_events(&self, values: impl IntoIterator<Item = Value>) {
self.events
.lock()
.unwrap_or_else(|e| e.into_inner())
.extend(values);
}
pub fn enqueue_session(&self, thread_id: &str) {
self.enqueue_event(super::builders::thread_started(thread_id));
self.enqueue_event(super::builders::turn_started());
}
pub fn enqueue_turn_complete(&self, message: &str) {
self.enqueue_event(super::builders::agent_message_completed(
"msg-auto", message,
));
self.enqueue_event(super::builders::turn_completed(100, 0, 50));
}
pub fn set_exit_code(&self, code: i32) {
*self.exit_code.lock().unwrap_or_else(|e| e.into_inner()) = Some(code);
}
pub fn queued_count(&self) -> usize {
self.events.lock().unwrap_or_else(|e| e.into_inner()).len()
}
pub fn written_lines(&self) -> Vec<String> {
self.written
.lock()
.unwrap_or_else(|e| e.into_inner())
.clone()
}
}
impl Default for MockTransport {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Transport for MockTransport {
async fn connect(&self) -> Result<()> {
if self.ready.load(Ordering::Acquire) {
return Err(crate::Error::AlreadyConnected);
}
let events: Vec<Value> = self
.events
.lock()
.unwrap_or_else(|e| e.into_inner())
.drain(..)
.collect();
let (tx, rx) = tokio::sync::mpsc::channel(256);
tokio::spawn(async move {
for event in events {
if tx.send(Ok(event)).await.is_err() {
break;
}
}
});
*self.message_rx.lock().await = Some(rx);
self.ready.store(true, Ordering::Release);
Ok(())
}
async fn write(&self, data: &str) -> Result<()> {
self.written
.lock()
.unwrap_or_else(|e| e.into_inner())
.push(data.to_string());
Ok(())
}
fn read_messages(&self) -> Pin<Box<dyn Stream<Item = Result<Value>> + Send>> {
match self.message_rx.try_lock() {
Ok(mut guard) => match guard.take() {
Some(rx) => Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)),
None => Box::pin(tokio_stream::iter(std::iter::once(Err(
crate::Error::TransportClosed,
)))),
},
Err(_) => Box::pin(tokio_stream::iter(std::iter::once(Err(
crate::Error::TransportClosed,
)))),
}
}
async fn end_input(&self) -> Result<()> {
Ok(())
}
async fn interrupt(&self) -> Result<()> {
self.interrupt_called.store(true, Ordering::Release);
Ok(())
}
fn is_ready(&self) -> bool {
self.ready.load(Ordering::Acquire)
}
async fn close(&self) -> Result<Option<i32>> {
self.ready.store(false, Ordering::Release);
let code = self
.exit_code
.lock()
.unwrap_or_else(|e| e.into_inner())
.unwrap_or(0);
Ok(Some(code))
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio_stream::StreamExt;
#[tokio::test]
async fn mock_transport_lifecycle() {
let mock = MockTransport::new();
assert!(!mock.is_ready());
mock.enqueue_session("test-thread");
mock.enqueue_turn_complete("Hello!");
mock.connect().await.unwrap();
assert!(mock.is_ready());
let mut stream = mock.read_messages();
let mut count = 0;
while let Some(Ok(_)) = stream.next().await {
count += 1;
}
assert_eq!(count, 4);
mock.close().await.unwrap();
assert!(!mock.is_ready());
}
#[tokio::test]
async fn mock_transport_captures_writes() {
let mock = MockTransport::new();
mock.enqueue_event(serde_json::json!({"type": "thread.started", "thread_id": "t1"}));
mock.connect().await.unwrap();
mock.write("test data").await.unwrap();
mock.write("more data").await.unwrap();
let lines = mock.written_lines();
assert_eq!(lines.len(), 2);
assert_eq!(lines[0], "test data");
}
#[tokio::test]
async fn mock_transport_double_connect_fails() {
let mock = MockTransport::new();
mock.connect().await.unwrap();
let result = mock.connect().await;
assert!(result.is_err());
}
}