use std::io;
use std::path::PathBuf;
use serde::{Deserialize, Serialize};
use tokio::net::UnixStream;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ChildTransportInfo {
NamedSockets {
dir: PathBuf,
num_slots: usize,
},
#[cfg(target_os = "linux")]
AbstractSockets {
prefix: String,
num_slots: usize,
},
}
pub struct NamedSocketTransport {
dir: PathBuf,
sockets: Vec<UnixStream>,
listeners: Vec<tokio::net::UnixListener>,
is_parent: bool,
}
impl NamedSocketTransport {
pub async fn create(num_slots: usize) -> io::Result<(Self, ChildTransportInfo)> {
use std::os::unix::net::UnixListener as StdUnixListener;
use tokio::net::UnixListener;
let dir = std::env::temp_dir().join(format!("coglet-{}", std::process::id()));
std::fs::create_dir_all(&dir)?;
tracing::debug!(transport_type = "named", dir = %dir.display(), num_slots, "Creating slot transport");
let mut listeners = Vec::with_capacity(num_slots);
for i in 0..num_slots {
let path = dir.join(format!("slot-{}.sock", i));
if path.exists() {
std::fs::remove_file(&path)?;
}
let std_listener = StdUnixListener::bind(&path)?;
std_listener.set_nonblocking(true)?;
let listener = UnixListener::from_std(std_listener)?;
tracing::trace!(slot = i, path = %path.display(), "Bound socket");
listeners.push(listener);
}
let transport = Self {
dir: dir.clone(),
sockets: Vec::with_capacity(num_slots),
listeners,
is_parent: true,
};
let child_info = ChildTransportInfo::NamedSockets {
dir: dir.clone(),
num_slots,
};
Ok((transport, child_info))
}
pub async fn accept_connections(&mut self, num_slots: usize) -> io::Result<()> {
for i in 0..num_slots {
let listener = &self.listeners[i];
tracing::trace!(slot = i, "Waiting for child connection");
let (stream, _) = listener.accept().await?;
self.sockets.push(stream);
tracing::trace!(slot = i, "Child connected");
}
self.listeners.clear();
Ok(())
}
pub async fn connect(dir: PathBuf, num_slots: usize) -> io::Result<Self> {
let mut sockets = Vec::with_capacity(num_slots);
for i in 0..num_slots {
let path = dir.join(format!("slot-{}.sock", i));
tracing::trace!(slot = i, path = %path.display(), "Connecting to socket");
let stream = UnixStream::connect(&path).await?;
sockets.push(stream);
tracing::trace!(slot = i, "Connected");
}
Ok(Self {
dir,
sockets,
listeners: Vec::new(),
is_parent: false,
})
}
pub fn slot_socket(&mut self, slot: usize) -> Option<&mut UnixStream> {
self.sockets.get_mut(slot)
}
pub fn drain_sockets(&mut self) -> Vec<UnixStream> {
std::mem::take(&mut self.sockets)
}
pub fn dir(&self) -> &PathBuf {
&self.dir
}
pub fn num_slots(&self) -> usize {
self.sockets.len()
}
pub fn cleanup(&mut self) -> io::Result<()> {
if self.is_parent && self.dir.exists() {
tracing::debug!(dir = %self.dir.display(), "Cleaning up socket directory");
std::fs::remove_dir_all(&self.dir)?;
}
Ok(())
}
}
impl Drop for NamedSocketTransport {
fn drop(&mut self) {
if let Err(e) = self.cleanup() {
tracing::warn!(error = %e, "Failed to cleanup socket directory");
}
}
}
#[cfg(target_os = "linux")]
pub struct AbstractSocketTransport {
#[allow(dead_code)] prefix: String,
sockets: Vec<UnixStream>,
listeners: Vec<tokio::net::UnixListener>,
}
#[cfg(target_os = "linux")]
impl AbstractSocketTransport {
pub async fn create(num_slots: usize) -> io::Result<(Self, ChildTransportInfo)> {
use std::os::linux::net::SocketAddrExt;
use std::os::unix::net::{SocketAddr, UnixListener as StdUnixListener};
use tokio::net::UnixListener;
let prefix = format!("coglet-{}", std::process::id());
tracing::debug!(transport_type = "abstract", prefix = %prefix, num_slots, "Creating slot transport");
let mut listeners = Vec::with_capacity(num_slots);
for i in 0..num_slots {
let name = format!("{}-{}", prefix, i);
let addr = SocketAddr::from_abstract_name(name.as_bytes())?;
let std_listener = StdUnixListener::bind_addr(&addr)?;
std_listener.set_nonblocking(true)?;
let listener = UnixListener::from_std(std_listener)?;
tracing::trace!(slot = i, name = %name, "Bound abstract socket");
listeners.push(listener);
}
let transport = Self {
prefix: prefix.clone(),
sockets: Vec::with_capacity(num_slots),
listeners,
};
let child_info = ChildTransportInfo::AbstractSockets { prefix, num_slots };
Ok((transport, child_info))
}
pub async fn accept_connections(&mut self, num_slots: usize) -> io::Result<()> {
for i in 0..num_slots {
let listener = &self.listeners[i];
tracing::trace!(slot = i, "Waiting for child connection");
let (stream, _) = listener.accept().await?;
self.sockets.push(stream);
tracing::trace!(slot = i, "Child connected");
}
self.listeners.clear();
Ok(())
}
pub async fn connect(prefix: String, num_slots: usize) -> io::Result<Self> {
use std::os::linux::net::SocketAddrExt;
use std::os::unix::net::SocketAddr;
let mut sockets = Vec::with_capacity(num_slots);
for i in 0..num_slots {
let name = format!("{}-{}", prefix, i);
let addr = SocketAddr::from_abstract_name(name.as_bytes())?;
tracing::trace!(slot = i, name = %name, "Connecting to abstract socket");
let std_stream = std::os::unix::net::UnixStream::connect_addr(&addr)?;
std_stream.set_nonblocking(true)?;
let stream = UnixStream::from_std(std_stream)?;
sockets.push(stream);
tracing::trace!(slot = i, "Connected");
}
Ok(Self {
prefix,
sockets,
listeners: Vec::new(),
})
}
pub fn slot_socket(&mut self, slot: usize) -> Option<&mut UnixStream> {
self.sockets.get_mut(slot)
}
pub fn drain_sockets(&mut self) -> Vec<UnixStream> {
std::mem::take(&mut self.sockets)
}
pub fn num_slots(&self) -> usize {
self.sockets.len()
}
}
pub enum SlotTransport {
Named(NamedSocketTransport),
#[cfg(target_os = "linux")]
Abstract(AbstractSocketTransport),
}
impl SlotTransport {
pub fn slot_socket(&mut self, slot: usize) -> Option<&mut UnixStream> {
match self {
Self::Named(t) => t.slot_socket(slot),
#[cfg(target_os = "linux")]
Self::Abstract(t) => t.slot_socket(slot),
}
}
pub fn drain_sockets(&mut self) -> Vec<UnixStream> {
match self {
Self::Named(t) => t.drain_sockets(),
#[cfg(target_os = "linux")]
Self::Abstract(t) => t.drain_sockets(),
}
}
pub fn num_slots(&self) -> usize {
match self {
Self::Named(t) => t.num_slots(),
#[cfg(target_os = "linux")]
Self::Abstract(t) => t.num_slots(),
}
}
pub async fn accept_connections(&mut self, num_slots: usize) -> io::Result<()> {
match self {
Self::Named(t) => t.accept_connections(num_slots).await,
#[cfg(target_os = "linux")]
Self::Abstract(t) => t.accept_connections(num_slots).await,
}
}
}
pub async fn create_transport(num_slots: usize) -> io::Result<(SlotTransport, ChildTransportInfo)> {
#[cfg(target_os = "linux")]
{
let (transport, info) = AbstractSocketTransport::create(num_slots).await?;
Ok((SlotTransport::Abstract(transport), info))
}
#[cfg(not(target_os = "linux"))]
{
let (transport, info) = NamedSocketTransport::create(num_slots).await?;
Ok((SlotTransport::Named(transport), info))
}
}
pub async fn connect_transport(info: ChildTransportInfo) -> io::Result<SlotTransport> {
match info {
ChildTransportInfo::NamedSockets { dir, num_slots } => {
let transport = NamedSocketTransport::connect(dir, num_slots).await?;
Ok(SlotTransport::Named(transport))
}
#[cfg(target_os = "linux")]
ChildTransportInfo::AbstractSockets { prefix, num_slots } => {
let transport = AbstractSocketTransport::connect(prefix, num_slots).await?;
Ok(SlotTransport::Abstract(transport))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn child_transport_info_roundtrips() {
let info = ChildTransportInfo::NamedSockets {
dir: PathBuf::from("/tmp/coglet-123"),
num_slots: 3,
};
let json = serde_json::to_string(&info).unwrap();
let parsed: ChildTransportInfo = serde_json::from_str(&json).unwrap();
match parsed {
ChildTransportInfo::NamedSockets { dir, num_slots } => {
assert_eq!(dir, PathBuf::from("/tmp/coglet-123"));
assert_eq!(num_slots, 3);
}
#[cfg(target_os = "linux")]
_ => panic!("Wrong variant"),
}
}
#[cfg(target_os = "linux")]
#[test]
fn abstract_socket_info_roundtrips() {
let info = ChildTransportInfo::AbstractSockets {
prefix: "coglet-456".to_string(),
num_slots: 2,
};
let json = serde_json::to_string(&info).unwrap();
let parsed: ChildTransportInfo = serde_json::from_str(&json).unwrap();
match parsed {
ChildTransportInfo::AbstractSockets { prefix, num_slots } => {
assert_eq!(prefix, "coglet-456");
assert_eq!(num_slots, 2);
}
_ => panic!("Wrong variant"),
}
}
}