use parking_lot::Mutex;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::time::Duration;
use bytes::Bytes;
use futures::StreamExt;
use tokio::io::{AsyncRead, AsyncWrite, BufReader};
use tokio::process::Child;
use tokio::sync::{Mutex as TokioMutex, mpsc};
use tokio_util::codec::{FramedRead, FramedWrite, LinesCodec};
use tracing::{debug, error, trace, warn};
use turbomcp_protocol::MessageId;
use turbomcp_transport_traits::{
AtomicMetrics, Transport, TransportCapabilities, TransportConfig, TransportError,
TransportEventEmitter, TransportFactory, TransportMessage, TransportMessageMetadata,
TransportMetrics, TransportResult, TransportState, TransportType, validate_request_size,
validate_response_size,
};
use uuid::Uuid;
type BoxedAsyncRead = Pin<Box<dyn AsyncRead + Send + Sync + 'static>>;
type BoxedAsyncBufRead = BufReader<BoxedAsyncRead>;
type BoxedAsyncWrite = Pin<Box<dyn AsyncWrite + Send + Sync + 'static>>;
type StdinReader = FramedRead<BoxedAsyncBufRead, LinesCodec>;
type StdoutWriter = FramedWrite<BoxedAsyncWrite, LinesCodec>;
enum StreamSource {
ProcessStdio,
Raw {
reader: Option<BoxedAsyncRead>,
writer: Option<BoxedAsyncWrite>,
},
}
impl std::fmt::Debug for StreamSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ProcessStdio => write!(f, "ProcessStdio"),
Self::Raw { reader, writer } => f
.debug_struct("Raw")
.field("reader", &reader.as_ref().map(|_| "<async reader>"))
.field("writer", &writer.as_ref().map(|_| "<async writer>"))
.finish(),
}
}
}
pub struct StdioTransport {
state: Arc<Mutex<TransportState>>,
capabilities: TransportCapabilities,
config: Arc<Mutex<TransportConfig>>,
metrics: Arc<AtomicMetrics>,
event_emitter: TransportEventEmitter,
stream_source: Arc<TokioMutex<StreamSource>>,
stdin_reader: Arc<TokioMutex<Option<StdinReader>>>,
stdout_writer: Arc<TokioMutex<Option<StdoutWriter>>>,
receive_channel: Arc<TokioMutex<Option<mpsc::Receiver<TransportMessage>>>>,
_task_handle: Arc<TokioMutex<Option<tokio::task::JoinHandle<()>>>>,
}
impl std::fmt::Debug for StdioTransport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StdioTransport")
.field("state", &self.state)
.field("capabilities", &self.capabilities)
.field("config", &self.config)
.field("metrics", &self.metrics)
.field("stream_source", &"<StreamSource>")
.field("stdin_reader", &"<StdinReader>")
.field("stdout_writer", &"<StdoutWriter>")
.field("receive_channel", &"<mpsc::Receiver>")
.field("_task_handle", &"<JoinHandle>")
.finish()
}
}
impl StdioTransport {
#[must_use]
pub fn new() -> Self {
let (event_emitter, _) = TransportEventEmitter::new();
Self {
state: Arc::new(Mutex::new(TransportState::Disconnected)),
capabilities: TransportCapabilities {
max_message_size: Some(turbomcp_protocol::MAX_MESSAGE_SIZE),
supports_compression: false,
supports_streaming: true,
supports_bidirectional: true,
supports_multiplexing: false,
compression_algorithms: Vec::new(),
custom: std::collections::HashMap::new(),
},
config: Arc::new(Mutex::new(TransportConfig {
transport_type: TransportType::Stdio,
..Default::default()
})),
metrics: Arc::new(AtomicMetrics::default()),
event_emitter,
stream_source: Arc::new(TokioMutex::new(StreamSource::ProcessStdio)),
stdin_reader: Arc::new(TokioMutex::new(None)),
stdout_writer: Arc::new(TokioMutex::new(None)),
receive_channel: Arc::new(TokioMutex::new(None)),
_task_handle: Arc::new(TokioMutex::new(None)),
}
}
pub fn from_child(child: &mut Child) -> TransportResult<Self> {
let stdin = child.stdin.take().ok_or_else(|| {
TransportError::ConfigurationError(
"Child process stdin was not piped. Use Stdio::piped() when spawning.".to_string(),
)
})?;
let stdout = child.stdout.take().ok_or_else(|| {
TransportError::ConfigurationError(
"Child process stdout was not piped. Use Stdio::piped() when spawning.".to_string(),
)
})?;
Self::from_raw(stdout, stdin)
}
pub fn from_raw<R, W>(reader: R, writer: W) -> TransportResult<Self>
where
R: AsyncRead + Send + Sync + 'static,
W: AsyncWrite + Send + Sync + 'static,
{
let (event_emitter, _) = TransportEventEmitter::new();
let boxed_reader: BoxedAsyncRead = Box::pin(reader);
let boxed_writer: BoxedAsyncWrite = Box::pin(writer);
Ok(Self {
state: Arc::new(Mutex::new(TransportState::Disconnected)),
capabilities: TransportCapabilities {
max_message_size: Some(turbomcp_protocol::MAX_MESSAGE_SIZE),
supports_compression: false,
supports_streaming: true,
supports_bidirectional: true,
supports_multiplexing: false,
compression_algorithms: Vec::new(),
custom: std::collections::HashMap::new(),
},
config: Arc::new(Mutex::new(TransportConfig {
transport_type: TransportType::Stdio,
..Default::default()
})),
metrics: Arc::new(AtomicMetrics::default()),
event_emitter,
stream_source: Arc::new(TokioMutex::new(StreamSource::Raw {
reader: Some(boxed_reader),
writer: Some(boxed_writer),
})),
stdin_reader: Arc::new(TokioMutex::new(None)),
stdout_writer: Arc::new(TokioMutex::new(None)),
receive_channel: Arc::new(TokioMutex::new(None)),
_task_handle: Arc::new(TokioMutex::new(None)),
})
}
#[must_use]
pub fn with_config(config: TransportConfig) -> Self {
let transport = Self::new();
*transport.config.lock() = config;
transport
}
#[must_use]
pub fn with_event_emitter(event_emitter: TransportEventEmitter) -> Self {
let (_, _) = TransportEventEmitter::new();
Self {
state: Arc::new(Mutex::new(TransportState::Disconnected)),
capabilities: TransportCapabilities {
max_message_size: Some(turbomcp_protocol::MAX_MESSAGE_SIZE),
supports_compression: false,
supports_streaming: true,
supports_bidirectional: true,
supports_multiplexing: false,
compression_algorithms: Vec::new(),
custom: std::collections::HashMap::new(),
},
config: Arc::new(Mutex::new(TransportConfig {
transport_type: TransportType::Stdio,
..Default::default()
})),
metrics: Arc::new(AtomicMetrics::default()),
event_emitter,
stream_source: Arc::new(TokioMutex::new(StreamSource::ProcessStdio)),
stdin_reader: Arc::new(TokioMutex::new(None)),
stdout_writer: Arc::new(TokioMutex::new(None)),
receive_channel: Arc::new(TokioMutex::new(None)),
_task_handle: Arc::new(TokioMutex::new(None)),
}
}
fn set_state(&self, new_state: TransportState) {
let mut state = self.state.lock();
if *state != new_state {
trace!("Stdio transport state: {:?} -> {:?}", *state, new_state);
*state = new_state.clone();
match new_state {
TransportState::Connected => {
self.event_emitter
.emit_connected(TransportType::Stdio, "stdio://".to_string());
}
TransportState::Disconnected => {
self.event_emitter.emit_disconnected(
TransportType::Stdio,
"stdio://".to_string(),
None,
);
}
TransportState::Failed { reason } => {
self.event_emitter.emit_disconnected(
TransportType::Stdio,
"stdio://".to_string(),
Some(reason),
);
}
_ => {}
}
}
}
#[allow(dead_code)]
fn heartbeat(&self) {
}
async fn setup_stdio_streams(&self) -> TransportResult<()> {
let mut stream_source = self.stream_source.lock().await;
let mut stdin_reader: StdinReader = match &mut *stream_source {
StreamSource::ProcessStdio => {
let stdin = tokio::io::stdin();
let boxed_stdin: BoxedAsyncRead = Box::pin(stdin);
let buffered_reader: BoxedAsyncBufRead = BufReader::new(boxed_stdin);
let stdout: BoxedAsyncWrite = Box::pin(tokio::io::stdout());
*self.stdout_writer.lock().await = Some(FramedWrite::new(
stdout,
LinesCodec::new_with_max_length(turbomcp_protocol::MAX_MESSAGE_SIZE),
));
FramedRead::new(
buffered_reader,
LinesCodec::new_with_max_length(turbomcp_protocol::MAX_MESSAGE_SIZE),
)
}
StreamSource::Raw { reader, writer } => {
let raw_reader = reader.take().ok_or_else(|| {
TransportError::ConfigurationError(
"Raw reader stream already consumed".to_string(),
)
})?;
let raw_writer = writer.take().ok_or_else(|| {
TransportError::ConfigurationError(
"Raw writer stream already consumed".to_string(),
)
})?;
let buffered_reader: BoxedAsyncBufRead = BufReader::new(raw_reader);
*self.stdout_writer.lock().await = Some(FramedWrite::new(
raw_writer,
LinesCodec::new_with_max_length(turbomcp_protocol::MAX_MESSAGE_SIZE),
));
FramedRead::new(
buffered_reader,
LinesCodec::new_with_max_length(turbomcp_protocol::MAX_MESSAGE_SIZE),
)
}
};
let (tx, rx) = mpsc::channel(1000);
*self.receive_channel.lock().await = Some(rx);
{
let sender = tx;
let event_emitter = self.event_emitter.clone();
let metrics = self.metrics.clone();
let config = self.config.clone();
let task_handle = tokio::spawn(async move {
while let Some(result) = stdin_reader.next().await {
match result {
Ok(line) => {
trace!("Received line: {}", line);
let size = line.len();
let limits = {
let cfg = config.lock();
cfg.limits.clone()
};
if let Err(e) = validate_response_size(size, &limits) {
error!("Response size validation failed: {}", e);
event_emitter.emit_error(
e.clone(),
Some("response size validation".to_string()),
);
continue;
}
match Self::parse_message(&line) {
Ok(message) => {
let size = message.size();
metrics.messages_received.fetch_add(1, Ordering::Relaxed);
metrics
.bytes_received
.fetch_add(size as u64, Ordering::Relaxed);
event_emitter.emit_message_received(message.id.clone(), size);
if let Err(e) = sender.send(message).await {
debug!(
error = %e,
"Receive channel closed, stopping reader task"
);
break;
}
}
Err(e) => {
error!("Failed to parse message: {}", e);
event_emitter
.emit_error(e, Some("message parsing".to_string()));
}
}
}
Err(e) => {
error!("Failed to read from stdin: {}", e);
event_emitter.emit_error(
TransportError::ReceiveFailed(e.to_string()),
Some("stdin read".to_string()),
);
break;
}
}
}
debug!("Stdio reader task completed");
});
*self._task_handle.lock().await = Some(task_handle);
}
Ok(())
}
fn parse_message(line: &str) -> TransportResult<TransportMessage> {
let line = line.trim();
if line.is_empty() {
return Err(TransportError::ProtocolError("Empty message".to_string()));
}
let json_value: serde_json::Value = serde_json::from_str(line)
.map_err(|e| TransportError::SerializationFailed(e.to_string()))?;
let message_id = json_value
.get("id")
.and_then(|id| match id {
serde_json::Value::String(s) => Some(MessageId::from(s.clone())),
serde_json::Value::Number(n) => n.as_i64().map(MessageId::from),
_ => None,
})
.unwrap_or_else(|| MessageId::from(Uuid::new_v4()));
let payload = Bytes::from(line.to_string());
let metadata = TransportMessageMetadata::with_content_type("application/json");
Ok(TransportMessage::with_metadata(
message_id, payload, metadata,
))
}
fn serialize_message(message: &TransportMessage) -> TransportResult<String> {
let json_str = std::str::from_utf8(&message.payload)
.map_err(|e| TransportError::SerializationFailed(e.to_string()))?;
if json_str.contains('\n') || json_str.contains('\r') {
return Err(TransportError::ProtocolError(
"Message contains embedded newlines (forbidden by MCP stdio specification)"
.to_string(),
));
}
let _: serde_json::Value = serde_json::from_str(json_str)
.map_err(|e| TransportError::SerializationFailed(e.to_string()))?;
Ok(json_str.to_string())
}
}
impl Transport for StdioTransport {
fn transport_type(&self) -> TransportType {
TransportType::Stdio
}
fn capabilities(&self) -> &TransportCapabilities {
&self.capabilities
}
fn state(&self) -> Pin<Box<dyn Future<Output = TransportState> + Send + '_>> {
Box::pin(async move {
self.state.lock().clone()
})
}
fn connect(&self) -> Pin<Box<dyn Future<Output = TransportResult<()>> + Send + '_>> {
Box::pin(async move {
if matches!(self.state().await, TransportState::Connected) {
return Ok(());
}
self.set_state(TransportState::Connecting);
match self.setup_stdio_streams().await {
Ok(()) => {
self.metrics.connections.fetch_add(1, Ordering::Relaxed);
self.set_state(TransportState::Connected);
debug!("Stdio transport connected");
Ok(())
}
Err(e) => {
self.metrics
.failed_connections
.fetch_add(1, Ordering::Relaxed);
self.set_state(TransportState::Failed {
reason: e.to_string(),
});
error!("Failed to connect stdio transport: {}", e);
Err(e)
}
}
})
}
fn disconnect(&self) -> Pin<Box<dyn Future<Output = TransportResult<()>> + Send + '_>> {
Box::pin(async move {
if matches!(self.state().await, TransportState::Disconnected) {
return Ok(());
}
self.set_state(TransportState::Disconnecting);
*self.stdin_reader.lock().await = None;
*self.stdout_writer.lock().await = None;
*self.receive_channel.lock().await = None;
if let Some(handle) = self._task_handle.lock().await.take() {
handle.abort();
}
self.set_state(TransportState::Disconnected);
debug!("Stdio transport disconnected");
Ok(())
})
}
fn send(
&self,
message: TransportMessage,
) -> Pin<Box<dyn Future<Output = TransportResult<()>> + Send + '_>> {
Box::pin(async move {
let state = self.state().await;
if !matches!(state, TransportState::Connected) {
return Err(TransportError::ConnectionFailed(format!(
"Transport not connected: {state}"
)));
}
let json_line = Self::serialize_message(&message)?;
let size = json_line.len();
let config = self.config.lock().clone();
validate_request_size(size, &config.limits)?;
let mut stdout_writer = self.stdout_writer.lock().await;
if let Some(writer) = stdout_writer.as_mut() {
if let Err(e) = writer.send(json_line).await {
error!("Failed to send message: {}", e);
self.set_state(TransportState::Failed {
reason: e.to_string(),
});
return Err(TransportError::SendFailed(e.to_string()));
}
use futures::SinkExt;
if let Err(e) = SinkExt::<String>::flush(writer).await {
error!("Failed to flush stdout: {}", e);
return Err(TransportError::SendFailed(e.to_string()));
}
self.metrics.messages_sent.fetch_add(1, Ordering::Relaxed);
self.metrics
.bytes_sent
.fetch_add(size as u64, Ordering::Relaxed);
self.event_emitter.emit_message_sent(message.id, size);
trace!("Sent message: {} bytes", size);
Ok(())
} else {
Err(TransportError::SendFailed(
"Stdout writer not available".to_string(),
))
}
})
}
fn receive(
&self,
) -> Pin<Box<dyn Future<Output = TransportResult<Option<TransportMessage>>> + Send + '_>> {
Box::pin(async move {
let state = self.state().await;
if !matches!(state, TransportState::Connected) {
return Err(TransportError::ConnectionFailed(format!(
"Transport not connected: {state}"
)));
}
let mut receive_channel = self.receive_channel.lock().await;
if let Some(receiver) = receive_channel.as_mut() {
match receiver.recv().await {
Some(message) => {
trace!("Received message: {} bytes", message.size());
Ok(Some(message))
}
None => {
warn!("Receive channel disconnected");
self.set_state(TransportState::Failed {
reason: "Receive channel disconnected".to_string(),
});
Err(TransportError::ReceiveFailed(
"Channel disconnected".to_string(),
))
}
}
} else {
Err(TransportError::ReceiveFailed(
"Receive channel not available".to_string(),
))
}
})
}
fn metrics(&self) -> Pin<Box<dyn Future<Output = TransportMetrics> + Send + '_>> {
Box::pin(async move {
self.metrics.snapshot()
})
}
fn endpoint(&self) -> Option<String> {
Some("stdio://".to_string())
}
fn configure(
&self,
config: TransportConfig,
) -> Pin<Box<dyn Future<Output = TransportResult<()>> + Send + '_>> {
Box::pin(async move {
if config.transport_type != TransportType::Stdio {
return Err(TransportError::ConfigurationError(format!(
"Invalid transport type: {:?}",
config.transport_type
)));
}
if config.connect_timeout < Duration::from_millis(100) {
return Err(TransportError::ConfigurationError(
"Connect timeout too small".to_string(),
));
}
*self.config.lock() = config;
debug!("Stdio transport configured");
Ok(())
})
}
}
#[derive(Debug, Default)]
pub struct StdioTransportFactory;
impl StdioTransportFactory {
#[must_use]
pub const fn new() -> Self {
Self
}
}
impl TransportFactory for StdioTransportFactory {
fn transport_type(&self) -> TransportType {
TransportType::Stdio
}
fn create(&self, config: TransportConfig) -> TransportResult<Box<dyn Transport>> {
if config.transport_type != TransportType::Stdio {
return Err(TransportError::ConfigurationError(format!(
"Invalid transport type: {:?}",
config.transport_type
)));
}
let transport = StdioTransport::with_config(config);
Ok(Box::new(transport))
}
fn is_available(&self) -> bool {
true
}
}
impl Default for StdioTransport {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn test_stdio_transport_creation() {
let transport = StdioTransport::new();
assert_eq!(transport.transport_type(), TransportType::Stdio);
assert!(transport.capabilities().supports_streaming);
assert!(transport.capabilities().supports_bidirectional);
}
#[test]
fn test_stdio_transport_with_config() {
let config = TransportConfig {
transport_type: TransportType::Stdio,
connect_timeout: Duration::from_secs(10),
..Default::default()
};
let transport = StdioTransport::with_config(config);
assert_eq!(
transport.config.lock().connect_timeout,
Duration::from_secs(10)
);
}
#[tokio::test]
async fn test_stdio_transport_state_management() {
let transport = StdioTransport::new();
assert_eq!(transport.state().await, TransportState::Disconnected);
}
#[test]
fn test_message_parsing() {
let json_line = r#"{"jsonrpc":"2.0","id":"test-123","method":"test","params":{}}"#;
let message = StdioTransport::parse_message(json_line).unwrap();
assert_eq!(message.id, MessageId::from("test-123"));
assert_eq!(message.content_type(), Some("application/json"));
assert!(!message.payload.is_empty());
}
#[test]
fn test_message_parsing_with_numeric_id() {
let json_line = r#"{"jsonrpc":"2.0","id":42,"method":"test","params":{}}"#;
let message = StdioTransport::parse_message(json_line).unwrap();
assert_eq!(message.id, MessageId::from(42));
}
#[test]
fn test_message_parsing_without_id() {
let json_line = r#"{"jsonrpc":"2.0","method":"notification","params":{}}"#;
let message = StdioTransport::parse_message(json_line).unwrap();
match message.id {
MessageId::Uuid(_) => {} _ => assert!(
matches!(message.id, MessageId::Uuid(_)),
"Expected UUID message ID"
),
}
}
#[test]
fn test_message_parsing_invalid_json() {
let invalid_json = "not json at all";
let result = StdioTransport::parse_message(invalid_json);
assert!(matches!(
result,
Err(TransportError::SerializationFailed(_))
));
}
#[test]
fn test_message_parsing_empty() {
let result = StdioTransport::parse_message("");
assert!(matches!(result, Err(TransportError::ProtocolError(_))));
let result = StdioTransport::parse_message(" ");
assert!(matches!(result, Err(TransportError::ProtocolError(_))));
}
#[test]
fn test_message_serialization() {
let json_str = r#"{"jsonrpc":"2.0","id":"test","method":"ping"}"#;
let payload = Bytes::from(json_str);
let message = TransportMessage::new(MessageId::from("test"), payload);
let serialized = StdioTransport::serialize_message(&message).unwrap();
assert_eq!(serialized, json_str);
}
#[test]
fn test_message_serialization_invalid_utf8() {
let payload = Bytes::from(vec![0xFF, 0xFE, 0xFD]); let message = TransportMessage::new(MessageId::from("test"), payload);
let result = StdioTransport::serialize_message(&message);
assert!(matches!(
result,
Err(TransportError::SerializationFailed(_))
));
}
#[test]
fn test_message_serialization_invalid_json() {
let payload = Bytes::from("not json");
let message = TransportMessage::new(MessageId::from("test"), payload);
let result = StdioTransport::serialize_message(&message);
assert!(matches!(
result,
Err(TransportError::SerializationFailed(_))
));
}
#[test]
fn test_message_serialization_embedded_newline_lf() {
let json_with_newline = r#"{"jsonrpc":"2.0","id":"test","method":"test","params":{"text":"line1
line2"}}"#;
let payload = Bytes::from(json_with_newline);
let message = TransportMessage::new(MessageId::from("test"), payload);
let result = StdioTransport::serialize_message(&message);
assert!(
matches!(result, Err(TransportError::ProtocolError(_))),
"Expected ProtocolError for message with LF, got: {:?}",
result
);
}
#[test]
fn test_message_serialization_embedded_newline_crlf() {
let json_with_crlf = "{\r\n\"jsonrpc\":\"2.0\",\"id\":\"test\"}";
let payload = Bytes::from(json_with_crlf);
let message = TransportMessage::new(MessageId::from("test"), payload);
let result = StdioTransport::serialize_message(&message);
assert!(
matches!(result, Err(TransportError::ProtocolError(_))),
"Expected ProtocolError for message with CRLF, got: {:?}",
result
);
}
#[test]
fn test_message_serialization_embedded_cr() {
let json_with_cr = "{\r\"jsonrpc\":\"2.0\",\"id\":\"test\"}";
let payload = Bytes::from(json_with_cr);
let message = TransportMessage::new(MessageId::from("test"), payload);
let result = StdioTransport::serialize_message(&message);
assert!(
matches!(result, Err(TransportError::ProtocolError(_))),
"Expected ProtocolError for message with CR, got: {:?}",
result
);
}
#[test]
fn test_message_serialization_valid_no_newlines() {
let valid_json =
r#"{"jsonrpc":"2.0","id":"test","method":"test","params":{"text":"single line"}}"#;
let payload = Bytes::from(valid_json);
let message = TransportMessage::new(MessageId::from("test"), payload);
let result = StdioTransport::serialize_message(&message);
assert!(
result.is_ok(),
"Valid message without newlines should be accepted"
);
assert_eq!(result.unwrap(), valid_json);
}
#[test]
fn test_message_serialization_escaped_newlines_allowed() {
let json_with_escaped_newlines = r#"{"jsonrpc":"2.0","id":"test","method":"log","params":{"message":"line1\nline2\ntab:\there"}}"#;
assert!(
!json_with_escaped_newlines.contains('\n'),
"Test setup error: raw string should not contain literal newline bytes"
);
assert!(
!json_with_escaped_newlines.contains('\r'),
"Test setup error: raw string should not contain literal CR bytes"
);
let payload = Bytes::from(json_with_escaped_newlines);
let message = TransportMessage::new(MessageId::from("test"), payload);
let result = StdioTransport::serialize_message(&message);
assert!(
result.is_ok(),
"JSON with ESCAPED newlines (backslash-n) should be ALLOWED per MCP spec. Got: {:?}",
result
);
assert_eq!(result.unwrap(), json_with_escaped_newlines);
}
#[test]
fn test_stdio_factory() {
let factory = StdioTransportFactory::new();
assert_eq!(factory.transport_type(), TransportType::Stdio);
assert!(factory.is_available());
let config = TransportConfig {
transport_type: TransportType::Stdio,
..Default::default()
};
let transport = factory.create(config).unwrap();
assert_eq!(transport.transport_type(), TransportType::Stdio);
}
#[test]
fn test_stdio_factory_invalid_config() {
let factory = StdioTransportFactory::new();
let config = TransportConfig {
transport_type: TransportType::Http, ..Default::default()
};
let result = factory.create(config);
assert!(matches!(result, Err(TransportError::ConfigurationError(_))));
}
#[tokio::test]
async fn test_configuration_validation() {
let transport = StdioTransport::new();
let valid_config = TransportConfig {
transport_type: TransportType::Stdio,
connect_timeout: Duration::from_secs(5),
..Default::default()
};
assert!(transport.configure(valid_config).await.is_ok());
let invalid_config = TransportConfig {
transport_type: TransportType::Http,
..Default::default()
};
let result = transport.configure(invalid_config).await;
assert!(matches!(result, Err(TransportError::ConfigurationError(_))));
let invalid_timeout_config = TransportConfig {
transport_type: TransportType::Stdio,
connect_timeout: Duration::from_millis(50), ..Default::default()
};
let result = transport.configure(invalid_timeout_config).await;
assert!(matches!(result, Err(TransportError::ConfigurationError(_))));
}
#[test]
fn test_from_raw_creation() {
let (client_tx, server_rx) = tokio::io::duplex(1024);
let (server_tx, client_rx) = tokio::io::duplex(1024);
let transport = StdioTransport::from_raw(server_rx, server_tx).unwrap();
assert_eq!(transport.transport_type(), TransportType::Stdio);
assert!(transport.capabilities().supports_streaming);
assert!(transport.capabilities().supports_bidirectional);
let _client_transport = StdioTransport::from_raw(client_rx, client_tx).unwrap();
}
#[tokio::test]
async fn test_from_raw_connect_and_communicate() {
let (client_tx, server_rx) = tokio::io::duplex(4096);
let (server_tx, client_rx) = tokio::io::duplex(4096);
let server_transport = StdioTransport::from_raw(server_rx, server_tx).unwrap();
let client_transport = StdioTransport::from_raw(client_rx, client_tx).unwrap();
assert_eq!(server_transport.state().await, TransportState::Disconnected);
assert_eq!(client_transport.state().await, TransportState::Disconnected);
server_transport.connect().await.unwrap();
client_transport.connect().await.unwrap();
assert_eq!(server_transport.state().await, TransportState::Connected);
assert_eq!(client_transport.state().await, TransportState::Connected);
server_transport.disconnect().await.unwrap();
client_transport.disconnect().await.unwrap();
assert_eq!(server_transport.state().await, TransportState::Disconnected);
assert_eq!(client_transport.state().await, TransportState::Disconnected);
}
#[test]
fn test_stream_source_debug() {
let process_source = StreamSource::ProcessStdio;
let debug_str = format!("{:?}", process_source);
assert_eq!(debug_str, "ProcessStdio");
}
}