#![warn(missing_docs)]
#![warn(missing_debug_implementations)]
#![deny(unsafe_code)]
use core::fmt;
use std::borrow::Cow;
use std::io;
use std::sync::Arc;
use hyperdriver::bridge::rt::TokioExecutor;
use hyperdriver::client::conn::protocol::auto;
use hyperdriver::client::conn::Stream as ClientStream;
use hyperdriver::info::UnixAddr;
use hyperdriver::server::AutoBuilder;
use hyperdriver::stream::UnixStream;
use pidfile::PidFile;
use camino::{Utf8Path, Utf8PathBuf};
use dashmap::mapref::one::{Ref, RefMut};
use dashmap::DashMap;
use hyper::Uri;
use tower::make::Shared;
mod transport;
use hyperdriver::client::Client;
pub use transport::GrpcScheme;
pub use transport::RegistryTransport;
pub use transport::Scheme;
pub use transport::SvcScheme;
pub use transport::TransportBuilder;
#[derive(Debug, thiserror::Error)]
pub enum ConnectionError {
#[error("Invalid name: {0}")]
InvalidName(String),
#[error("Connection to {0} timed out")]
ConnectionTimeout(String, #[source] tokio::time::error::Elapsed),
#[error("Invalid URI: {0}")]
InvalidUri(Uri),
#[error("Handshake with {name}")]
Handshake {
#[source]
error: io::Error,
name: String,
},
#[error("Error {} connecting to {name} over a duplex socket", .error.kind())]
Duplex {
#[source]
error: io::Error,
name: String,
},
#[error("Error {} connecting to {name} at {path}", .error.kind())]
Unix {
#[source]
error: io::Error,
path: Utf8PathBuf,
name: String,
},
}
#[derive(Debug)]
pub(crate) enum InternalBindError {
AlreadyBound,
SocketResetError(Utf8PathBuf, io::Error),
PidLockError(Utf8PathBuf, io::Error),
}
#[derive(Debug)]
pub struct BindError {
service: String,
inner: InternalBindError,
}
impl BindError {
fn new(service: String, inner: InternalBindError) -> Self {
Self { service, inner }
}
}
impl fmt::Display for BindError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.inner {
InternalBindError::AlreadyBound => {
write!(f, "Service {} is already bound", self.service)
}
InternalBindError::SocketResetError(path, error) => {
write!(
f,
"Service {}: Unable to reset socket at {}: {}",
self.service, path, error
)
}
InternalBindError::PidLockError(path, error) => {
write!(
f,
"Service {}: Unable to lock PID file at {}: {}",
self.service, path, error
)
}
}
}
}
impl std::error::Error for BindError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match &self.inner {
InternalBindError::AlreadyBound => None,
InternalBindError::SocketResetError(_, error) => Some(error),
InternalBindError::PidLockError(_, error) => Some(error),
}
}
}
#[derive(Debug, Clone, Default)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(feature = "serde", serde(rename_all = "lowercase"))]
pub enum ServiceDiscovery {
#[default]
InProcess,
Unix {
path: Utf8PathBuf,
},
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(feature = "serde", serde(default))]
pub struct RegistryConfig {
pub service_discovery: ServiceDiscovery,
#[cfg_attr(feature = "serde", serde(with = "humantime_serde"))]
pub connect_timeout: Option<std::time::Duration>,
pub buffer_size: usize,
#[cfg_attr(feature = "serde", serde(with = "humantime_serde"))]
pub proxy_timeout: std::time::Duration,
pub proxy_limit: usize,
}
impl Default for RegistryConfig {
fn default() -> Self {
Self {
service_discovery: Default::default(),
connect_timeout: None,
buffer_size: 1024 * 1024,
proxy_timeout: std::time::Duration::from_secs(30),
proxy_limit: 32,
}
}
}
#[derive(Clone, Default)]
pub struct ServiceRegistry {
inner: Arc<InnerRegistry>,
config: Arc<RegistryConfig>,
}
impl std::fmt::Debug for ServiceRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ServiceRegistry")
.field("config", &self.config)
.finish()
}
}
impl ServiceRegistry {
pub fn new() -> Self {
Self {
inner: Arc::new(InnerRegistry::default()),
config: Arc::new(RegistryConfig::default()),
}
}
pub fn new_with_config(config: RegistryConfig) -> Self {
Self {
inner: Arc::new(InnerRegistry::default()),
config: Arc::new(config),
}
}
#[inline]
fn config_mut(&mut self) -> &mut RegistryConfig {
Arc::make_mut(&mut self.config)
}
pub fn set_discovery(&mut self, discovery: ServiceDiscovery) {
self.config_mut().service_discovery = discovery;
}
pub fn set_connect_timeout(&mut self, timeout: std::time::Duration) {
self.config_mut().connect_timeout = Some(timeout);
}
pub fn set_buffer_size(&mut self, size: usize) {
self.config_mut().buffer_size = size;
}
pub fn is_available<S: AsRef<str>>(&self, service: S) -> bool {
self.inner.is_available(&self.config, service.as_ref())
}
#[tracing::instrument(skip_all, fields(service=tracing::field::Empty))]
pub async fn bind<'a, S>(
&'a self,
service: S,
) -> Result<hyperdriver::server::conn::Acceptor, BindError>
where
S: Into<Cow<'a, str>>,
{
let name = service.into();
let span = tracing::Span::current();
span.record("service", name.as_ref());
self.inner
.bind(&self.config, &name)
.map_err(|err| BindError::new(name.into_owned(), err))
}
pub async fn server<'a, S, M, B, E>(
&'a self,
make_service: M,
name: S,
executor: E,
) -> Result<
hyperdriver::server::Server<
hyperdriver::server::conn::Acceptor,
AutoBuilder<TokioExecutor>,
M,
B,
E,
>,
BindError,
>
where
S: Into<Cow<'a, str>>,
{
let acceptor = self.bind(name.into()).await?;
Ok(hyperdriver::server::Server::builder()
.with_acceptor(acceptor)
.with_auto_http()
.with_make_service(make_service)
.with_executor(executor))
}
pub fn router<A, B, E>(
&self,
acceptor: A,
executor: E,
) -> hyperdriver::server::Server<A, AutoBuilder<TokioExecutor>, Shared<Client>, B, E>
where
A: hyperdriver::server::conn::Accept,
B: http_body::Body,
{
hyperdriver::server::Server::builder()
.with_acceptor(acceptor)
.with_auto_http()
.with_shared_service(self.client())
.with_executor(executor)
}
#[tracing::instrument(skip_all, fields(service=tracing::field::Empty))]
pub async fn connect<'a, S: Into<Cow<'a, str>>>(
&'a self,
service: S,
) -> Result<ClientStream, ConnectionError> {
let service = service.into();
let span = tracing::Span::current();
span.record("service", service.as_ref());
self.inner.connect(&self.config, service).await
}
pub fn default_transport(&self) -> transport::RegistryTransport {
transport::RegistryTransport::with_default_schemes(self.clone())
}
pub fn transport_builder(&self) -> transport::TransportBuilder {
transport::RegistryTransport::builder(self.clone())
}
pub fn client(&self) -> Client {
let transport = self.default_transport();
Client::builder()
.with_transport(transport)
.with_protocol(auto::HttpConnectionBuilder::default())
.with_pool(Default::default())
.without_tls()
.build()
}
}
#[derive(Debug)]
struct InnerRegistry {
services: DashMap<String, ServiceHandle>,
}
impl Default for InnerRegistry {
fn default() -> Self {
Self {
services: DashMap::new(),
}
}
}
impl InnerRegistry {
fn get_mut(&self, config: &RegistryConfig, service: &str) -> RefMut<'_, String, ServiceHandle> {
self.services
.entry(service.to_owned())
.or_insert_with(|| match &config.service_discovery {
ServiceDiscovery::InProcess => ServiceHandle::duplex(),
ServiceDiscovery::Unix { path } => ServiceHandle::unix(path, service),
})
}
fn get(&self, config: &RegistryConfig, service: &str) -> Ref<'_, String, ServiceHandle> {
if let Some(handle) = self.services.get(service) {
handle
} else {
self.get_mut(config, service).downgrade()
}
}
fn is_available(&self, config: &RegistryConfig, service: &str) -> bool {
let handle = self.get(config, service);
handle.is_available()
}
#[tracing::instrument(skip(self, config))]
async fn connect(
&self,
config: &RegistryConfig,
service: Cow<'_, str>,
) -> Result<ClientStream, ConnectionError> {
let handle = self.get(config, service.as_ref());
connect_to_handle(config, handle.value(), service).await
}
fn bind(
&self,
config: &RegistryConfig,
service: &str,
) -> Result<hyperdriver::server::conn::Acceptor, InternalBindError> {
let mut handle = self.get_mut(config, service);
handle.acceptor()
}
}
#[derive(Debug)]
enum PidLock {
Path(Utf8PathBuf),
#[allow(dead_code)]
Lock(PidFile),
}
impl PidLock {
fn is_available(&self) -> bool {
tracing::trace!("Checking PID file {self:?}");
match self {
PidLock::Path(path) => PidFile::is_locked(path.as_std_path())
.map_err(|error| tracing::warn!("Unable to inspect PID file: {error:?}"))
.unwrap_or(false),
PidLock::Lock(_) => true,
}
}
}
#[derive(Debug)]
enum ServiceHandle {
Duplex {
acceptor: Option<hyperdriver::server::conn::Acceptor>,
connector: hyperdriver::stream::duplex::DuplexClient,
},
Unix {
path: Utf8PathBuf,
pidfile: PidLock,
},
}
impl ServiceHandle {
fn duplex() -> Self {
let (connector, acceptor) = hyperdriver::stream::duplex::pair();
Self::Duplex {
acceptor: Some(acceptor.into()),
connector,
}
}
fn unix(path: &Utf8Path, service: &str) -> Self {
let svcpath = path.join(format!("{service}.svc"));
let pidfile = path.join(format!("{service}.pid"));
Self::Unix {
path: svcpath,
pidfile: PidLock::Path(pidfile),
}
}
fn is_available(&self) -> bool {
match self {
ServiceHandle::Duplex { acceptor, .. } => acceptor.is_none(),
ServiceHandle::Unix { pidfile, .. } => pidfile.is_available(),
}
}
async fn connect(
&self,
config: &RegistryConfig,
name: Cow<'_, str>,
) -> Result<hyperdriver::client::conn::Stream, ConnectionError> {
match self {
ServiceHandle::Duplex { connector, .. } => Ok(connector
.connect(config.buffer_size)
.await
.map(|stream| stream.into())
.map_err(|error| ConnectionError::Duplex {
error,
name: name.into_owned(),
}))?,
ServiceHandle::Unix { path, .. } => tokio::net::UnixStream::connect(path)
.await
.map(|stream| {
UnixStream::new(stream, Some(UnixAddr::from_pathbuf(path.clone()))).into()
})
.map_err(|error| ConnectionError::Unix {
error,
path: path.into(),
name: name.into_owned(),
}),
}
}
fn acceptor(&mut self) -> Result<hyperdriver::server::conn::Acceptor, InternalBindError> {
match self {
ServiceHandle::Duplex { acceptor, .. } => {
tracing::trace!("Preparing in-process acceptor");
acceptor.take().ok_or(InternalBindError::AlreadyBound)
}
ServiceHandle::Unix { ref path, pidfile } => {
tracing::trace!("Locking PID file");
let file = match pidfile {
PidLock::Path(ref path) => PidFile::new(path.clone()).map_err(|err| {
tracing::warn!(
"Encountered an error resetting the Pid file {path}: {}",
err
);
InternalBindError::PidLockError(path.clone(), err)
})?,
PidLock::Lock(_) => {
tracing::warn!("Service is already bound in this process");
return Err(InternalBindError::AlreadyBound);
}
};
*pidfile = PidLock::Lock(file);
tracing::trace!("Binding to socket at {path}");
if let Err(error) = std::fs::remove_file(path) {
match error.kind() {
io::ErrorKind::NotFound => {}
_ => {
tracing::error!("Unable to remove socket: {:#}", error);
return Err(InternalBindError::SocketResetError(path.clone(), error));
}
}
}
tokio::net::UnixListener::bind(path)
.map(|listener| listener.into())
.map_err(|error| match error.kind() {
io::ErrorKind::AddrInUse => {
tracing::warn!("Service is already bound");
InternalBindError::AlreadyBound
}
_ => {
tracing::error!("Unable to bind socket: {:#}", error);
InternalBindError::SocketResetError(path.clone(), error)
}
})
}
}
}
}
async fn connect_to_handle(
config: &RegistryConfig,
handle: &ServiceHandle,
name: Cow<'_, str>,
) -> Result<ClientStream, ConnectionError> {
let request = handle.connect(config, name.clone());
let stream = if let Some(timeout) = &config.connect_timeout {
tracing::trace!("Waiting for connection to {name} with timeout");
match tokio::time::timeout(*timeout, request).await {
Ok(outcome) => outcome,
Err(elapsed) => {
tracing::warn!(
"Connection to {name} timed out after {timeout:?}",
name = name,
timeout = elapsed
);
return Err(ConnectionError::ConnectionTimeout(
name.into_owned(),
elapsed,
));
}
}
} else {
tracing::trace!("Waiting for connection to {name} without timeout");
tokio::pin!(request);
let default_timeout = std::time::Duration::from_secs(30);
match tokio::time::timeout(default_timeout, &mut request).await {
Ok(Ok(stream)) => Ok(stream),
Err(_) => {
tracing::warn!(
"Waited {}s without a timeout for connection to {name}... continuing",
default_timeout.as_secs()
);
request.await
}
Ok(Err(error)) => Err(error),
}
}?;
Ok(stream)
}
#[cfg(test)]
mod tests {
use hyperdriver::info::{BraidAddr, HasConnectionInfo as _};
use super::*;
#[test]
fn test_service_handle() {
let tmp = tempfile::tempdir().unwrap();
let name = "service.with.dots";
let handle = ServiceHandle::unix(tmp.path().try_into().unwrap(), name);
assert!(!handle.is_available());
let ServiceHandle::Unix { path, pidfile } = handle else {
panic!("expected unix handle")
};
let expected =
Utf8PathBuf::from_path_buf(tmp.path().join(format!("{}.svc", name))).unwrap();
assert_eq!(path, expected);
assert!(matches!(pidfile, PidLock::Path(_)));
}
#[tokio::test]
async fn connect_to_handle_unix() {
let tmp = tempfile::tempdir().unwrap();
let name = "service.with.dots";
let mut handle = ServiceHandle::unix(tmp.path().try_into().unwrap(), name);
let _accept = handle.acceptor().unwrap();
let config = RegistryConfig::default();
let name = Cow::Borrowed(name);
let stream = connect_to_handle(&config, &handle, name.clone())
.await
.unwrap();
let info = stream.info();
let remote = info.remote_addr();
match remote {
BraidAddr::Unix(addr) => {
assert_eq!(
addr.path().unwrap(),
tmp.path().join(format!("{}.svc", name))
);
}
_ => panic!("expected Unix address"),
}
}
#[tokio::test]
async fn connect_to_handle_unix_error() {
let tmp = tempfile::tempdir().unwrap();
let name = "service.with.dots";
let handle = ServiceHandle::unix(tmp.path().try_into().unwrap(), name);
let config = RegistryConfig::default();
let name = Cow::Borrowed(name);
let result = connect_to_handle(&config, &handle, name).await;
match result.unwrap_err() {
ConnectionError::Unix { error, path, name } => {
assert_eq!(error.kind(), io::ErrorKind::NotFound);
assert_eq!(path, tmp.path().join(format!("{}.svc", name)));
assert_eq!(name, name);
}
_ => panic!("expected Unix error"),
}
}
}