use std::io;
use std::net::SocketAddr;
use std::path::Path;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};
pub const DEFAULT_TCP_ADDR: &str = "127.0.0.1:47321";
pub fn default_admin_addr() -> std::path::PathBuf {
#[cfg(target_os = "linux")]
{
linux_runtime_path("admin.sock")
}
#[cfg(target_os = "macos")]
{
let mut p = std::env::temp_dir();
p.push("inferd");
p.push("admin.sock");
p
}
#[cfg(windows)]
{
std::path::PathBuf::from(DEFAULT_ADMIN_PIPE_PATH)
}
#[cfg(not(any(target_os = "linux", target_os = "macos", windows)))]
{
std::path::PathBuf::from("/tmp/inferd/admin.sock")
}
}
pub fn default_v2_addr() -> std::path::PathBuf {
#[cfg(target_os = "linux")]
{
linux_runtime_path("infer.v2.sock")
}
#[cfg(target_os = "macos")]
{
let mut p = std::env::temp_dir();
p.push("inferd");
p.push("infer.v2.sock");
p
}
#[cfg(windows)]
{
std::path::PathBuf::from(DEFAULT_PIPE_V2_PATH)
}
#[cfg(not(any(target_os = "linux", target_os = "macos", windows)))]
{
std::path::PathBuf::from("/tmp/inferd/infer.v2.sock")
}
}
pub fn default_embed_addr() -> std::path::PathBuf {
#[cfg(target_os = "linux")]
{
linux_runtime_path("infer.embed.sock")
}
#[cfg(target_os = "macos")]
{
let mut p = std::env::temp_dir();
p.push("inferd");
p.push("infer.embed.sock");
p
}
#[cfg(windows)]
{
std::path::PathBuf::from(DEFAULT_PIPE_EMBED_PATH)
}
#[cfg(not(any(target_os = "linux", target_os = "macos", windows)))]
{
std::path::PathBuf::from("/tmp/inferd/infer.embed.sock")
}
}
#[cfg(target_os = "linux")]
pub fn linux_runtime_path(leaf: &str) -> std::path::PathBuf {
if let Some(xdg) = std::env::var_os("XDG_RUNTIME_DIR") {
let mut p = std::path::PathBuf::from(xdg);
if !p.as_os_str().is_empty() {
p.push("inferd");
p.push(leaf);
return p;
}
}
if let Some(home) = std::env::var_os("HOME") {
let mut p = std::path::PathBuf::from(home);
if !p.as_os_str().is_empty() {
p.push(".inferd");
p.push("run");
p.push(leaf);
return p;
}
}
let uid = nix::unistd::Uid::current().as_raw();
std::path::PathBuf::from(format!("/tmp/inferd-{uid}/{leaf}"))
}
pub trait Connection: AsyncRead + AsyncWrite + Unpin + Send {
fn transport(&self) -> &'static str;
}
impl Connection for TcpStream {
fn transport(&self) -> &'static str {
"tcp"
}
}
#[cfg(unix)]
impl Connection for tokio::net::UnixStream {
fn transport(&self) -> &'static str {
"unix"
}
}
pub async fn bind_tcp(addr: &str) -> io::Result<TcpListener> {
let parsed: SocketAddr = addr
.parse()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("bad tcp addr: {e}")))?;
TcpListener::bind(parsed).await
}
#[cfg(unix)]
pub async fn bind_uds(path: &Path, group: Option<&str>) -> io::Result<tokio::net::UnixListener> {
use std::os::unix::fs::PermissionsExt;
if let Ok(meta) = std::fs::symlink_metadata(path) {
if meta.file_type().is_symlink() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("uds path is a symlink (refused): {}", path.display()),
));
}
std::fs::remove_file(path)?;
}
let listener = tokio::net::UnixListener::bind(path)?;
let mut perms = std::fs::metadata(path)?.permissions();
perms.set_mode(0o660);
std::fs::set_permissions(path, perms)?;
if let Some(group_name) = group {
chown_to_group(path, group_name)?;
}
Ok(listener)
}
#[cfg(unix)]
pub async fn bind_admin_uds(path: &Path) -> io::Result<tokio::net::UnixListener> {
use std::os::unix::fs::PermissionsExt;
if let Ok(meta) = std::fs::symlink_metadata(path) {
if meta.file_type().is_symlink() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("admin uds path is a symlink (refused): {}", path.display()),
));
}
std::fs::remove_file(path)?;
}
if let Some(parent) = path.parent()
&& !parent.as_os_str().is_empty()
{
std::fs::create_dir_all(parent)?;
}
let listener = tokio::net::UnixListener::bind(path)?;
let mut perms = std::fs::metadata(path)?.permissions();
perms.set_mode(0o600);
std::fs::set_permissions(path, perms)?;
Ok(listener)
}
#[cfg(not(unix))]
pub async fn bind_admin_uds(_path: &Path) -> io::Result<()> {
Err(io::Error::new(
io::ErrorKind::Unsupported,
"Unix domain sockets are not supported on this platform; use bind_admin_pipe",
))
}
#[cfg(windows)]
#[allow(unsafe_code)] pub fn bind_admin_pipe(
path: &str,
first: bool,
) -> io::Result<tokio::net::windows::named_pipe::NamedPipeServer> {
use crate::windows_security::PipeSecurityDescriptor;
use tokio::net::windows::named_pipe::ServerOptions;
let mut sd = PipeSecurityDescriptor::current_user_only()?;
let mut opts = ServerOptions::new();
opts.first_pipe_instance(first);
let server = unsafe { opts.create_with_security_attributes_raw(path, sd.as_attrs_ptr()) }?;
drop(sd);
Ok(server)
}
#[cfg(not(unix))]
pub async fn bind_uds(_path: &Path, _group: Option<&str>) -> io::Result<()> {
Err(io::Error::new(
io::ErrorKind::Unsupported,
"Unix domain sockets are not supported on this platform; use bind_named_pipe or TCP",
))
}
#[cfg(windows)]
pub const DEFAULT_PIPE_PATH: &str = r"\\.\pipe\inferd-infer";
#[cfg(windows)]
pub const DEFAULT_PIPE_V2_PATH: &str = r"\\.\pipe\inferd-infer-v2";
#[cfg(windows)]
pub const DEFAULT_ADMIN_PIPE_PATH: &str = r"\\.\pipe\inferd-admin";
#[cfg(windows)]
pub const DEFAULT_PIPE_EMBED_PATH: &str = r"\\.\pipe\inferd-infer-embed";
#[cfg(windows)]
#[allow(unsafe_code)] pub fn bind_named_pipe(
path: &str,
first: bool,
) -> io::Result<tokio::net::windows::named_pipe::NamedPipeServer> {
use crate::windows_security::PipeSecurityDescriptor;
use tokio::net::windows::named_pipe::ServerOptions;
let mut sd = PipeSecurityDescriptor::current_user_only()?;
let mut opts = ServerOptions::new();
opts.first_pipe_instance(first);
let server = unsafe { opts.create_with_security_attributes_raw(path, sd.as_attrs_ptr()) }?;
drop(sd);
Ok(server)
}
#[cfg(windows)]
impl Connection for tokio::net::windows::named_pipe::NamedPipeServer {
fn transport(&self) -> &'static str {
"pipe"
}
}
#[cfg(unix)]
fn chown_to_group(path: &Path, group_name: &str) -> io::Result<()> {
let group = nix::unistd::Group::from_name(group_name)
.map_err(|e| io::Error::other(format!("getgrnam: {e}")))?
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::NotFound,
format!("group not found: {group_name}"),
)
})?;
nix::unistd::chown(path, None, Some(group.gid))
.map_err(|e| io::Error::other(format!("chown: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[tokio::test]
async fn bind_tcp_accepts_a_connection() {
let listener = bind_tcp("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
let (mut sock, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 4];
sock.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"ping");
sock.write_all(b"pong").await.unwrap();
});
let mut client = TcpStream::connect(addr).await.unwrap();
client.write_all(b"ping").await.unwrap();
let mut buf = [0u8; 4];
client.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"pong");
server.await.unwrap();
}
#[tokio::test]
async fn bind_tcp_rejects_garbage_addr() {
let err = bind_tcp("not-an-addr").await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
}
#[cfg(unix)]
#[tokio::test]
async fn bind_uds_creates_socket_and_accepts() {
use tempfile::tempdir;
let dir = tempdir().unwrap();
let path = dir.path().join("test.sock");
let listener = bind_uds(&path, None).await.unwrap();
let server = tokio::spawn(async move {
let (mut sock, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 4];
sock.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"ping");
});
let mut client = tokio::net::UnixStream::connect(&path).await.unwrap();
client.write_all(b"ping").await.unwrap();
server.await.unwrap();
}
#[cfg(windows)]
#[tokio::test]
async fn bind_named_pipe_accepts_a_connection() {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::windows::named_pipe::ClientOptions;
use std::sync::atomic::{AtomicU64, Ordering};
static COUNTER: AtomicU64 = AtomicU64::new(0);
let pid = std::process::id();
let n = COUNTER.fetch_add(1, Ordering::Relaxed);
let ts = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos();
let path = format!(r"\\.\pipe\inferd-endpoint-test-{pid}-{ts}-{n}");
let server = bind_named_pipe(&path, true).expect("bind named pipe");
let path_for_server = path.clone();
let server_task = tokio::spawn(async move {
server.connect().await.expect("server connect");
let mut s = server;
let mut buf = [0u8; 4];
s.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"ping");
s.write_all(b"pong").await.unwrap();
drop(path_for_server);
});
let mut client = ClientOptions::new()
.open(&path)
.expect("client open named pipe");
client.write_all(b"ping").await.unwrap();
let mut buf = [0u8; 4];
client.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"pong");
server_task.await.unwrap();
}
#[cfg(unix)]
#[tokio::test]
async fn bind_uds_refuses_symlink_path() {
use tempfile::tempdir;
let dir = tempdir().unwrap();
let target = dir.path().join("real.sock");
std::fs::write(&target, b"").unwrap();
let symlink = dir.path().join("link.sock");
std::os::unix::fs::symlink(&target, &symlink).unwrap();
let err = bind_uds(&symlink, None).await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
}
}