use super::{
ConnectionContext, DEFAULT_MAX_MESSAGE_SIZE, DEFAULT_STDIO_BUFFER_SIZE, JsonRpcMessage, Result,
Transport, TransportType,
};
use std::str::FromStr;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
use tokio::sync::Mutex;
#[derive(Debug, Clone, Copy)]
pub struct StdioConfig {
pub max_message_size: usize,
pub buffer_size: usize,
}
impl StdioConfig {
#[must_use]
pub fn from_env() -> Self {
Self {
max_message_size: env_or("THOUGHTJACK_MAX_MESSAGE_SIZE", DEFAULT_MAX_MESSAGE_SIZE),
buffer_size: env_or("THOUGHTJACK_STDIO_BUFFER_SIZE", DEFAULT_STDIO_BUFFER_SIZE),
}
}
}
impl Default for StdioConfig {
fn default() -> Self {
Self {
max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
buffer_size: DEFAULT_STDIO_BUFFER_SIZE,
}
}
}
pub struct StdioTransport {
reader: Mutex<BufReader<tokio::io::Stdin>>,
writer: Mutex<BufWriter<tokio::io::Stdout>>,
config: StdioConfig,
context: ConnectionContext,
}
impl StdioTransport {
#[must_use]
pub fn new() -> Self {
let config = StdioConfig::from_env();
Self {
reader: Mutex::new(BufReader::with_capacity(
config.buffer_size,
tokio::io::stdin(),
)),
writer: Mutex::new(BufWriter::with_capacity(
config.buffer_size,
tokio::io::stdout(),
)),
config,
context: ConnectionContext::stdio(),
}
}
#[must_use]
pub fn with_config(config: StdioConfig) -> Self {
Self {
reader: Mutex::new(BufReader::with_capacity(
config.buffer_size,
tokio::io::stdin(),
)),
writer: Mutex::new(BufWriter::with_capacity(
config.buffer_size,
tokio::io::stdout(),
)),
config,
context: ConnectionContext::stdio(),
}
}
#[must_use]
pub const fn context(&self) -> &ConnectionContext {
&self.context
}
}
impl Default for StdioTransport {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for StdioTransport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StdioTransport")
.field("config", &self.config)
.field("context", &self.context)
.finish_non_exhaustive()
}
}
#[async_trait::async_trait]
impl Transport for StdioTransport {
async fn send_message(&self, message: &JsonRpcMessage) -> Result<()> {
let serialized = serde_json::to_string(message)?;
let mut writer = self.writer.lock().await;
writer.write_all(serialized.as_bytes()).await?;
writer.write_all(b"\n").await?;
writer.flush().await?;
drop(writer);
Ok(())
}
async fn send_raw(&self, bytes: &[u8]) -> Result<()> {
let mut writer = self.writer.lock().await;
writer.write_all(bytes).await?;
writer.flush().await?;
drop(writer);
Ok(())
}
#[allow(clippy::significant_drop_tightening)] async fn receive_message(&self) -> Result<Option<JsonRpcMessage>> {
let mut reader = self.reader.lock().await;
let read_limit = self.config.max_message_size + 1;
let mut buf: Vec<u8> = Vec::with_capacity(read_limit.min(64 * 1024));
loop {
buf.clear();
let mut overflowed = false;
loop {
let available = reader.fill_buf().await?;
if available.is_empty() {
if buf.is_empty() {
return Ok(None);
}
break;
}
if let Some(pos) = available.iter().position(|&b| b == b'\n') {
if !overflowed {
let remaining_cap = read_limit.saturating_sub(buf.len());
let copy_len = pos.min(remaining_cap);
buf.extend_from_slice(&available[..copy_len]);
if pos > remaining_cap {
overflowed = true;
}
}
reader.consume(pos + 1); break;
}
if !overflowed {
let remaining_cap = read_limit.saturating_sub(buf.len());
if remaining_cap == 0 {
overflowed = true;
} else {
let copy_len = available.len().min(remaining_cap);
buf.extend_from_slice(&available[..copy_len]);
if available.len() > remaining_cap {
overflowed = true;
}
}
}
let consumed = available.len();
reader.consume(consumed);
}
if overflowed {
tracing::warn!(
limit = self.config.max_message_size,
"message exceeds size limit (read capped), skipping"
);
continue;
}
let line = match std::str::from_utf8(&buf) {
Ok(s) => s,
Err(e) => {
tracing::warn!("invalid UTF-8 in message, skipping line: {e}");
continue;
}
};
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
match serde_json::from_str::<JsonRpcMessage>(trimmed) {
Ok(message) => return Ok(Some(message)),
Err(e) => {
tracing::warn!(
error = %e,
line = %sanitize_for_log(trimmed, 200),
"invalid JSON-RPC message, skipping"
);
}
}
}
}
fn transport_type(&self) -> TransportType {
TransportType::Stdio
}
async fn finalize_response(&self) -> Result<()> {
Ok(())
}
fn connection_context(&self) -> ConnectionContext {
self.context.clone()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
fn sanitize_for_log(input: &str, max_len: usize) -> String {
input
.chars()
.take(max_len)
.map(|c| {
if c.is_control() && c != '\t' {
'\u{FFFD}'
} else {
c
}
})
.collect()
}
fn env_or<T: FromStr>(name: &str, default: T) -> T {
match std::env::var(name) {
Ok(v) => v.parse().unwrap_or_else(|_| {
tracing::warn!(name, value = %v, "invalid env var value, using default");
default
}),
Err(_) => default,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stdio_config_default() {
let config = StdioConfig::default();
assert_eq!(config.max_message_size, DEFAULT_MAX_MESSAGE_SIZE);
assert_eq!(config.buffer_size, DEFAULT_STDIO_BUFFER_SIZE);
}
#[test]
fn test_env_or_default() {
let result: usize = env_or("THOUGHTJACK_TEST_NONEXISTENT_VAR_12345", 42);
assert_eq!(result, 42);
}
#[test]
fn test_stdio_transport_debug() {
let transport = StdioTransport::new();
let debug = format!("{transport:?}");
assert!(debug.contains("StdioTransport"));
assert!(debug.contains("config"));
}
#[test]
fn test_stdio_transport_type() {
let transport = StdioTransport::new();
assert_eq!(transport.transport_type(), TransportType::Stdio);
}
#[test]
fn test_stdio_context() {
let transport = StdioTransport::new();
let ctx = transport.context();
assert_eq!(ctx.connection_id, 0);
assert!(ctx.remote_addr.is_none());
assert!(ctx.is_exclusive);
}
#[test]
fn sanitize_for_log_truncates_long_input() {
let long_input = "a".repeat(500);
let sanitized = sanitize_for_log(&long_input, 200);
assert_eq!(sanitized.len(), 200);
assert!(sanitized.chars().all(|c| c == 'a'));
}
#[test]
fn sanitize_for_log_replaces_control_chars() {
let input = "hello\x00world\x0Bfoo\tbar";
let sanitized = sanitize_for_log(input, 200);
assert!(sanitized.contains('\u{FFFD}'));
assert!(sanitized.contains('\t'));
assert!(sanitized.contains("hello"));
}
#[test]
fn sanitize_for_log_empty_input() {
let sanitized = sanitize_for_log("", 200);
assert!(sanitized.is_empty());
}
#[test]
fn with_config_applies_custom_values() {
let config = StdioConfig {
max_message_size: 1024,
buffer_size: 512,
};
let transport = StdioTransport::with_config(config);
assert_eq!(transport.config.max_message_size, 1024);
assert_eq!(transport.config.buffer_size, 512);
}
#[test]
fn connection_context_returns_stdio_context() {
let transport = StdioTransport::new();
let ctx = transport.connection_context();
assert_eq!(ctx.connection_id, 0);
assert!(ctx.remote_addr.is_none());
assert!(ctx.is_exclusive);
}
}