use crate::error::{Error, ErrorCode};
use crate::transport::{Receiver as TransportReceiver, Sender as TransportSender, Transport};
use crate::types::Message;
use futures_util::TryFutureExt;
use tokio::{
io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter},
sync::mpsc::{self, Receiver, Sender},
};
use tokio_util::sync::CancellationToken;
#[cfg(feature = "server")]
use tokio::io::{Stdin, Stdout};
#[cfg(feature = "client")]
use self::options::StdIoOptions;
#[cfg(feature = "client")]
use tokio::process::{ChildStdin, ChildStdout};
#[cfg(all(feature = "client", target_os = "linux"))]
mod linux;
#[cfg(all(feature = "client", target_os = "windows"))]
mod windows;
#[cfg(feature = "client")]
pub(crate) mod options;
#[cfg(feature = "server")]
pub(crate) struct StdIoServer {
sender: StdIoSender,
receiver: StdIoReceiver,
}
#[cfg(feature = "client")]
pub(crate) struct StdIoClient {
sender: StdIoSender,
receiver: StdIoReceiver,
options: StdIoOptions,
}
pub(crate) struct StdIoSender {
tx: Sender<Message>,
rx: Option<Receiver<Message>>,
}
pub(crate) struct StdIoReceiver {
tx: Sender<Result<Message, Error>>,
rx: Receiver<Result<Message, Error>>,
}
impl Clone for StdIoSender {
#[inline]
fn clone(&self) -> Self {
Self {
tx: self.tx.clone(),
rx: None,
}
}
}
impl StdIoSender {
pub(crate) fn new() -> Self {
let (tx, rx) = mpsc::channel(100);
Self { tx, rx: Some(rx) }
}
pub(crate) fn start<T: AsyncWrite + Unpin + Send + 'static>(
&mut self,
mut writer: BufWriter<T>,
token: CancellationToken,
) {
let Some(mut receiver) = self.rx.take() else {
#[cfg(feature = "tracing")]
tracing::error!(logger = "neva", "The stdout writer already in use");
return;
};
tokio::spawn(async move {
loop {
tokio::select! {
biased;
_ = token.cancelled() => break,
resp = receiver.recv() => {
match resp {
Some(resp) => {
match serde_json::to_vec(&resp) {
Ok(mut json_bytes) => {
json_bytes.push(b'\n');
if let Err(_err) = writer.write_all(&json_bytes).await {
#[cfg(feature = "tracing")]
tracing::error!(
logger = "neva",
"stdout write error: {:?}", _err);
}
let _ = writer.flush().await;
},
Err(_err) => {
#[cfg(feature = "tracing")]
tracing::error!(
logger = "neva",
"Serialization error: {:?}", _err);
}
}
},
None => break,
}
}
}
}
});
}
}
impl StdIoReceiver {
pub(crate) fn new() -> Self {
let (tx, rx) = mpsc::channel(100);
Self { tx, rx }
}
pub(crate) fn start<T: AsyncRead + Unpin + Send + 'static>(
&self,
mut reader: BufReader<T>,
token: CancellationToken,
) {
let tx = self.tx.clone();
tokio::spawn(async move {
let mut line = String::new();
loop {
line.clear();
tokio::select! {
biased;
_ = token.cancelled() => break,
read_line = reader.read_line(&mut line) => {
match read_line {
Ok(0) => break, Ok(_) => {
let req = match serde_json::from_str(&line) {
Ok(req) => Ok(req),
Err(err) => Err(err.into()),
};
if let Err(_e) = tx.send(req).await {
#[cfg(feature = "tracing")]
tracing::error!(logger = "neva", "Failed to send request: {:?}", _e);
break;
}
}
Err(err) => {
let err = Err(err.into());
if let Err(_e) = tx.send(err).await {
#[cfg(feature = "tracing")]
tracing::error!(logger = "neva", "Failed to send error request: {:?}", _e);
}
break;
}
};
}
}
}
});
}
}
#[cfg(feature = "client")]
impl StdIoClient {
pub(crate) fn new(options: StdIoOptions) -> Self {
Self {
receiver: StdIoReceiver::new(),
sender: StdIoSender::new(),
options,
}
}
fn handshake(
&self,
token: CancellationToken,
) -> (BufReader<ChildStdout>, BufWriter<ChildStdin>) {
let options = &self.options;
#[cfg(target_os = "linux")]
let (job, mut child) =
linux::Job::new(options.command, &options.args).expect("Failed to handshake");
#[cfg(target_os = "windows")]
let (job, mut child) =
windows::Job::new(options.command, &options.args).expect("Failed to handshake");
#[cfg(all(not(target_os = "windows"), not(target_os = "linux")))]
let mut child = tokio::process::Command::new(options.command)
.args(options.args)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.spawn()
.expect("Failed to handshake");
let stdin = child
.stdin
.take()
.expect("Failed to handshake: Inaccessible stdin");
let stdout = child
.stdout
.take()
.expect("Failed to handshake: Inaccessible stdout");
#[cfg(feature = "tracing")]
let child_id = child.id();
tokio::task::spawn(async move {
#[cfg(any(target_os = "windows", target_os = "linux"))]
let _job = job;
tokio::select! {
biased;
_ = child.wait() => {}
_ = token.cancelled() => {
if let Err(_e) = child.kill().await {
#[cfg(feature = "tracing")]
tracing::warn!(
logger = "neva",
pid = child_id,
"Failed to kill child process: {:?}", _e);
} else {
let _exit = child.wait().await;
#[cfg(feature = "tracing")]
tracing::trace!(
logger = "neva",
pid = child_id,
"Child exited with status: {:?}", _exit);
}
},
}
});
(BufReader::new(stdout), BufWriter::new(stdin))
}
}
#[cfg(feature = "server")]
impl StdIoServer {
pub(crate) fn new() -> Self {
Self {
receiver: StdIoReceiver::new(),
sender: StdIoSender::new(),
}
}
pub(crate) fn init() -> (BufReader<Stdin>, BufWriter<Stdout>) {
(
BufReader::new(tokio::io::stdin()),
BufWriter::new(tokio::io::stdout()),
)
}
}
impl TransportSender for StdIoSender {
async fn send(&mut self, msg: Message) -> Result<(), Error> {
self.tx
.send(msg)
.map_err(|err| Error::new(ErrorCode::InternalError, err))
.await
}
}
impl TransportReceiver for StdIoReceiver {
async fn recv(&mut self) -> Result<Message, Error> {
self.rx.recv().await.unwrap_or_else(|| {
Err(Error::new(
ErrorCode::InvalidRequest,
"Unexpected end of stream",
))
})
}
}
#[cfg(feature = "client")]
impl Transport for StdIoClient {
type Sender = StdIoSender;
type Receiver = StdIoReceiver;
fn start(&mut self) -> CancellationToken {
let token = CancellationToken::new();
let (reader, writer) = self.handshake(token.clone());
self.receiver.start(reader, token.clone());
self.sender.start(writer, token.clone());
#[cfg(feature = "tracing")]
tracing::info!(logger = "neva", "Connected: stdio");
token
}
#[inline]
fn split(self) -> (Self::Sender, Self::Receiver) {
(self.sender, self.receiver)
}
}
#[cfg(feature = "server")]
impl Transport for StdIoServer {
type Sender = StdIoSender;
type Receiver = StdIoReceiver;
fn start(&mut self) -> CancellationToken {
let token = CancellationToken::new();
let (reader, writer) = StdIoServer::init();
self.receiver.start(reader, token.clone());
self.sender.start(writer, token.clone());
#[cfg(feature = "tracing")]
tracing::info!(logger = "neva", "Listening: stdio");
token
}
#[inline]
fn split(self) -> (Self::Sender, Self::Receiver) {
(self.sender, self.receiver)
}
}
#[cfg(test)]
mod tests {
#[tokio::test]
#[cfg(all(feature = "client", target_os = "windows"))]
async fn it_tests_handshake() {
use super::options::StdIoOptions;
use crate::transport::StdIoClient;
use tokio_util::sync::CancellationToken;
let client = StdIoClient::new(StdIoOptions::new(
"cmd.exe",
["/c", "ping", "127.0.0.1", "-t"],
));
let token = CancellationToken::new();
let (_, _) = client.handshake(token.clone());
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
token.cancel();
let result = tokio::time::timeout(
std::time::Duration::from_secs(2),
tokio::process::Command::new("tasklist").output(),
)
.await
.unwrap();
assert!(
!String::from_utf8_lossy(&result.unwrap().stdout).contains("ping.exe"),
"Ping should be terminated"
);
}
#[tokio::test]
#[cfg(all(feature = "client", target_os = "linux"))]
async fn it_tests_handshake() {
use super::options::StdIoOptions;
use crate::transport::StdIoClient;
use tokio_util::sync::CancellationToken;
let client = StdIoClient::new(StdIoOptions::new("sh", ["-c", "sleep 300"]));
let token = CancellationToken::new();
let (_, _) = client.handshake(token.clone());
token.cancel();
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
let output = tokio::process::Command::new("pgrep")
.arg("-f")
.arg("sleep 300")
.output()
.await
.unwrap();
assert!(output.stdout.is_empty(), "Process still running");
}
}