use std::convert::Infallible;
use std::fmt;
use std::io;
use std::str::FromStr;
use camino::Utf8Path;
use camino::Utf8PathBuf;
use http::uri::Authority;
use tokio::net::{TcpStream, UnixStream};
#[cfg(feature = "tls")]
pub mod tls;
#[cfg(feature = "tls")]
pub use self::tls::HasTlsConnectionInfo;
#[cfg(feature = "tls")]
pub use self::tls::TlsConnectionInfo;
pub use crate::stream::duplex::DuplexAddr;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Protocol {
Http(http::Version),
Grpc,
WebSocket,
Other(String),
}
impl std::fmt::Display for Protocol {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Http(version) => write!(f, "{:?}", version),
Self::Grpc => write!(f, "gRPC"),
Self::WebSocket => write!(f, "WebSocket"),
Self::Other(s) => write!(f, "{}", s),
}
}
}
impl Protocol {
pub fn http(version: http::Version) -> Self {
Self::Http(version)
}
pub fn grpc() -> Self {
Self::Grpc
}
pub fn web_socket() -> Self {
Self::WebSocket
}
}
impl From<http::Version> for Protocol {
fn from(version: http::Version) -> Self {
Self::Http(version)
}
}
impl FromStr for Protocol {
type Err = Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"http/0.9" => Ok(Self::Http(http::Version::HTTP_09)),
"http/1.0" => Ok(Self::Http(http::Version::HTTP_10)),
"http/1.1" => Ok(Self::Http(http::Version::HTTP_11)),
"h2" => Ok(Self::Http(http::Version::HTTP_2)),
"h3" => Ok(Self::Http(http::Version::HTTP_3)),
"gRPC" => Ok(Self::Grpc),
"WebSocket" => Ok(Self::WebSocket),
_ => Ok(Self::Other(s.to_string())),
}
}
}
#[cfg(feature = "stream")]
fn make_canonical(addr: std::net::SocketAddr) -> std::net::SocketAddr {
match addr.ip() {
std::net::IpAddr::V4(_) => addr,
std::net::IpAddr::V6(ip) => {
if let Some(ip) = ip.to_ipv4_mapped() {
std::net::SocketAddr::new(std::net::IpAddr::V4(ip), addr.port())
} else {
addr
}
}
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
pub struct UnixAddr {
path: Option<Utf8PathBuf>,
}
impl UnixAddr {
pub fn is_named(&self) -> bool {
self.path.is_some()
}
pub fn path(&self) -> Option<&Utf8Path> {
self.path.as_deref()
}
pub fn from_pathbuf(path: Utf8PathBuf) -> Self {
Self { path: Some(path) }
}
pub fn unnamed() -> Self {
Self { path: None }
}
}
impl fmt::Display for UnixAddr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(path) = self.path() {
write!(f, "unix://{}", path)
} else {
write!(f, "unix://")
}
}
}
impl TryFrom<std::os::unix::net::SocketAddr> for UnixAddr {
type Error = io::Error;
fn try_from(addr: std::os::unix::net::SocketAddr) -> Result<Self, Self::Error> {
Ok(Self {
path: addr
.as_pathname()
.map(|p| {
Utf8Path::from_path(p).ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidData, "not a utf-8 path")
})
})
.transpose()?
.map(|path| path.to_owned()),
})
}
}
impl TryFrom<tokio::net::unix::SocketAddr> for UnixAddr {
type Error = io::Error;
fn try_from(addr: tokio::net::unix::SocketAddr) -> Result<Self, Self::Error> {
Ok(Self {
path: addr
.as_pathname()
.map(|p| {
Utf8Path::from_path(p).ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidData, "not a utf-8 path")
})
})
.transpose()?
.map(|path| path.to_owned()),
})
}
}
#[cfg(feature = "stream")]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum BraidAddr {
Tcp(std::net::SocketAddr),
Duplex,
Unix(UnixAddr),
}
#[cfg(feature = "stream")]
impl std::fmt::Display for BraidAddr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Tcp(addr) => write!(f, "{}", addr),
Self::Duplex => write!(f, "<duplex>"),
Self::Unix(path) => write!(f, "{}", path),
}
}
}
#[cfg(feature = "stream")]
impl BraidAddr {
pub fn tcp(&self) -> Option<std::net::SocketAddr> {
match self {
Self::Tcp(addr) => Some(*addr),
_ => None,
}
}
pub fn path(&self) -> Option<&Utf8Path> {
match self {
Self::Unix(addr) => addr.path(),
_ => None,
}
}
pub fn canonical(self) -> Self {
match self {
Self::Tcp(addr) => Self::Tcp(make_canonical(addr)),
_ => self,
}
}
}
#[cfg(feature = "stream")]
impl From<std::net::SocketAddr> for BraidAddr {
fn from(addr: std::net::SocketAddr) -> Self {
Self::Tcp(make_canonical(addr))
}
}
#[cfg(feature = "stream")]
impl TryFrom<tokio::net::unix::SocketAddr> for BraidAddr {
type Error = io::Error;
fn try_from(addr: tokio::net::unix::SocketAddr) -> Result<Self, Self::Error> {
Ok(Self::Unix(addr.try_into()?))
}
}
#[cfg(feature = "stream")]
impl From<(std::net::IpAddr, u16)> for BraidAddr {
fn from(addr: (std::net::IpAddr, u16)) -> Self {
Self::Tcp(std::net::SocketAddr::new(addr.0, addr.1))
}
}
#[cfg(feature = "stream")]
impl From<(std::net::Ipv4Addr, u16)> for BraidAddr {
fn from(addr: (std::net::Ipv4Addr, u16)) -> Self {
Self::Tcp(std::net::SocketAddr::new(
std::net::IpAddr::V4(addr.0),
addr.1,
))
}
}
#[cfg(feature = "stream")]
impl From<(std::net::Ipv6Addr, u16)> for BraidAddr {
fn from(addr: (std::net::Ipv6Addr, u16)) -> Self {
Self::Tcp(std::net::SocketAddr::new(
std::net::IpAddr::V6(addr.0),
addr.1,
))
}
}
#[cfg(feature = "stream")]
impl From<Utf8PathBuf> for BraidAddr {
fn from(addr: Utf8PathBuf) -> Self {
Self::Unix(UnixAddr::from_pathbuf(addr))
}
}
#[cfg(feature = "stream")]
impl From<UnixAddr> for BraidAddr {
fn from(addr: UnixAddr) -> Self {
Self::Unix(addr)
}
}
#[cfg(feature = "stream")]
impl From<DuplexAddr> for BraidAddr {
fn from(_: DuplexAddr) -> Self {
Self::Duplex
}
}
#[cfg(feature = "stream")]
#[derive(Debug, Clone)]
pub struct ConnectionInfo<Addr = BraidAddr> {
pub protocol: Option<Protocol>,
pub authority: Option<Authority>,
pub local_addr: Addr,
pub remote_addr: Addr,
pub buffer_size: Option<usize>,
}
#[cfg(not(feature = "stream"))]
#[derive(Debug, Clone)]
pub struct ConnectionInfo<Addr> {
pub protocol: Option<Protocol>,
pub authority: Option<Authority>,
pub local_addr: Addr,
pub remote_addr: Addr,
pub buffer_size: Option<usize>,
}
impl<Addr> Default for ConnectionInfo<Addr>
where
Addr: Default,
{
fn default() -> Self {
Self {
protocol: None,
authority: None,
local_addr: Addr::default(),
remote_addr: Addr::default(),
buffer_size: None,
}
}
}
#[cfg(feature = "stream")]
impl ConnectionInfo<BraidAddr> {
pub(crate) fn duplex(name: Authority, protocol: Option<Protocol>, buffer_size: usize) -> Self {
ConnectionInfo {
protocol,
authority: Some(name),
local_addr: BraidAddr::Duplex,
remote_addr: BraidAddr::Duplex,
buffer_size: Some(buffer_size),
}
}
}
#[cfg(not(feature = "stream"))]
impl ConnectionInfo<DuplexAddr> {
pub(crate) fn duplex(name: Authority, protocol: Option<Protocol>, buffer_size: usize) -> Self {
ConnectionInfo {
protocol,
authority: Some(name),
local_addr: DuplexAddr::new(),
remote_addr: DuplexAddr::new(),
buffer_size: Some(buffer_size),
}
}
}
impl<Addr> ConnectionInfo<Addr> {
pub fn local_addr(&self) -> &Addr {
&self.local_addr
}
pub fn remote_addr(&self) -> &Addr {
&self.remote_addr
}
pub fn map<T, F>(self, f: F) -> ConnectionInfo<T>
where
F: Fn(Addr) -> T,
{
ConnectionInfo {
protocol: self.protocol,
authority: self.authority,
local_addr: f(self.local_addr),
remote_addr: f(self.remote_addr),
buffer_size: self.buffer_size,
}
}
}
impl<Addr> TryFrom<&TcpStream> for ConnectionInfo<Addr>
where
Addr: From<std::net::SocketAddr>,
{
type Error = io::Error;
fn try_from(stream: &TcpStream) -> Result<Self, Self::Error> {
let local_addr = stream.local_addr()?;
let remote_addr = stream.peer_addr()?;
Ok(Self {
protocol: None,
authority: None,
local_addr: local_addr.into(),
remote_addr: remote_addr.into(),
buffer_size: None,
})
}
}
impl<Addr> TryFrom<&UnixStream> for ConnectionInfo<Addr>
where
Addr: From<UnixAddr>,
{
type Error = io::Error;
fn try_from(stream: &UnixStream) -> Result<Self, Self::Error> {
let local_addr = match stream.local_addr() {
Ok(addr) => addr.try_into().expect("unix socket address"),
Err(e) if matches!(e.kind(), io::ErrorKind::InvalidInput) => UnixAddr::unnamed(),
Err(e) => return Err(e),
};
let remote_addr = match stream.peer_addr() {
Ok(addr) => addr.try_into().expect("unix socket address"),
Err(e) if matches!(e.kind(), io::ErrorKind::InvalidInput) => UnixAddr::unnamed(),
Err(e) => return Err(e),
};
Ok(Self {
protocol: None,
authority: None,
local_addr: local_addr.into(),
remote_addr: remote_addr.into(),
buffer_size: None,
})
}
}
pub trait HasConnectionInfo {
type Addr: fmt::Display + fmt::Debug;
fn info(&self) -> ConnectionInfo<Self::Addr>;
}
impl HasConnectionInfo for TcpStream {
type Addr = std::net::SocketAddr;
fn info(&self) -> ConnectionInfo<Self::Addr> {
self.try_into()
.expect("connection info should be available")
}
}
impl HasConnectionInfo for UnixStream {
type Addr = UnixAddr;
fn info(&self) -> ConnectionInfo<Self::Addr> {
ConnectionInfo {
local_addr: self
.local_addr()
.expect("unable to get local address")
.try_into()
.expect("utf-8 unix socket address"),
remote_addr: self
.peer_addr()
.expect("unable to get peer address")
.try_into()
.expect("utf-8 unix socket address"),
..Default::default()
}
}
}
#[cfg(test)]
mod tests {
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use http::Version;
use tokio::net::{TcpListener, UnixListener};
use super::*;
#[test]
fn protocol_display() {
assert_eq!(Protocol::http(Version::HTTP_11).to_string(), "HTTP/1.1");
assert_eq!(Protocol::http(Version::HTTP_2).to_string(), "HTTP/2.0");
assert_eq!(Protocol::http(Version::HTTP_3).to_string(), "HTTP/3.0");
assert_eq!(Protocol::http(Version::HTTP_10).to_string(), "HTTP/1.0");
assert_eq!(Protocol::grpc().to_string(), "gRPC");
assert_eq!(Protocol::web_socket().to_string(), "WebSocket");
}
#[test]
fn parse_protocol() {
assert_eq!(
Protocol::from_str("http/1.1").unwrap(),
Protocol::http(Version::HTTP_11)
);
assert_eq!(
Protocol::from_str("h2").unwrap(),
Protocol::http(Version::HTTP_2)
);
assert_eq!(
Protocol::from_str("h3").unwrap(),
Protocol::http(Version::HTTP_3)
);
assert_eq!(
Protocol::from_str("http/1.0").unwrap(),
Protocol::http(Version::HTTP_10)
);
assert_eq!(Protocol::from_str("gRPC").unwrap(), Protocol::grpc());
assert_eq!(
Protocol::from_str("WebSocket").unwrap(),
Protocol::web_socket()
);
assert_eq!(
Protocol::from_str("foo").unwrap(),
Protocol::Other("foo".into())
)
}
#[test]
fn test_make_canonical() {
assert_eq!(
make_canonical("[::1]:8080".parse().unwrap()),
"[::1]:8080".parse().unwrap()
);
assert_eq!(
make_canonical("[::ffff:192.0.2.128]:8080".parse().unwrap()),
"192.0.2.128:8080".parse().unwrap()
)
}
#[test]
fn connection_info_default() {
let info = ConnectionInfo::<DuplexAddr>::default();
assert_eq!(info.protocol, None);
assert_eq!(info.authority, None);
assert_eq!(info.local_addr, DuplexAddr::new());
assert_eq!(info.remote_addr, DuplexAddr::new());
assert_eq!(info.buffer_size, None);
}
#[test]
fn unix_addr() {
let addr = UnixAddr::from_pathbuf("/tmp/foo.sock".into());
assert_eq!(addr.path(), Some("/tmp/foo.sock".into()));
let addr = UnixAddr::unnamed();
assert_eq!(addr.path(), None);
}
#[test]
fn connection_info_map() {
let info = ConnectionInfo {
protocol: Some(Protocol::http(Version::HTTP_11)),
authority: Some("example.com".parse().unwrap()),
local_addr: "local",
remote_addr: "remote",
buffer_size: Some(1024),
};
let mapped = info.map(|addr| addr.to_string());
assert_eq!(mapped.protocol, Some(Protocol::http(Version::HTTP_11)));
assert_eq!(mapped.authority, Some("example.com".parse().unwrap()));
assert_eq!(mapped.local_addr, "local".to_string());
}
#[tokio::test]
async fn unix_connection_info_unnamed() {
let (a, _) = UnixStream::pair().expect("pair");
let info: ConnectionInfo<UnixAddr> = ConnectionInfo::try_from(&a).unwrap();
assert_eq!(info.local_addr(), &UnixAddr::unnamed());
}
#[tokio::test]
async fn unix_connection_info_named() {
let tmp = tempfile::TempDir::with_prefix("unix-connection-info").unwrap();
tokio::fs::create_dir_all(&tmp).await.unwrap();
let path = tmp.path().join("socket.sock");
let listener = UnixListener::bind(&path).unwrap();
let conn = UnixStream::connect(&path).await.unwrap();
let info: ConnectionInfo<UnixAddr> = ConnectionInfo::try_from(&conn).unwrap();
assert_eq!(
info.remote_addr(),
&UnixAddr::from_pathbuf(path.try_into().unwrap())
);
drop(listener);
}
#[tokio::test]
async fn tcp_connection_info() {
let listener = TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0))
.await
.unwrap();
let addr = listener.local_addr().unwrap();
let conn = TcpStream::connect(addr).await.unwrap();
let info: ConnectionInfo<std::net::SocketAddr> = ConnectionInfo::try_from(&conn).unwrap();
assert_eq!(info.remote_addr().ip(), IpAddr::V4(Ipv4Addr::LOCALHOST));
assert_eq!(info.remote_addr().port(), addr.port());
}
}