use std::time::Duration;
use bytes::BytesMut;
use tokio::io::{AsyncWriteExt, ReadHalf, WriteHalf};
use super::error::IpcError;
mod framing;
mod probe;
#[cfg(unix)]
mod unix;
#[cfg(windows)]
mod windows;
use framing::{decode_response_wire, recv_bincode_loop, recv_wire_loop};
#[cfg(unix)]
pub use unix::connect;
#[cfg(windows)]
pub use windows::{connect, IpcClientConnection};
pub const DEFAULT_CLIENT_RECV_TIMEOUT: Duration = Duration::from_secs(300);
#[cfg(unix)]
type StreamType = tokio::net::UnixStream;
#[cfg(windows)]
type StreamType = tokio::net::windows::named_pipe::NamedPipeServer;
pub struct IpcConnection {
pub(super) reader: ReadHalf<StreamType>,
pub(super) writer: WriteHalf<StreamType>,
pub(super) read_buf: BytesMut,
pub(super) recv_timeout: Option<Duration>,
pub(super) next_frame_request_id: u64,
}
impl IpcConnection {
pub async fn try_serve_backend_handle_probe(
&mut self,
daemon: &running_process::broker::backend_handle::DaemonProcess,
) -> Result<bool, IpcError> {
probe::try_serve_backend_handle_probe(
&mut self.reader,
&mut self.writer,
&mut self.read_buf,
daemon,
)
.await
}
pub async fn send<T: serde::Serialize>(&mut self, msg: &T) -> Result<(), IpcError> {
let buf = crate::protocol::encode_message(msg)?;
self.writer.write_all(&buf).await?;
self.writer.flush().await?;
Ok(())
}
pub async fn send_prost<M: prost::Message>(&mut self, msg: &M) -> Result<(), IpcError> {
let buf = crate::protocol::wire_prost::encode_prost_message(msg)?;
self.writer.write_all(&buf).await?;
self.writer.flush().await?;
Ok(())
}
pub async fn send_frame_v1_request<M: prost::Message>(
&mut self,
msg: &M,
) -> Result<u64, IpcError> {
let request_id = self.next_frame_request_id;
self.next_frame_request_id = self.next_frame_request_id.wrapping_add(1);
let buf = crate::protocol::wire_frame::encode_frame_v1_request(msg, request_id)?;
self.writer.write_all(&buf).await?;
self.writer.flush().await?;
Ok(request_id)
}
pub async fn send_frame_v1_response<M: prost::Message>(
&mut self,
msg: &M,
request_id: u64,
) -> Result<(), IpcError> {
let buf = crate::protocol::wire_frame::encode_frame_v1_response(msg, request_id)?;
self.writer.write_all(&buf).await?;
self.writer.flush().await?;
Ok(())
}
pub fn set_recv_timeout(&mut self, timeout: Duration) {
self.recv_timeout = Some(timeout);
}
pub fn clear_recv_timeout(&mut self) {
self.recv_timeout = None;
}
pub fn recv_timeout(&self) -> Option<Duration> {
self.recv_timeout
}
pub async fn recv<T: serde::de::DeserializeOwned>(&mut self) -> Result<Option<T>, IpcError> {
match self.recv_timeout {
Some(t) => self.recv_with_timeout(t).await,
None => self.recv_loop().await,
}
}
pub async fn recv_with_timeout<T: serde::de::DeserializeOwned>(
&mut self,
timeout: Duration,
) -> Result<Option<T>, IpcError> {
match tokio::time::timeout(timeout, self.recv_loop()).await {
Ok(result) => result,
Err(_) => Err(IpcError::Timeout(timeout)),
}
}
pub async fn recv_wire<Bincode, Prost>(
&mut self,
) -> Result<Option<crate::protocol::DecodedWireMessage<Bincode, Prost>>, IpcError>
where
Bincode: serde::de::DeserializeOwned,
Prost: prost::Message + Default,
{
match self.recv_timeout {
Some(t) => self.recv_wire_with_timeout(t).await,
None => self.recv_wire_loop().await,
}
}
pub async fn recv_wire_with_timeout<Bincode, Prost>(
&mut self,
timeout: Duration,
) -> Result<Option<crate::protocol::DecodedWireMessage<Bincode, Prost>>, IpcError>
where
Bincode: serde::de::DeserializeOwned,
Prost: prost::Message + Default,
{
match tokio::time::timeout(timeout, self.recv_wire_loop()).await {
Ok(result) => result,
Err(_) => Err(IpcError::Timeout(timeout)),
}
}
pub async fn send_request(
&mut self,
request: &crate::protocol::Request,
wire: crate::protocol::wire_prost::WireFormat,
) -> Result<(), IpcError> {
match wire {
crate::protocol::wire_prost::WireFormat::BincodeV15 => self.send(request).await,
crate::protocol::wire_prost::WireFormat::ProstV16 => {
let request_id = crate::protocol::wire_prost::default_request_id(request);
let request = crate::protocol::wire_prost::request_to_prost(request, request_id);
self.send_prost(&request).await
}
crate::protocol::wire_prost::WireFormat::FrameV1 => {
let request_id = crate::protocol::wire_prost::default_request_id(request);
let request = crate::protocol::wire_prost::request_to_prost(request, request_id);
self.send_frame_v1_request(&request).await.map(|_| ())
}
}
}
pub async fn recv_response(&mut self) -> Result<Option<crate::protocol::Response>, IpcError> {
let message = self
.recv_wire::<crate::protocol::Response, crate::protocol::wire_prost::zccache_v1::Response>()
.await?;
decode_response_wire(message)
}
pub async fn recv_response_with_timeout(
&mut self,
timeout: Duration,
) -> Result<Option<crate::protocol::Response>, IpcError> {
let message = self
.recv_wire_with_timeout::<crate::protocol::Response, crate::protocol::wire_prost::zccache_v1::Response>(timeout)
.await?;
decode_response_wire(message)
}
async fn recv_loop<T: serde::de::DeserializeOwned>(&mut self) -> Result<Option<T>, IpcError> {
recv_bincode_loop(&mut self.reader, &mut self.read_buf).await
}
async fn recv_wire_loop<Bincode, Prost>(
&mut self,
) -> Result<Option<crate::protocol::DecodedWireMessage<Bincode, Prost>>, IpcError>
where
Bincode: serde::de::DeserializeOwned,
Prost: prost::Message + Default,
{
recv_wire_loop(&mut self.reader, &mut self.read_buf).await
}
}
pub struct IpcListener {
pub(super) inner: ListenerInner,
}
#[cfg(unix)]
pub(super) struct ListenerInner {
listener: tokio::net::UnixListener,
}
#[cfg(windows)]
pub(super) struct ListenerInner {
pub(super) endpoint: String,
pub(super) pool: std::collections::VecDeque<tokio::net::windows::named_pipe::NamedPipeServer>,
}
impl IpcListener {
pub fn bind(endpoint: &str) -> Result<Self, IpcError> {
#[cfg(unix)]
{
let _ = std::fs::remove_file(endpoint);
if let Some(parent) = std::path::Path::new(endpoint).parent() {
std::fs::create_dir_all(parent)?;
}
let listener = tokio::net::UnixListener::bind(endpoint)?;
Ok(Self {
inner: ListenerInner { listener },
})
}
#[cfg(windows)]
{
use std::collections::VecDeque;
use tokio::net::windows::named_pipe::ServerOptions;
let pool_size = std::env::var("ZCCACHE_PIPE_POOL_SIZE")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or_else(|| {
std::thread::available_parallelism()
.map(|n| n.get().saturating_mul(4))
.unwrap_or(64)
.clamp(16, 128)
});
const FIRST_BIND_ATTEMPTS: u32 = 8;
const FIRST_BIND_INITIAL_DELAY_MS: u64 = 20;
const FIRST_BIND_MAX_DELAY_MS: u64 = 160;
let mut pool = VecDeque::with_capacity(pool_size);
let first_pipe = {
let mut attempt = 0u32;
let mut delay_ms = FIRST_BIND_INITIAL_DELAY_MS;
loop {
match ServerOptions::new()
.first_pipe_instance(true)
.create(endpoint)
{
Ok(p) => break p,
Err(e) => {
attempt += 1;
if attempt >= FIRST_BIND_ATTEMPTS {
return Err(e.into());
}
tracing::warn!(
attempt,
max_attempts = FIRST_BIND_ATTEMPTS,
error = %e,
endpoint = %endpoint,
"first pipe instance bind failed; retrying after backoff (issue #774)"
);
std::thread::sleep(std::time::Duration::from_millis(delay_ms));
delay_ms = (delay_ms * 2).min(FIRST_BIND_MAX_DELAY_MS);
}
}
}
};
pool.push_back(first_pipe);
for _ in 1..pool_size {
let pipe = ServerOptions::new()
.first_pipe_instance(false)
.create(endpoint)?;
pool.push_back(pipe);
}
Ok(Self {
inner: ListenerInner {
endpoint: endpoint.to_string(),
pool,
},
})
}
}
pub async fn accept(&mut self) -> Result<IpcConnection, IpcError> {
#[cfg(unix)]
{
let (stream, _addr) = self.inner.listener.accept().await?;
let (reader, writer) = tokio::io::split(stream);
Ok(IpcConnection {
reader,
writer,
read_buf: BytesMut::with_capacity(4096),
recv_timeout: None,
next_frame_request_id: 1,
})
}
#[cfg(windows)]
{
self.accept_windows().await
}
}
}
pub fn unique_test_endpoint() -> String {
use std::sync::atomic::{AtomicU64, Ordering};
static COUNTER: AtomicU64 = AtomicU64::new(0);
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
let pid = std::process::id();
#[cfg(unix)]
{
format!("/tmp/zccache-test-{pid}-{id}.sock")
}
#[cfg(windows)]
{
format!(r"\\.\pipe\zccache-test-{pid}-{id}")
}
}
#[cfg(test)]
mod tests;