use std::time::Duration;
use bytes::BytesMut;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use super::error::IpcError;
pub const DEFAULT_CLIENT_RECV_TIMEOUT: Duration = Duration::from_secs(300);
#[cfg(unix)]
type StreamType = tokio::net::UnixStream;
#[cfg(windows)]
type StreamType = tokio::net::windows::named_pipe::NamedPipeServer;
pub struct IpcConnection {
reader: tokio::io::ReadHalf<StreamType>,
writer: tokio::io::WriteHalf<StreamType>,
read_buf: BytesMut,
recv_timeout: Option<Duration>,
}
#[cfg(windows)]
pub struct IpcClientConnection {
reader: tokio::io::ReadHalf<tokio::net::windows::named_pipe::NamedPipeClient>,
writer: tokio::io::WriteHalf<tokio::net::windows::named_pipe::NamedPipeClient>,
read_buf: BytesMut,
recv_timeout: Option<Duration>,
}
impl IpcConnection {
pub async fn send<T: serde::Serialize>(&mut self, msg: &T) -> Result<(), IpcError> {
let buf = crate::protocol::encode_message(msg)?;
self.writer.write_all(&buf).await?;
self.writer.flush().await?;
Ok(())
}
pub fn set_recv_timeout(&mut self, timeout: Duration) {
self.recv_timeout = Some(timeout);
}
pub fn clear_recv_timeout(&mut self) {
self.recv_timeout = None;
}
pub fn recv_timeout(&self) -> Option<Duration> {
self.recv_timeout
}
pub async fn recv<T: serde::de::DeserializeOwned>(&mut self) -> Result<Option<T>, IpcError> {
match self.recv_timeout {
Some(t) => self.recv_with_timeout(t).await,
None => self.recv_loop().await,
}
}
pub async fn recv_with_timeout<T: serde::de::DeserializeOwned>(
&mut self,
timeout: Duration,
) -> Result<Option<T>, IpcError> {
match tokio::time::timeout(timeout, self.recv_loop()).await {
Ok(result) => result,
Err(_) => Err(IpcError::Timeout(timeout)),
}
}
async fn recv_loop<T: serde::de::DeserializeOwned>(&mut self) -> Result<Option<T>, IpcError> {
loop {
if let Some(msg) = crate::protocol::decode_message::<T>(&mut self.read_buf)? {
return Ok(Some(msg));
}
let mut tmp = [0u8; 4096];
let n = self.reader.read(&mut tmp).await?;
if n == 0 {
if self.read_buf.is_empty() {
return Ok(None);
}
return Err(IpcError::ConnectionClosed);
}
self.read_buf.extend_from_slice(&tmp[..n]);
}
}
}
#[cfg(windows)]
impl IpcClientConnection {
pub async fn send<T: serde::Serialize>(&mut self, msg: &T) -> Result<(), IpcError> {
let buf = crate::protocol::encode_message(msg)?;
self.writer.write_all(&buf).await?;
self.writer.flush().await?;
Ok(())
}
pub fn set_recv_timeout(&mut self, timeout: Duration) {
self.recv_timeout = Some(timeout);
}
pub fn clear_recv_timeout(&mut self) {
self.recv_timeout = None;
}
pub fn recv_timeout(&self) -> Option<Duration> {
self.recv_timeout
}
pub async fn recv<T: serde::de::DeserializeOwned>(&mut self) -> Result<Option<T>, IpcError> {
match self.recv_timeout {
Some(t) => self.recv_with_timeout(t).await,
None => self.recv_loop().await,
}
}
pub async fn recv_with_timeout<T: serde::de::DeserializeOwned>(
&mut self,
timeout: Duration,
) -> Result<Option<T>, IpcError> {
match tokio::time::timeout(timeout, self.recv_loop()).await {
Ok(result) => result,
Err(_) => Err(IpcError::Timeout(timeout)),
}
}
async fn recv_loop<T: serde::de::DeserializeOwned>(&mut self) -> Result<Option<T>, IpcError> {
loop {
if let Some(msg) = crate::protocol::decode_message::<T>(&mut self.read_buf)? {
return Ok(Some(msg));
}
let mut tmp = [0u8; 4096];
let n = self.reader.read(&mut tmp).await?;
if n == 0 {
if self.read_buf.is_empty() {
return Ok(None);
}
return Err(IpcError::ConnectionClosed);
}
self.read_buf.extend_from_slice(&tmp[..n]);
}
}
}
pub struct IpcListener {
inner: ListenerInner,
}
#[cfg(unix)]
struct ListenerInner {
listener: tokio::net::UnixListener,
}
#[cfg(windows)]
struct ListenerInner {
endpoint: String,
pool: std::collections::VecDeque<tokio::net::windows::named_pipe::NamedPipeServer>,
}
impl IpcListener {
pub fn bind(endpoint: &str) -> Result<Self, IpcError> {
#[cfg(unix)]
{
let _ = std::fs::remove_file(endpoint);
if let Some(parent) = std::path::Path::new(endpoint).parent() {
std::fs::create_dir_all(parent)?;
}
let listener = tokio::net::UnixListener::bind(endpoint)?;
Ok(Self {
inner: ListenerInner { listener },
})
}
#[cfg(windows)]
{
use std::collections::VecDeque;
use tokio::net::windows::named_pipe::ServerOptions;
let pool_size = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(4)
.min(16);
let mut pool = VecDeque::with_capacity(pool_size);
for i in 0..pool_size {
let pipe = ServerOptions::new()
.first_pipe_instance(i == 0)
.create(endpoint)?;
pool.push_back(pipe);
}
Ok(Self {
inner: ListenerInner {
endpoint: endpoint.to_string(),
pool,
},
})
}
}
pub async fn accept(&mut self) -> Result<IpcConnection, IpcError> {
#[cfg(unix)]
{
let (stream, _addr) = self.inner.listener.accept().await?;
let (reader, writer) = tokio::io::split(stream);
Ok(IpcConnection {
reader,
writer,
read_buf: BytesMut::with_capacity(4096),
recv_timeout: None,
})
}
#[cfg(windows)]
{
use tokio::net::windows::named_pipe::ServerOptions;
let pipe = self
.inner
.pool
.pop_front()
.expect("pipe pool must not be empty");
pipe.connect().await?;
let replacement = ServerOptions::new()
.first_pipe_instance(false)
.create(&self.inner.endpoint)?;
self.inner.pool.push_back(replacement);
let (reader, writer) = tokio::io::split(pipe);
Ok(IpcConnection {
reader,
writer,
read_buf: BytesMut::with_capacity(4096),
recv_timeout: None,
})
}
}
}
#[cfg(unix)]
pub async fn connect(endpoint: &str) -> Result<IpcConnection, IpcError> {
let stream = tokio::net::UnixStream::connect(endpoint).await?;
let (reader, writer) = tokio::io::split(stream);
Ok(IpcConnection {
reader,
writer,
read_buf: BytesMut::with_capacity(4096),
recv_timeout: None,
})
}
#[cfg(windows)]
pub async fn connect(endpoint: &str) -> Result<IpcClientConnection, IpcError> {
use tokio::net::windows::named_pipe::ClientOptions;
const MAX_PIPE_BUSY_RETRIES: u32 = 50;
const INITIAL_BACKOFF_MS: u64 = 10;
const MAX_BACKOFF_MS: u64 = 500;
let client = {
let mut attempts = 0u32;
let mut backoff_ms = INITIAL_BACKOFF_MS;
loop {
match ClientOptions::new().open(endpoint) {
Ok(client) => break client,
Err(e) if e.raw_os_error() == Some(231) => {
attempts += 1;
if attempts >= MAX_PIPE_BUSY_RETRIES {
return Err(IpcError::Io(std::io::Error::new(
std::io::ErrorKind::ConnectionRefused,
format!(
"cannot connect to daemon at {endpoint}: \
all pipe instances busy after {attempts} retries (~{:.0}s). \
The daemon may be overloaded — reduce parallel compilation jobs \
or restart the daemon with `zccache stop && zccache start`",
backoff_ms as f64 * attempts as f64 / 2000.0
),
)));
}
tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
backoff_ms = (backoff_ms * 2).min(MAX_BACKOFF_MS);
}
Err(e) => return Err(IpcError::Io(e)),
}
}
};
let (reader, writer) = tokio::io::split(client);
Ok(IpcClientConnection {
reader,
writer,
read_buf: BytesMut::with_capacity(4096),
recv_timeout: None,
})
}
pub fn unique_test_endpoint() -> String {
use std::sync::atomic::{AtomicU64, Ordering};
static COUNTER: AtomicU64 = AtomicU64::new(0);
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
let pid = std::process::id();
#[cfg(unix)]
{
format!("/tmp/zccache-test-{pid}-{id}.sock")
}
#[cfg(windows)]
{
format!(r"\\.\pipe\zccache-test-{pid}-{id}")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::{Request, Response};
#[tokio::test]
async fn test_ping_pong() {
let endpoint = unique_test_endpoint();
let mut listener = IpcListener::bind(&endpoint).unwrap();
let server = tokio::spawn(async move {
let mut conn = listener.accept().await.unwrap();
let msg: Option<Request> = conn.recv().await.unwrap();
assert_eq!(msg, Some(Request::Ping));
conn.send(&Response::Pong).await.unwrap();
});
let mut client = connect(&endpoint).await.unwrap();
client.send(&Request::Ping).await.unwrap();
let resp: Option<Response> = client.recv().await.unwrap();
assert_eq!(resp, Some(Response::Pong));
server.await.unwrap();
}
#[tokio::test]
async fn test_multiple_messages() {
let endpoint = unique_test_endpoint();
let mut listener = IpcListener::bind(&endpoint).unwrap();
let server = tokio::spawn(async move {
let mut conn = listener.accept().await.unwrap();
for _ in 0..5 {
let msg: Option<Request> = conn.recv().await.unwrap();
assert_eq!(msg, Some(Request::Ping));
conn.send(&Response::Pong).await.unwrap();
}
});
let mut client = connect(&endpoint).await.unwrap();
for _ in 0..5 {
client.send(&Request::Ping).await.unwrap();
let resp: Option<Response> = client.recv().await.unwrap();
assert_eq!(resp, Some(Response::Pong));
}
server.await.unwrap();
}
#[tokio::test]
async fn test_connection_closed() {
let endpoint = unique_test_endpoint();
let mut listener = IpcListener::bind(&endpoint).unwrap();
let server = tokio::spawn(async move {
let _conn = listener.accept().await.unwrap();
});
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let mut client = connect(&endpoint).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let resp: Result<Option<Response>, _> = client.recv().await;
match resp {
Ok(None) => {}
Err(IpcError::ConnectionClosed) => {}
Err(IpcError::Io(_)) => {}
other => panic!("unexpected result: {other:?}"),
}
server.await.unwrap();
}
#[tokio::test]
async fn test_parallel_connections() {
let endpoint = unique_test_endpoint();
let mut listener = IpcListener::bind(&endpoint).unwrap();
let n = 8;
let server = tokio::spawn(async move {
for _ in 0..n {
let mut conn = listener.accept().await.unwrap();
let msg: Option<Request> = conn.recv().await.unwrap();
assert_eq!(msg, Some(Request::Ping));
conn.send(&Response::Pong).await.unwrap();
}
});
let mut handles = Vec::new();
let ep = endpoint.clone();
for _ in 0..n {
let ep = ep.clone();
handles.push(tokio::spawn(async move {
let mut client = connect(&ep).await.unwrap();
client.send(&Request::Ping).await.unwrap();
let resp: Option<Response> = client.recv().await.unwrap();
assert_eq!(resp, Some(Response::Pong));
}));
}
for h in handles {
h.await.unwrap();
}
server.await.unwrap();
}
}