use crate::services::remote::protocol::{AgentRequest, AgentResponse};
use std::collections::HashMap;
use std::io;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::{mpsc, oneshot};
use tracing::warn;
const DEFAULT_DATA_CHANNEL_CAPACITY: usize = 64;
const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(10);
pub static TEST_RECV_DELAY_US: AtomicU64 = AtomicU64::new(0);
#[derive(Debug, thiserror::Error)]
pub enum ChannelError {
#[error("IO error: {0}")]
Io(#[from] io::Error),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("Channel closed")]
ChannelClosed,
#[error("Request cancelled")]
Cancelled,
#[error("Request timed out")]
Timeout,
#[error("Remote error: {0}")]
Remote(String),
}
struct PendingRequest {
data_tx: mpsc::Sender<serde_json::Value>,
result_tx: oneshot::Sender<Result<serde_json::Value, String>>,
}
type BoxedReader = Box<dyn AsyncBufRead + Unpin + Send>;
type BoxedWriter = Box<dyn AsyncWrite + Unpin + Send>;
pub struct AgentChannel {
write_tx: mpsc::Sender<String>,
pending: Arc<Mutex<HashMap<u64, PendingRequest>>>,
next_id: AtomicU64,
connected: Arc<std::sync::atomic::AtomicBool>,
runtime_handle: tokio::runtime::Handle,
data_channel_capacity: usize,
request_timeout_ms: AtomicU64,
new_reader_tx: mpsc::Sender<BoxedReader>,
new_writer_tx: mpsc::Sender<BoxedWriter>,
}
impl AgentChannel {
pub fn new(
reader: tokio::io::BufReader<tokio::process::ChildStdout>,
writer: tokio::process::ChildStdin,
) -> Self {
Self::with_capacity(reader, writer, DEFAULT_DATA_CHANNEL_CAPACITY)
}
pub fn with_capacity(
reader: tokio::io::BufReader<tokio::process::ChildStdout>,
writer: tokio::process::ChildStdin,
data_channel_capacity: usize,
) -> Self {
Self::from_transport(reader, writer, data_channel_capacity)
}
pub fn from_transport<R, W>(reader: R, writer: W, data_channel_capacity: usize) -> Self
where
R: AsyncBufRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
let pending: Arc<Mutex<HashMap<u64, PendingRequest>>> =
Arc::new(Mutex::new(HashMap::new()));
let connected = Arc::new(std::sync::atomic::AtomicBool::new(true));
let runtime_handle = tokio::runtime::Handle::current();
let (write_tx, write_rx) = mpsc::channel::<String>(64);
let (new_reader_tx, new_reader_rx) = mpsc::channel::<BoxedReader>(1);
let (new_writer_tx, new_writer_rx) = mpsc::channel::<BoxedWriter>(1);
let connected_write = connected.clone();
tokio::spawn(Self::write_task(
Box::new(writer),
write_rx,
new_writer_rx,
connected_write,
));
let pending_read = pending.clone();
let connected_read = connected.clone();
tokio::spawn(Self::read_task(
Box::new(reader),
new_reader_rx,
pending_read,
connected_read,
));
Self {
write_tx,
pending,
next_id: AtomicU64::new(1),
connected,
runtime_handle,
data_channel_capacity,
request_timeout_ms: AtomicU64::new(DEFAULT_REQUEST_TIMEOUT.as_millis() as u64),
new_reader_tx,
new_writer_tx,
}
}
async fn write_task(
mut writer: BoxedWriter,
mut write_rx: mpsc::Receiver<String>,
mut new_writer_rx: mpsc::Receiver<BoxedWriter>,
connected: Arc<std::sync::atomic::AtomicBool>,
) {
loop {
tokio::select! {
msg = write_rx.recv() => {
let Some(msg) = msg else { break };
let write_ok = writer.write_all(msg.as_bytes()).await.is_ok()
&& writer.flush().await.is_ok();
if !write_ok {
connected.store(false, Ordering::SeqCst);
match new_writer_rx.recv().await {
Some(new_writer) => { writer = new_writer; continue; }
None => break,
}
}
}
new_writer = new_writer_rx.recv() => {
match new_writer {
Some(w) => { writer = w; }
None => break, }
}
}
}
}
async fn read_task(
mut reader: BoxedReader,
mut new_reader_rx: mpsc::Receiver<BoxedReader>,
pending: Arc<Mutex<HashMap<u64, PendingRequest>>>,
connected: Arc<std::sync::atomic::AtomicBool>,
) {
let mut line = String::new();
loop {
line.clear();
tokio::select! {
read_result = reader.read_line(&mut line) => {
match read_result {
Ok(0) | Err(_) => {
connected.store(false, Ordering::SeqCst);
Self::drain_pending(&pending);
match new_reader_rx.recv().await {
Some(new_reader) => { reader = new_reader; continue; }
None => break,
}
}
Ok(_) => {
if let Ok(resp) = serde_json::from_str::<AgentResponse>(&line) {
Self::handle_response(&pending, resp).await;
}
}
}
}
new_reader = new_reader_rx.recv() => {
match new_reader {
Some(r) => {
Self::drain_pending(&pending);
reader = r;
connected.store(true, Ordering::SeqCst);
}
None => break, }
}
}
}
}
fn drain_pending(pending: &Arc<Mutex<HashMap<u64, PendingRequest>>>) {
let mut pending = pending.lock().unwrap();
for (id, req) in pending.drain() {
match req.result_tx.send(Err("connection closed".to_string())) {
Ok(()) => {}
Err(_) => {
warn!("request {id}: receiver dropped during disconnect cleanup");
}
}
}
}
async fn handle_response(
pending: &Arc<Mutex<HashMap<u64, PendingRequest>>>,
resp: AgentResponse,
) {
if let Some(data) = resp.data {
let data_tx = {
let pending = pending.lock().unwrap();
pending.get(&resp.id).map(|req| req.data_tx.clone())
};
if let Some(tx) = data_tx {
if tx.send(data).await.is_err() {
warn!("request {}: data receiver dropped mid-stream", resp.id);
let mut pending = pending.lock().unwrap();
pending.remove(&resp.id);
return;
}
}
}
if resp.result.is_some() || resp.error.is_some() {
let mut pending = pending.lock().unwrap();
if let Some(req) = pending.remove(&resp.id) {
let outcome = if let Some(result) = resp.result {
req.result_tx.send(Ok(result))
} else if let Some(error) = resp.error {
req.result_tx.send(Err(error))
} else {
return;
};
match outcome {
Ok(()) => {}
Err(_) => {
warn!("request {}: result receiver dropped", resp.id);
}
}
}
}
}
pub fn is_connected(&self) -> bool {
self.connected.load(Ordering::SeqCst)
}
pub async fn replace_transport<R, W>(&self, reader: R, writer: W)
where
R: AsyncBufRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
if self.new_writer_tx.send(Box::new(writer)).await.is_err() {
warn!("replace_transport: write task is gone, cannot reconnect");
return;
}
if self.new_reader_tx.send(Box::new(reader)).await.is_err() {
warn!("replace_transport: read task is gone, cannot reconnect");
}
}
pub fn replace_transport_blocking<R, W>(&self, reader: R, writer: W)
where
R: AsyncBufRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
self.runtime_handle
.block_on(self.replace_transport(reader, writer));
while !self.is_connected() {
std::thread::yield_now();
}
}
pub fn set_request_timeout(&self, timeout: Duration) {
self.request_timeout_ms
.store(timeout.as_millis() as u64, Ordering::SeqCst);
}
fn request_timeout(&self) -> Duration {
Duration::from_millis(self.request_timeout_ms.load(Ordering::SeqCst))
}
pub async fn request(
&self,
method: &str,
params: serde_json::Value,
) -> Result<serde_json::Value, ChannelError> {
let (mut data_rx, result_rx) = self.request_streaming(method, params).await?;
let timeout = self.request_timeout();
let result = tokio::time::timeout(timeout, async {
while data_rx.recv().await.is_some() {}
result_rx
.await
.map_err(|_| ChannelError::ChannelClosed)?
.map_err(ChannelError::Remote)
})
.await;
match result {
Ok(inner) => inner,
Err(_elapsed) => {
warn!("request '{}' timed out after {:?}", method, timeout);
self.connected.store(false, Ordering::SeqCst);
Err(ChannelError::Timeout)
}
}
}
pub async fn request_streaming(
&self,
method: &str,
params: serde_json::Value,
) -> Result<
(
mpsc::Receiver<serde_json::Value>,
oneshot::Receiver<Result<serde_json::Value, String>>,
),
ChannelError,
> {
if !self.is_connected() {
return Err(ChannelError::ChannelClosed);
}
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
let (data_tx, data_rx) = mpsc::channel(self.data_channel_capacity);
let (result_tx, result_rx) = oneshot::channel();
{
let mut pending = self.pending.lock().unwrap();
pending.insert(id, PendingRequest { data_tx, result_tx });
}
let req = AgentRequest::new(id, method, params);
self.write_tx
.send(req.to_json_line())
.await
.map_err(|_| ChannelError::ChannelClosed)?;
Ok((data_rx, result_rx))
}
pub fn request_blocking(
&self,
method: &str,
params: serde_json::Value,
) -> Result<serde_json::Value, ChannelError> {
self.runtime_handle.block_on(self.request(method, params))
}
pub async fn request_with_data(
&self,
method: &str,
params: serde_json::Value,
) -> Result<(Vec<serde_json::Value>, serde_json::Value), ChannelError> {
let (mut data_rx, result_rx) = self.request_streaming(method, params).await?;
let timeout = self.request_timeout();
let result = tokio::time::timeout(timeout, async {
let mut data = Vec::new();
while let Some(chunk) = data_rx.recv().await {
data.push(chunk);
let delay_us = TEST_RECV_DELAY_US.load(Ordering::Relaxed);
if delay_us > 0 {
tokio::time::sleep(tokio::time::Duration::from_micros(delay_us)).await;
}
}
let result = result_rx
.await
.map_err(|_| ChannelError::ChannelClosed)?
.map_err(ChannelError::Remote)?;
Ok((data, result))
})
.await;
match result {
Ok(inner) => inner,
Err(_elapsed) => {
warn!("streaming request timed out after {:?}", timeout);
self.connected.store(false, Ordering::SeqCst);
Err(ChannelError::Timeout)
}
}
}
pub fn request_with_data_blocking(
&self,
method: &str,
params: serde_json::Value,
) -> Result<(Vec<serde_json::Value>, serde_json::Value), ChannelError> {
self.runtime_handle
.block_on(self.request_with_data(method, params))
}
#[allow(clippy::type_complexity)]
pub fn request_streaming_blocking(
&self,
method: &str,
params: serde_json::Value,
) -> Result<
(
mpsc::Receiver<serde_json::Value>,
oneshot::Receiver<Result<serde_json::Value, String>>,
),
ChannelError,
> {
self.runtime_handle
.block_on(self.request_streaming(method, params))
}
pub async fn cancel(&self, request_id: u64) -> Result<(), ChannelError> {
use crate::services::remote::protocol::cancel_params;
self.request("cancel", cancel_params(request_id)).await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
}