use crate::errors::{Error, Result};
use async_trait::async_trait;
use futures_core::Stream;
use serde_json::Value;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::sync::Mutex;
#[async_trait]
pub trait Transport: Send + Sync {
async fn connect(&self) -> Result<()>;
async fn write(&self, data: &str) -> Result<()>;
fn read_messages(&self) -> Pin<Box<dyn Stream<Item = Result<Value>> + Send>>;
async fn end_input(&self) -> Result<()>;
async fn interrupt(&self) -> Result<()>;
fn is_ready(&self) -> bool;
async fn close(&self) -> Result<Option<i32>>;
fn collected_stderr(&self) -> String {
String::new()
}
}
pub struct CliTransport {
cli_args: Vec<String>,
cli_path: std::path::PathBuf,
env: std::collections::HashMap<String, String>,
process: std::sync::Mutex<Option<tokio::process::Child>>,
stdin: Mutex<Option<tokio::process::ChildStdin>>,
message_rx: Mutex<Option<tokio::sync::mpsc::Receiver<Result<Value>>>>,
reader_handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
stderr_callback: Option<crate::config::StderrCallback>,
stderr_buf: Arc<std::sync::Mutex<String>>,
cancel: Option<tokio_util::sync::CancellationToken>,
close_timeout: Option<std::time::Duration>,
ready: AtomicBool,
}
impl CliTransport {
pub fn new(
cli_path: std::path::PathBuf,
cli_args: Vec<String>,
env: std::collections::HashMap<String, String>,
stderr_callback: Option<crate::config::StderrCallback>,
cancel: Option<tokio_util::sync::CancellationToken>,
close_timeout: Option<std::time::Duration>,
) -> Self {
Self {
cli_args,
cli_path,
env,
process: std::sync::Mutex::new(None),
stdin: Mutex::new(None),
message_rx: Mutex::new(None),
reader_handle: Mutex::new(None),
stderr_callback,
stderr_buf: Arc::new(std::sync::Mutex::new(String::new())),
cancel,
close_timeout,
ready: AtomicBool::new(false),
}
}
}
fn send_interrupt_signal(pid: u32) {
#[cfg(unix)]
{
use nix::sys::signal::{Signal, kill};
use nix::unistd::Pid;
let _ = kill(Pid::from_raw(pid as i32), Signal::SIGINT);
}
#[cfg(windows)]
{
unsafe {
windows_sys::Win32::System::Console::GenerateConsoleCtrlEvent(
windows_sys::Win32::System::Console::CTRL_BREAK_EVENT,
pid,
);
}
}
}
#[async_trait]
impl Transport for CliTransport {
async fn connect(&self) -> Result<()> {
if self.ready.load(Ordering::Acquire) {
return Err(Error::AlreadyConnected);
}
let mut cmd = tokio::process::Command::new(&self.cli_path);
cmd.args(&self.cli_args)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.envs(&self.env);
#[cfg(windows)]
cmd.creation_flags(windows_sys::Win32::System::Threading::CREATE_NEW_PROCESS_GROUP);
let mut child = cmd.spawn().map_err(Error::SpawnFailed)?;
let child_pid: Option<u32> = child.id();
let stdout = child.stdout.take().ok_or(Error::NotConnected)?;
let stdin = child.stdin.take().ok_or(Error::NotConnected)?;
let stderr = child.stderr.take();
let (tx, rx) = tokio::sync::mpsc::channel::<Result<Value>>(256);
let cancel_token = self.cancel.clone();
let reader_handle = tokio::spawn(async move {
use tokio::io::{AsyncBufReadExt, BufReader};
let reader = BufReader::new(stdout);
let mut lines = reader.lines();
loop {
let line = if let Some(ref token) = cancel_token {
tokio::select! {
_ = token.cancelled() => {
tracing::debug!("Reader cancelled via CancellationToken — sending interrupt");
if let Some(pid) = child_pid {
send_interrupt_signal(pid);
}
break;
}
result = lines.next_line() => result,
}
} else {
lines.next_line().await
};
match line {
Ok(Some(line)) => {
let line = line.trim().to_string();
if line.is_empty() {
continue;
}
match serde_json::from_str::<Value>(&line) {
Ok(value) => {
if tx.send(Ok(value)).await.is_err() {
break;
}
}
Err(e) => {
tracing::warn!("JSONL parse error: {e} — line: {line}");
let _ = tx.send(Err(crate::errors::Error::Json(e))).await;
}
}
}
Ok(None) => break,
Err(e) => {
tracing::error!("stdout read error: {e}");
let _ = tx.send(Err(crate::errors::Error::ReadFailed(e))).await;
break;
}
}
}
});
if let Some(stderr) = stderr {
let buf = Arc::clone(&self.stderr_buf);
let cb = self.stderr_callback.as_ref().map(Arc::clone);
tokio::spawn(async move {
use tokio::io::{AsyncBufReadExt, BufReader};
let reader = BufReader::new(stderr);
let mut lines = reader.lines();
while let Ok(Some(line)) = lines.next_line().await {
if let Some(ref callback) = cb {
callback(&line);
}
let mut guard = buf.lock().unwrap_or_else(|e| e.into_inner());
if !guard.is_empty() {
guard.push('\n');
}
guard.push_str(&line);
}
});
}
*self.process.lock().unwrap_or_else(|e| e.into_inner()) = Some(child);
*self.stdin.lock().await = Some(stdin);
*self.message_rx.lock().await = Some(rx);
*self.reader_handle.lock().await = Some(reader_handle);
self.ready.store(true, Ordering::Release);
Ok(())
}
async fn write(&self, data: &str) -> Result<()> {
use tokio::io::AsyncWriteExt;
let mut guard = self.stdin.lock().await;
let stdin = guard.as_mut().ok_or(Error::NotConnected)?;
stdin
.write_all(data.as_bytes())
.await
.map_err(Error::WriteFailed)?;
stdin.write_all(b"\n").await.map_err(Error::WriteFailed)?;
stdin.flush().await.map_err(Error::WriteFailed)?;
Ok(())
}
fn read_messages(&self) -> Pin<Box<dyn Stream<Item = Result<Value>> + Send>> {
match self.message_rx.try_lock() {
Ok(mut guard) => match guard.take() {
Some(rx) => Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)),
None => Box::pin(tokio_stream::iter(std::iter::once(Err(
crate::errors::Error::TransportClosed,
)))),
},
Err(_) => Box::pin(tokio_stream::iter(std::iter::once(Err(
crate::errors::Error::TransportClosed,
)))),
}
}
async fn end_input(&self) -> Result<()> {
let mut guard = self.stdin.lock().await;
*guard = None;
Ok(())
}
async fn interrupt(&self) -> Result<()> {
if let Some(pid) = self
.process
.lock()
.unwrap_or_else(|e| e.into_inner())
.as_ref()
.and_then(|c| c.id())
{
send_interrupt_signal(pid);
}
Ok(())
}
fn is_ready(&self) -> bool {
self.ready.load(Ordering::Acquire)
}
async fn close(&self) -> Result<Option<i32>> {
self.end_input().await?;
if let Some(handle) = self.reader_handle.lock().await.take() {
let _ = handle.await;
}
let mut child = self
.process
.lock()
.unwrap_or_else(|e| e.into_inner())
.take();
let timeout = self
.close_timeout
.unwrap_or(std::time::Duration::from_secs(5));
let exit_code = if let Some(ref mut child) = child {
match tokio::time::timeout(timeout, child.wait()).await {
Ok(Ok(status)) => status.code(),
_ => {
let _ = child.start_kill();
None
}
}
} else {
None
};
self.ready.store(false, Ordering::Release);
Ok(exit_code)
}
fn collected_stderr(&self) -> String {
self.stderr_buf
.lock()
.unwrap_or_else(|e| e.into_inner())
.clone()
}
}