#![cfg_attr(docsrs, feature(doc_auto_cfg))]
#![deny(missing_docs)]
pub mod error;
mod resolv_addr;
use std::convert::{TryFrom, TryInto};
use std::fmt;
use std::ffi::{OsStr, OsString};
use crate::error::*;
use crate::resolv_addr::ResolvAddr;
#[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
use std::convert::Infallible as Never;
#[cfg(all(target_os = "linux", feature = "enable_systemd"))]
pub(crate) mod systemd_sockets {
use std::fmt;
use std::sync::Mutex;
use libsystemd::activation::FileDescriptor;
use libsystemd::errors::SdError as LibSystemdError;
#[derive(Debug)]
pub(crate) struct Error(&'static Mutex<InitError>);
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
use std::error::Error as _;
let guard = self.0.lock().expect("mutex poisoned");
fmt::Display::fmt(&*guard, f)?;
let mut source_opt = guard.source();
while let Some(source) = source_opt {
write!(f, ": {}", source)?;
source_opt = source.source();
}
Ok(())
}
}
type StoredSocket = Result<Socket, ()>;
impl std::error::Error for Error {}
pub(crate) unsafe fn init(protected: bool) -> Result<(), InitError> {
SYSTEMD_SOCKETS.get_or_try_init(|| SystemdSockets::new(protected, true).map(Ok)).map(drop)
}
pub(crate) fn take(name: &str) -> Result<Option<StoredSocket>, Error> {
let sockets = SYSTEMD_SOCKETS.get_or_init(|| SystemdSockets::new_protected(false).map_err(Mutex::new));
match sockets {
Ok(sockets) => Ok(sockets.take(name)),
Err(error) => Err(Error(error))
}
}
#[derive(Debug)]
pub(crate) enum InitError {
OpenStatus(std::io::Error),
ReadStatus(std::io::Error),
ThreadCountNotFound,
MultipleThreads,
LibSystemd(LibSystemdError),
}
impl From<LibSystemdError> for InitError {
fn from(value: LibSystemdError) -> Self {
Self::LibSystemd(value)
}
}
impl fmt::Display for InitError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::OpenStatus(_) => write!(f, "failed to open /proc/self/status"),
Self::ReadStatus(_) => write!(f, "failed to read /proc/self/status"),
Self::ThreadCountNotFound => write!(f, "/proc/self/status doesn't contain Threads entry"),
Self::MultipleThreads => write!(f, "there is more than one thread running"),
Self::LibSystemd(error) => fmt::Display::fmt(error, f),
}
}
}
impl std::error::Error for InitError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::OpenStatus(error) => Some(error),
Self::ReadStatus(error) => Some(error),
Self::ThreadCountNotFound => None,
Self::MultipleThreads => None,
Self::LibSystemd(error) => error.source(),
}
}
}
pub(crate) enum Socket {
TcpListener(std::net::TcpListener),
}
impl std::convert::TryFrom<FileDescriptor> for Socket {
type Error = ();
fn try_from(value: FileDescriptor) -> Result<Self, Self::Error> {
use libsystemd::activation::IsType;
use std::os::unix::io::{FromRawFd, IntoRawFd, AsRawFd};
fn set_cloexec(fd: std::os::unix::io::RawFd) {
let flags = unsafe { libc::fcntl(fd, libc::F_GETFD) };
if flags != -1 && flags & libc::FD_CLOEXEC == 0 {
unsafe {
libc::fcntl(fd, libc::F_SETFD, flags | libc::FD_CLOEXEC);
}
}
}
if value.is_inet() {
let socket = unsafe { std::net::TcpListener::from_raw_fd(value.into_raw_fd()) };
set_cloexec(socket.as_raw_fd());
Ok(Socket::TcpListener(socket))
} else {
set_cloexec(value.into_raw_fd());
Err(())
}
}
}
struct SystemdSockets(std::sync::Mutex<std::collections::HashMap<String, StoredSocket>>);
impl SystemdSockets {
fn new_protected(explicit: bool) -> Result<Self, InitError> {
unsafe { Self::new(true, explicit) }
}
unsafe fn new(protected: bool, explicit: bool) -> Result<Self, InitError> {
use std::convert::TryFrom;
if explicit {
if std::env::var_os("LISTEN_PID").is_none() && std::env::var_os("LISTEN_FDS").is_none() && std::env::var_os("LISTEN_FDNAMES").is_none() {
return Ok(SystemdSockets(Mutex::new(Default::default())));
}
}
if protected { Self::check_single_thread()? }
let map = libsystemd::activation::receive_descriptors_with_names( protected)?.into_iter().map(|(fd, name)| {
(name, Socket::try_from(fd))
}).collect();
Ok(SystemdSockets(Mutex::new(map)))
}
fn check_single_thread() -> Result<(), InitError> {
use std::io::BufRead;
let status = std::fs::File::open("/proc/self/status").map_err(InitError::OpenStatus)?;
let mut status = std::io::BufReader::new(status);
let mut line = String::new();
loop {
if status.read_line(&mut line).map_err(InitError::ReadStatus)? == 0 {
return Err(InitError::ThreadCountNotFound);
}
if let Some(threads) = line.strip_prefix("Threads:") {
if threads.trim() == "1" {
break;
} else {
return Err(InitError::MultipleThreads);
}
}
line.clear();
}
Ok(())
}
fn take(&self, name: &str) -> Option<StoredSocket> {
self.0.lock().expect("poisoned mutex").remove(name)
}
}
static SYSTEMD_SOCKETS: once_cell::sync::OnceCell<Result<SystemdSockets, Mutex<InitError>>> = once_cell::sync::OnceCell::new();
}
#[derive(Debug)]
#[cfg_attr(feature = "serde", derive(serde_crate::Deserialize), serde(crate = "serde_crate", try_from = "serde_str_helpers::DeserBorrowStr"))]
pub struct SocketAddr(SocketAddrInner);
impl SocketAddr {
pub fn from_systemd_name<T: Into<String>>(name: T) -> Result<Self, ParseError> {
Self::inner_from_systemd_name(name.into(), false)
}
#[cfg(all(target_os = "linux", feature = "enable_systemd"))]
fn inner_from_systemd_name(name: String, prefixed: bool) -> Result<Self, ParseError> {
let real_systemd_name = if prefixed {
&name[SYSTEMD_PREFIX.len()..]
} else {
&name
};
let name_len = real_systemd_name.len();
match real_systemd_name.chars().enumerate().find(|(_, c)| !c.is_ascii() || *c < ' ' || *c == ':') {
None if name_len <= 255 && prefixed => Ok(SocketAddr(SocketAddrInner::Systemd(name))),
None if name_len <= 255 && !prefixed => Ok(SocketAddr(SocketAddrInner::SystemdNoPrefix(name))),
None => Err(ParseErrorInner::LongSocketName { string: name, len: name_len }.into()),
Some((pos, c)) => Err(ParseErrorInner::InvalidCharacter { string: name, c, pos, }.into()),
}
}
#[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
fn inner_from_systemd_name(name: String, _prefixed: bool) -> Result<Self, ParseError> {
Err(ParseError(ParseErrorInner::SystemdUnsupported(name)))
}
pub fn bind(self) -> Result<std::net::TcpListener, BindError> {
match self.0 {
SocketAddrInner::Ordinary(addr) => match std::net::TcpListener::bind(addr) {
Ok(socket) => Ok(socket),
Err(error) => Err(BindErrorInner::BindFailed { addr, error, }.into()),
},
SocketAddrInner::WithHostname(addr) => match std::net::TcpListener::bind(addr.as_str()) {
Ok(socket) => Ok(socket),
Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()),
},
SocketAddrInner::Systemd(socket_name) => Self::get_systemd(socket_name, true).map(|(socket, _)| socket),
SocketAddrInner::SystemdNoPrefix(socket_name) => Self::get_systemd(socket_name, false).map(|(socket, _)| socket),
}
}
#[cfg(feature = "tokio")]
pub async fn bind_tokio(self) -> Result<tokio::net::TcpListener, TokioBindError> {
match self.0 {
SocketAddrInner::Ordinary(addr) => match tokio::net::TcpListener::bind(addr).await {
Ok(socket) => Ok(socket),
Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindFailed { addr, error, }.into())),
},
SocketAddrInner::WithHostname(addr) => match tokio::net::TcpListener::bind(addr.as_str()).await {
Ok(socket) => Ok(socket),
Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())),
},
SocketAddrInner::Systemd(socket_name) => {
let (socket, addr) = Self::get_systemd(socket_name, true)?;
socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
},
SocketAddrInner::SystemdNoPrefix(socket_name) => {
let (socket, addr) = Self::get_systemd(socket_name, false)?;
socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
},
}
}
#[cfg(feature = "tokio_0_2")]
pub async fn bind_tokio_0_2(self) -> Result<tokio_0_2::net::TcpListener, TokioBindError> {
match self.0 {
SocketAddrInner::Ordinary(addr) => match tokio_0_2::net::TcpListener::bind(addr).await {
Ok(socket) => Ok(socket),
Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindFailed { addr, error, }.into())),
},
SocketAddrInner::WithHostname(addr) => match tokio_0_2::net::TcpListener::bind(addr.as_str()).await {
Ok(socket) => Ok(socket),
Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())),
},
SocketAddrInner::Systemd(socket_name) => {
let (socket, addr) = Self::get_systemd(socket_name, true)?;
socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
},
SocketAddrInner::SystemdNoPrefix(socket_name) => {
let (socket, addr) = Self::get_systemd(socket_name, false)?;
socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
},
}
}
#[cfg(feature = "tokio_0_3")]
pub async fn bind_tokio_0_3(self) -> Result<tokio_0_3::net::TcpListener, TokioBindError> {
match self.0 {
SocketAddrInner::Ordinary(addr) => match tokio_0_3::net::TcpListener::bind(addr).await {
Ok(socket) => Ok(socket),
Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindFailed { addr, error, }.into())),
},
SocketAddrInner::WithHostname(addr) => match tokio_0_3::net::TcpListener::bind(addr.as_str()).await {
Ok(socket) => Ok(socket),
Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())),
},
SocketAddrInner::Systemd(socket_name) => {
let (socket, addr) = Self::get_systemd(socket_name, true)?;
socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
},
SocketAddrInner::SystemdNoPrefix(socket_name) => {
let (socket, addr) = Self::get_systemd(socket_name, false)?;
socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
},
}
}
#[cfg(feature = "async-std")]
pub async fn bind_async_std(self) -> Result<async_std::net::TcpListener, BindError> {
match self.0 {
SocketAddrInner::Ordinary(addr) => match async_std::net::TcpListener::bind(addr).await {
Ok(socket) => Ok(socket),
Err(error) => Err(BindErrorInner::BindFailed { addr, error, }.into()),
},
SocketAddrInner::WithHostname(addr) => match async_std::net::TcpListener::bind(addr.as_str()).await {
Ok(socket) => Ok(socket),
Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()),
},
SocketAddrInner::Systemd(socket_name) => {
let (socket, _) = Self::get_systemd(socket_name, true)?;
Ok(socket.into())
},
SocketAddrInner::SystemdNoPrefix(socket_name) => {
let (socket, _) = Self::get_systemd(socket_name, false)?;
Ok(socket.into())
},
}
}
fn try_from_generic<'a, T>(string: T) -> Result<Self, ParseError> where T: 'a + std::ops::Deref<Target=str> + Into<String> {
if string.starts_with(SYSTEMD_PREFIX) {
Self::inner_from_systemd_name(string.into(), true)
} else {
match string.parse() {
Ok(addr) => Ok(SocketAddr(SocketAddrInner::Ordinary(addr))),
Err(_) => Ok(SocketAddr(SocketAddrInner::WithHostname(ResolvAddr::try_from_generic(string).map_err(ParseErrorInner::ResolvAddr)?))),
}
}
}
#[cfg(all(target_os = "linux", feature = "enable_systemd"))]
fn get_systemd(socket_name: String, prefixed: bool) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> {
use systemd_sockets::Socket;
let real_systemd_name = if prefixed {
&socket_name[SYSTEMD_PREFIX.len()..]
} else {
&socket_name
};
let socket = systemd_sockets::take(real_systemd_name).map_err(BindErrorInner::ReceiveDescriptors)?;
match socket {
Some(Ok(Socket::TcpListener(socket))) => Ok((socket, SocketAddrInner::Systemd(socket_name))),
Some(_) => Err(BindErrorInner::NotInetSocket(socket_name).into()),
None => Err(BindErrorInner::MissingDescriptor(socket_name).into())
}
}
#[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
fn get_systemd(socket_name: Never, _prefixed: bool) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> {
match socket_name {}
}
}
#[inline]
pub fn init() -> Result<(), error::InitError> {
#[cfg(all(target_os = "linux", feature = "enable_systemd"))]
{
unsafe { systemd_sockets::init(true) }.map_err(error::InitError)
}
#[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
{
Ok(())
}
}
pub unsafe fn init_unprotected() -> Result<(), error::InitError> {
#[cfg(all(target_os = "linux", feature = "enable_systemd"))]
{
systemd_sockets::init(false).map_err(error::InitError)
}
#[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
{
Ok(())
}
}
impl fmt::Display for SocketAddr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&self.0, f)
}
}
impl fmt::Display for SocketAddrInner {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
SocketAddrInner::Ordinary(addr) => fmt::Display::fmt(addr, f),
SocketAddrInner::Systemd(addr) => fmt::Display::fmt(addr, f),
SocketAddrInner::SystemdNoPrefix(addr) => write!(f, "{}{}", SYSTEMD_PREFIX, addr),
SocketAddrInner::WithHostname(addr) => fmt::Display::fmt(addr, f),
}
}
}
#[derive(Debug, PartialEq)]
enum SocketAddrInner {
Ordinary(std::net::SocketAddr),
WithHostname(resolv_addr::ResolvAddr),
#[cfg(all(target_os = "linux", feature = "enable_systemd"))]
Systemd(String),
#[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
#[allow(dead_code)]
Systemd(Never),
#[cfg(all(target_os = "linux", feature = "enable_systemd"))]
#[allow(dead_code)]
SystemdNoPrefix(String),
#[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
#[allow(dead_code)]
SystemdNoPrefix(Never),
}
const SYSTEMD_PREFIX: &str = "systemd://";
impl<I: Into<std::net::IpAddr>> From<(I, u16)> for SocketAddr {
fn from(value: (I, u16)) -> Self {
SocketAddr(SocketAddrInner::Ordinary(value.into()))
}
}
impl From<std::net::SocketAddr> for SocketAddr {
fn from(value: std::net::SocketAddr) -> Self {
SocketAddr(SocketAddrInner::Ordinary(value))
}
}
impl From<std::net::SocketAddrV4> for SocketAddr {
fn from(value: std::net::SocketAddrV4) -> Self {
SocketAddr(SocketAddrInner::Ordinary(value.into()))
}
}
impl From<std::net::SocketAddrV6> for SocketAddr {
fn from(value: std::net::SocketAddrV6) -> Self {
SocketAddr(SocketAddrInner::Ordinary(value.into()))
}
}
impl std::str::FromStr for SocketAddr {
type Err = ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
SocketAddr::try_from_generic(s)
}
}
impl<'a> TryFrom<&'a str> for SocketAddr {
type Error = ParseError;
fn try_from(s: &'a str) -> Result<Self, Self::Error> {
SocketAddr::try_from_generic(s)
}
}
impl TryFrom<String> for SocketAddr {
type Error = ParseError;
fn try_from(s: String) -> Result<Self, Self::Error> {
SocketAddr::try_from_generic(s)
}
}
impl<'a> TryFrom<&'a OsStr> for SocketAddr {
type Error = ParseOsStrError;
fn try_from(s: &'a OsStr) -> Result<Self, Self::Error> {
s.to_str().ok_or(ParseOsStrError::InvalidUtf8)?.try_into().map_err(Into::into)
}
}
impl TryFrom<OsString> for SocketAddr {
type Error = ParseOsStrError;
fn try_from(s: OsString) -> Result<Self, Self::Error> {
s.into_string().map_err(|_| ParseOsStrError::InvalidUtf8)?.try_into().map_err(Into::into)
}
}
#[cfg(feature = "serde")]
impl<'a> TryFrom<serde_str_helpers::DeserBorrowStr<'a>> for SocketAddr {
type Error = ParseError;
fn try_from(s: serde_str_helpers::DeserBorrowStr<'a>) -> Result<Self, Self::Error> {
SocketAddr::try_from_generic(s)
}
}
#[cfg(feature = "parse_arg")]
impl parse_arg::ParseArg for SocketAddr {
type Error = ParseOsStrError;
fn describe_type<W: fmt::Write>(mut writer: W) -> fmt::Result {
std::net::SocketAddr::describe_type(&mut writer)?;
write!(writer, " or a systemd socket name prefixed with systemd://")
}
fn parse_arg(arg: &OsStr) -> Result<Self, Self::Error> {
arg.try_into()
}
fn parse_owned_arg(arg: OsString) -> Result<Self, Self::Error> {
arg.try_into()
}
}
#[cfg(test)]
mod tests {
use super::{SocketAddr, SocketAddrInner};
#[test]
fn parse_ordinary() {
assert_eq!("127.0.0.1:42".parse::<SocketAddr>().unwrap().0, SocketAddrInner::Ordinary(([127, 0, 0, 1], 42).into()));
}
#[test]
#[cfg(all(target_os = "linux", feature = "enable_systemd"))]
fn parse_systemd() {
assert_eq!("systemd://foo".parse::<SocketAddr>().unwrap().0, SocketAddrInner::Systemd("systemd://foo".to_owned()));
}
#[test]
#[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
#[should_panic]
fn parse_systemd() {
"systemd://foo".parse::<SocketAddr>().unwrap();
}
#[test]
#[should_panic]
fn parse_systemd_fail_control() {
"systemd://foo\n".parse::<SocketAddr>().unwrap();
}
#[test]
#[should_panic]
fn parse_systemd_fail_colon() {
"systemd://foo:".parse::<SocketAddr>().unwrap();
}
#[test]
#[should_panic]
fn parse_systemd_fail_non_ascii() {
"systemd://fooá".parse::<SocketAddr>().unwrap();
}
#[test]
#[should_panic]
fn parse_systemd_fail_too_long() {
"systemd://xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx".parse::<SocketAddr>().unwrap();
}
#[test]
#[cfg_attr(not(all(target_os = "linux", feature = "enable_systemd")), should_panic)]
fn no_prefix_parse_systemd() {
SocketAddr::from_systemd_name("foo").unwrap();
}
#[test]
#[should_panic]
fn no_prefix_parse_systemd_fail_non_ascii() {
SocketAddr::from_systemd_name("fooá").unwrap();
}
#[test]
#[should_panic]
fn no_prefix_parse_systemd_fail_too_long() {
SocketAddr::from_systemd_name("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx").unwrap();
}
}