use crate::transport::TransportLayer;
use crate::{AhpError, AhpNotification, AhpRequest, AhpResponse, Result, TransportConfig};
use async_trait::async_trait;
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::UnixStream;
use tokio::sync::{oneshot, Mutex};
pub struct UnixSocketTransport {
writer: Arc<Mutex<tokio::io::WriteHalf<UnixStream>>>,
_reader_task: Arc<tokio::task::JoinHandle<()>>, pending_requests: Arc<Mutex<HashMap<String, oneshot::Sender<AhpResponse>>>>,
timeout_ms: u64,
}
impl UnixSocketTransport {
pub async fn connect(path: impl AsRef<Path>) -> Result<Self> {
Self::connect_with_config(path, &TransportConfig::default()).await
}
pub async fn connect_with_config(
path: impl AsRef<Path>,
config: &TransportConfig,
) -> Result<Self> {
let stream = UnixStream::connect(path.as_ref()).await.map_err(|e| {
AhpError::Transport(format!(
"Failed to connect to {}: {}",
path.as_ref().display(),
e
))
})?;
let pending_requests: Arc<Mutex<HashMap<String, oneshot::Sender<AhpResponse>>>> =
Arc::new(Mutex::new(HashMap::new()));
let pending_clone = pending_requests.clone();
let (reader, writer) = tokio::io::split(stream);
let reader = BufReader::new(reader);
let reader_task = tokio::spawn(async move {
let mut reader = reader;
loop {
let mut line = String::new();
match reader.read_line(&mut line).await {
Ok(0) => break, Ok(_) => {
if let Ok(response) = serde_json::from_str::<AhpResponse>(&line) {
let mut pending_guard = pending_clone.lock().await;
if let Some(sender) = pending_guard.remove(&response.id) {
let _ = sender.send(response);
}
}
}
Err(_) => break,
}
}
});
let transport = Self {
writer: Arc::new(Mutex::new(writer)),
_reader_task: Arc::new(reader_task),
pending_requests,
timeout_ms: config.timeout_ms,
};
Ok(transport)
}
}
#[async_trait]
impl TransportLayer for UnixSocketTransport {
async fn send_request(&self, request: AhpRequest) -> Result<AhpResponse> {
let (tx, rx) = oneshot::channel();
let request_id = request.id.clone();
let json = serde_json::to_string(&request)?;
{
let mut pending = self.pending_requests.lock().await;
pending.insert(request_id.clone(), tx);
}
let mut writer = self.writer.lock().await;
if let Err(e) = writer.write_all(json.as_bytes()).await {
self.pending_requests.lock().await.remove(&request_id);
return Err(e.into());
}
if let Err(e) = writer.write_all(b"\n").await {
self.pending_requests.lock().await.remove(&request_id);
return Err(e.into());
}
if let Err(e) = writer.flush().await {
self.pending_requests.lock().await.remove(&request_id);
return Err(e.into());
}
drop(writer);
match tokio::time::timeout(std::time::Duration::from_millis(self.timeout_ms), rx).await {
Ok(Ok(response)) => Ok(response),
Ok(Err(_)) => Err(AhpError::ConnectionClosed),
Err(_) => {
self.pending_requests.lock().await.remove(&request_id);
Err(AhpError::Timeout(self.timeout_ms))
}
}
}
async fn send_notification(&self, notification: AhpNotification) -> Result<()> {
let mut writer = self.writer.lock().await;
let json = serde_json::to_string(¬ification)?;
writer.write_all(json.as_bytes()).await?;
writer.write_all(b"\n").await?;
writer.flush().await?;
Ok(())
}
async fn close(&self) -> Result<()> {
Ok(())
}
}