hydiscovery/
lib.rs

1//! HyDiscovery provides service discovery in-process via unix domain sockets and the filesystem.
2//!
3//! This works well for running a collection of microservices on a single machine.
4
5#![warn(missing_docs)]
6#![warn(missing_debug_implementations)]
7#![deny(unsafe_code)]
8
9use core::fmt;
10use std::borrow::Cow;
11use std::io;
12use std::sync::Arc;
13
14use hyperdriver::bridge::rt::TokioExecutor;
15use hyperdriver::client::conn::protocol::auto;
16
17use hyperdriver::client::conn::Stream as ClientStream;
18use hyperdriver::info::UnixAddr;
19use hyperdriver::server::AutoBuilder;
20use hyperdriver::stream::UnixStream;
21use pidfile::PidFile;
22
23use camino::{Utf8Path, Utf8PathBuf};
24use dashmap::mapref::one::{Ref, RefMut};
25use dashmap::DashMap;
26use hyper::Uri;
27use tower::make::Shared;
28
29mod transport;
30
31use hyperdriver::client::Client;
32pub use transport::GrpcScheme;
33pub use transport::RegistryTransport;
34pub use transport::Scheme;
35pub use transport::SvcScheme;
36pub use transport::TransportBuilder;
37
38/// An error occured while connecting to a service.
39#[derive(Debug, thiserror::Error)]
40pub enum ConnectionError {
41    /// The service name was not a valid authority (e.g. `svc://foo`)
42    #[error("Invalid name: {0}")]
43    InvalidName(String),
44
45    /// Connection to the service timed out.
46    #[error("Connection to {0} timed out")]
47    ConnectionTimeout(String, #[source] tokio::time::error::Elapsed),
48
49    /// The service URI is not a valid URI.
50    #[error("Invalid URI: {0}")]
51    InvalidUri(Uri),
52
53    /// An IO error occured while handshaking with the service.
54    #[error("Handshake with {name}")]
55    Handshake {
56        /// Internal error.
57        #[source]
58        error: io::Error,
59
60        /// Service name.
61        name: String,
62    },
63
64    /// Error connecting to a duplex socket
65    #[error("Error {} connecting to {name} over a duplex socket", .error.kind())]
66    Duplex {
67        /// Internal error.
68        #[source]
69        error: io::Error,
70
71        /// Service name.
72        name: String,
73    },
74
75    /// An IO error occured while connecting to the service.
76    #[error("Error {} connecting to {name} at {path}", .error.kind())]
77    Unix {
78        /// Internal IO error.
79        #[source]
80        error: io::Error,
81
82        /// Path to unix socket.
83        path: Utf8PathBuf,
84
85        /// Service name.
86        name: String,
87    },
88}
89
90/// Internal error when something goes wrong during Bind.
91///
92/// Doesn't require the name, it will be added in context
93/// farther up.
94#[derive(Debug)]
95pub(crate) enum InternalBindError {
96    AlreadyBound,
97
98    SocketResetError(Utf8PathBuf, io::Error),
99
100    PidLockError(Utf8PathBuf, io::Error),
101}
102
103/// An error occured binding this service to the specified name.
104#[derive(Debug)]
105pub struct BindError {
106    service: String,
107    inner: InternalBindError,
108}
109
110impl BindError {
111    fn new(service: String, inner: InternalBindError) -> Self {
112        Self { service, inner }
113    }
114}
115
116impl fmt::Display for BindError {
117    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
118        match &self.inner {
119            InternalBindError::AlreadyBound => {
120                write!(f, "Service {} is already bound", self.service)
121            }
122            InternalBindError::SocketResetError(path, error) => {
123                write!(
124                    f,
125                    "Service {}: Unable to reset socket at {}: {}",
126                    self.service, path, error
127                )
128            }
129            InternalBindError::PidLockError(path, error) => {
130                write!(
131                    f,
132                    "Service {}: Unable to lock PID file at {}: {}",
133                    self.service, path, error
134                )
135            }
136        }
137    }
138}
139
140impl std::error::Error for BindError {
141    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
142        match &self.inner {
143            InternalBindError::AlreadyBound => None,
144            InternalBindError::SocketResetError(_, error) => Some(error),
145            InternalBindError::PidLockError(_, error) => Some(error),
146        }
147    }
148}
149
150/// Service discovery mechanism for services registered.
151#[derive(Debug, Clone, Default)]
152#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
153#[cfg_attr(feature = "serde", serde(rename_all = "lowercase"))]
154pub enum ServiceDiscovery {
155    /// Discover services in the same process, using an in-memory store and transport.
156    #[default]
157    InProcess,
158
159    /// Discover services by looking for a well-known unix socket.
160    Unix {
161        /// Path to the directory containing the unix sockets.
162        path: Utf8PathBuf,
163    },
164}
165
166/// Configuration for the service registry.
167#[derive(Debug, Clone)]
168#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
169#[cfg_attr(feature = "serde", serde(default))]
170pub struct RegistryConfig {
171    /// Service discovery mechanism.
172    pub service_discovery: ServiceDiscovery,
173
174    /// Connection timeout when finding a service.
175    #[cfg_attr(feature = "serde", serde(with = "humantime_serde"))]
176    pub connect_timeout: Option<std::time::Duration>,
177
178    /// Buffer size for in-memory transports.
179    pub buffer_size: usize,
180
181    /// Proxy service timeout
182    #[cfg_attr(feature = "serde", serde(with = "humantime_serde"))]
183    pub proxy_timeout: std::time::Duration,
184
185    /// Proxy concurrency limit
186    pub proxy_limit: usize,
187}
188
189impl Default for RegistryConfig {
190    fn default() -> Self {
191        Self {
192            service_discovery: Default::default(),
193            connect_timeout: None,
194            buffer_size: 1024 * 1024,
195            proxy_timeout: std::time::Duration::from_secs(30),
196            proxy_limit: 32,
197        }
198    }
199}
200
201/// Maintains the set of available services, and the connection
202/// configurations for those services.
203#[derive(Clone, Default)]
204pub struct ServiceRegistry {
205    inner: Arc<InnerRegistry>,
206    config: Arc<RegistryConfig>,
207}
208
209impl std::fmt::Debug for ServiceRegistry {
210    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211        f.debug_struct("ServiceRegistry")
212            .field("config", &self.config)
213            .finish()
214    }
215}
216
217impl ServiceRegistry {
218    /// Create a new registry with default configuration.
219    pub fn new() -> Self {
220        Self {
221            inner: Arc::new(InnerRegistry::default()),
222            config: Arc::new(RegistryConfig::default()),
223        }
224    }
225
226    /// Create a new registry with the specified configuration.
227    pub fn new_with_config(config: RegistryConfig) -> Self {
228        Self {
229            inner: Arc::new(InnerRegistry::default()),
230            config: Arc::new(config),
231        }
232    }
233
234    #[inline]
235    fn config_mut(&mut self) -> &mut RegistryConfig {
236        Arc::make_mut(&mut self.config)
237    }
238
239    /// Set the service discovery mechanism.
240    pub fn set_discovery(&mut self, discovery: ServiceDiscovery) {
241        self.config_mut().service_discovery = discovery;
242    }
243
244    /// Set the connection timeout for finding a service.
245    pub fn set_connect_timeout(&mut self, timeout: std::time::Duration) {
246        self.config_mut().connect_timeout = Some(timeout);
247    }
248
249    /// Set the buffer size for in-memory transports.
250    ///
251    /// See [`hyperdriver::stream::duplex::DuplexClient::connect`] for more information.
252    pub fn set_buffer_size(&mut self, size: usize) {
253        self.config_mut().buffer_size = size;
254    }
255
256    /// Check if a service is available, by name.
257    pub fn is_available<S: AsRef<str>>(&self, service: S) -> bool {
258        self.inner.is_available(&self.config, service.as_ref())
259    }
260
261    /// Get an acceptor which will be bound to a service with this name.
262    #[tracing::instrument(skip_all, fields(service=tracing::field::Empty))]
263    pub async fn bind<'a, S>(
264        &'a self,
265        service: S,
266    ) -> Result<hyperdriver::server::conn::Acceptor, BindError>
267    where
268        S: Into<Cow<'a, str>>,
269    {
270        let name = service.into();
271        let span = tracing::Span::current();
272        span.record("service", name.as_ref());
273
274        self.inner
275            .bind(&self.config, &name)
276            .map_err(|err| BindError::new(name.into_owned(), err))
277    }
278
279    /// Create a server which will bind to a service by name.
280    pub async fn server<'a, S, M, B, E>(
281        &'a self,
282        make_service: M,
283        name: S,
284        executor: E,
285    ) -> Result<
286        hyperdriver::server::Server<
287            hyperdriver::server::conn::Acceptor,
288            AutoBuilder<TokioExecutor>,
289            M,
290            B,
291            E,
292        >,
293        BindError,
294    >
295    where
296        S: Into<Cow<'a, str>>,
297    {
298        let acceptor = self.bind(name.into()).await?;
299        Ok(hyperdriver::server::Server::builder()
300            .with_acceptor(acceptor)
301            .with_auto_http()
302            .with_make_service(make_service)
303            .with_executor(executor))
304    }
305
306    /// Create a server which will use a registry transport to proxy requests to services.
307    pub fn router<A, B, E>(
308        &self,
309        acceptor: A,
310        executor: E,
311    ) -> hyperdriver::server::Server<A, AutoBuilder<TokioExecutor>, Shared<Client>, B, E>
312    where
313        A: hyperdriver::server::conn::Accept,
314        B: http_body::Body,
315    {
316        hyperdriver::server::Server::builder()
317            .with_acceptor(acceptor)
318            .with_auto_http()
319            .with_shared_service(self.client())
320            .with_executor(executor)
321    }
322
323    /// Connect to a service by name.
324    ///
325    /// Prefer using `client` instead of this method.
326    #[tracing::instrument(skip_all, fields(service=tracing::field::Empty))]
327    pub async fn connect<'a, S: Into<Cow<'a, str>>>(
328        &'a self,
329        service: S,
330    ) -> Result<ClientStream, ConnectionError> {
331        let service = service.into();
332        let span = tracing::Span::current();
333        span.record("service", service.as_ref());
334
335        self.inner.connect(&self.config, service).await
336    }
337
338    /// Create a transport for internal services, with default schemes.
339    ///
340    /// The default schemes are `grpc` and `svc`. `svc` uses the host to determine the service, and `grpc` uses the
341    /// first path component, and is suitable for gRPC services.
342    pub fn default_transport(&self) -> transport::RegistryTransport {
343        transport::RegistryTransport::with_default_schemes(self.clone())
344    }
345
346    /// Create a transport builder for internal services.
347    pub fn transport_builder(&self) -> transport::TransportBuilder {
348        transport::RegistryTransport::builder(self.clone())
349    }
350
351    /// Create a client which will connect to internal services.
352    pub fn client(&self) -> Client {
353        let transport = self.default_transport();
354
355        Client::builder()
356            .with_transport(transport)
357            .with_protocol(auto::HttpConnectionBuilder::default())
358            .with_pool(Default::default())
359            .without_tls()
360            .build()
361    }
362}
363
364/// Maintains the set of available services
365#[derive(Debug)]
366struct InnerRegistry {
367    services: DashMap<String, ServiceHandle>,
368}
369
370impl Default for InnerRegistry {
371    fn default() -> Self {
372        Self {
373            services: DashMap::new(),
374        }
375    }
376}
377
378impl InnerRegistry {
379    fn get_mut(&self, config: &RegistryConfig, service: &str) -> RefMut<'_, String, ServiceHandle> {
380        self.services
381            .entry(service.to_owned())
382            .or_insert_with(|| match &config.service_discovery {
383                ServiceDiscovery::InProcess => ServiceHandle::duplex(),
384                ServiceDiscovery::Unix { path } => ServiceHandle::unix(path, service),
385            })
386    }
387
388    fn get(&self, config: &RegistryConfig, service: &str) -> Ref<'_, String, ServiceHandle> {
389        if let Some(handle) = self.services.get(service) {
390            handle
391        } else {
392            self.get_mut(config, service).downgrade()
393        }
394    }
395
396    fn is_available(&self, config: &RegistryConfig, service: &str) -> bool {
397        let handle = self.get(config, service);
398        handle.is_available()
399    }
400
401    /// Connect to a service by name.
402    #[tracing::instrument(skip(self, config))]
403    async fn connect(
404        &self,
405        config: &RegistryConfig,
406        service: Cow<'_, str>,
407    ) -> Result<ClientStream, ConnectionError> {
408        let handle = self.get(config, service.as_ref());
409
410        connect_to_handle(config, handle.value(), service).await
411    }
412
413    /// Bind to a service by name.
414    fn bind(
415        &self,
416        config: &RegistryConfig,
417        service: &str,
418    ) -> Result<hyperdriver::server::conn::Acceptor, InternalBindError> {
419        let mut handle = self.get_mut(config, service);
420
421        handle.acceptor()
422    }
423}
424
425/// Represents a discovered service which uses a PID file to lock binding the service.
426#[derive(Debug)]
427enum PidLock {
428    Path(Utf8PathBuf),
429
430    #[allow(dead_code)]
431    Lock(PidFile),
432}
433
434impl PidLock {
435    fn is_available(&self) -> bool {
436        tracing::trace!("Checking PID file {self:?}");
437        match self {
438            PidLock::Path(path) => PidFile::is_locked(path.as_std_path())
439                .map_err(|error| tracing::warn!("Unable to inspect PID file: {error:?}"))
440                .unwrap_or(false),
441            PidLock::Lock(_) => true,
442        }
443    }
444}
445
446/// Handle to a service for creating new connections
447///
448/// This is the type held internally by the registry for a service.
449#[derive(Debug)]
450enum ServiceHandle {
451    Duplex {
452        acceptor: Option<hyperdriver::server::conn::Acceptor>,
453        connector: hyperdriver::stream::duplex::DuplexClient,
454    },
455    Unix {
456        path: Utf8PathBuf,
457        pidfile: PidLock,
458    },
459}
460
461impl ServiceHandle {
462    fn duplex() -> Self {
463        let (connector, acceptor) = hyperdriver::stream::duplex::pair();
464        Self::Duplex {
465            acceptor: Some(acceptor.into()),
466            connector,
467        }
468    }
469
470    fn unix(path: &Utf8Path, service: &str) -> Self {
471        let svcpath = path.join(format!("{service}.svc"));
472        let pidfile = path.join(format!("{service}.pid"));
473
474        Self::Unix {
475            path: svcpath,
476            pidfile: PidLock::Path(pidfile),
477        }
478    }
479
480    fn is_available(&self) -> bool {
481        match self {
482            ServiceHandle::Duplex { acceptor, .. } => acceptor.is_none(),
483            ServiceHandle::Unix { pidfile, .. } => pidfile.is_available(),
484        }
485    }
486
487    async fn connect(
488        &self,
489        config: &RegistryConfig,
490        name: Cow<'_, str>,
491    ) -> Result<hyperdriver::client::conn::Stream, ConnectionError> {
492        match self {
493            ServiceHandle::Duplex { connector, .. } => Ok(connector
494                .connect(config.buffer_size)
495                .await
496                .map(|stream| stream.into())
497                .map_err(|error| ConnectionError::Duplex {
498                    error,
499                    name: name.into_owned(),
500                }))?,
501            ServiceHandle::Unix { path, .. } => tokio::net::UnixStream::connect(path)
502                .await
503                .map(|stream| {
504                    UnixStream::new(stream, Some(UnixAddr::from_pathbuf(path.clone()))).into()
505                })
506                .map_err(|error| ConnectionError::Unix {
507                    error,
508                    path: path.into(),
509                    name: name.into_owned(),
510                }),
511        }
512    }
513
514    /// Create an acceptor for this service.
515    fn acceptor(&mut self) -> Result<hyperdriver::server::conn::Acceptor, InternalBindError> {
516        match self {
517            ServiceHandle::Duplex { acceptor, .. } => {
518                tracing::trace!("Preparing in-process acceptor");
519                acceptor.take().ok_or(InternalBindError::AlreadyBound)
520            }
521            ServiceHandle::Unix { ref path, pidfile } => {
522                tracing::trace!("Locking PID file");
523                let file = match pidfile {
524                    PidLock::Path(ref path) => PidFile::new(path.clone()).map_err(|err| {
525                        tracing::warn!(
526                            "Encountered an error resetting the Pid file {path}: {}",
527                            err
528                        );
529                        InternalBindError::PidLockError(path.clone(), err)
530                    })?,
531                    PidLock::Lock(_) => {
532                        tracing::warn!("Service is already bound in this process");
533                        return Err(InternalBindError::AlreadyBound);
534                    }
535                };
536                *pidfile = PidLock::Lock(file);
537
538                tracing::trace!("Binding to socket at {path}");
539                if let Err(error) = std::fs::remove_file(path) {
540                    match error.kind() {
541                        io::ErrorKind::NotFound => {}
542                        _ => {
543                            tracing::error!("Unable to remove socket: {:#}", error);
544                            return Err(InternalBindError::SocketResetError(path.clone(), error));
545                        }
546                    }
547                }
548
549                tokio::net::UnixListener::bind(path)
550                    .map(|listener| listener.into())
551                    .map_err(|error| match error.kind() {
552                        io::ErrorKind::AddrInUse => {
553                            tracing::warn!("Service is already bound");
554                            InternalBindError::AlreadyBound
555                        }
556                        _ => {
557                            tracing::error!("Unable to bind socket: {:#}", error);
558                            InternalBindError::SocketResetError(path.clone(), error)
559                        }
560                    })
561            }
562        }
563    }
564}
565
566async fn connect_to_handle(
567    config: &RegistryConfig,
568    handle: &ServiceHandle,
569    name: Cow<'_, str>,
570) -> Result<ClientStream, ConnectionError> {
571    let request = handle.connect(config, name.clone());
572
573    let stream = if let Some(timeout) = &config.connect_timeout {
574        tracing::trace!("Waiting for connection to {name} with timeout");
575        match tokio::time::timeout(*timeout, request).await {
576            Ok(outcome) => outcome,
577            Err(elapsed) => {
578                tracing::warn!(
579                    "Connection to {name} timed out after {timeout:?}",
580                    name = name,
581                    timeout = elapsed
582                );
583                return Err(ConnectionError::ConnectionTimeout(
584                    name.into_owned(),
585                    elapsed,
586                ));
587            }
588        }
589    } else {
590        tracing::trace!("Waiting for connection to {name} without timeout");
591
592        // Pin the request future so it can be polled in two places: once during
593        // the timeout, and once after the timeout.
594        tokio::pin!(request);
595
596        // Apply a default timeout so we can warn when a service is taking a long time
597        let default_timeout = std::time::Duration::from_secs(30);
598        match tokio::time::timeout(default_timeout, &mut request).await {
599            Ok(Ok(stream)) => Ok(stream),
600            Err(_) => {
601                tracing::warn!(
602                    "Waited {}s without a timeout for connection to {name}... continuing",
603                    default_timeout.as_secs()
604                );
605                request.await
606            }
607            Ok(Err(error)) => Err(error),
608        }
609    }?;
610
611    Ok(stream)
612}
613
614#[cfg(test)]
615mod tests {
616    use hyperdriver::info::{BraidAddr, HasConnectionInfo as _};
617
618    use super::*;
619
620    #[test]
621    fn test_service_handle() {
622        let tmp = tempfile::tempdir().unwrap();
623        let name = "service.with.dots";
624        let handle = ServiceHandle::unix(tmp.path().try_into().unwrap(), name);
625
626        assert!(!handle.is_available());
627
628        let ServiceHandle::Unix { path, pidfile } = handle else {
629            panic!("expected unix handle")
630        };
631
632        let expected =
633            Utf8PathBuf::from_path_buf(tmp.path().join(format!("{}.svc", name))).unwrap();
634
635        assert_eq!(path, expected);
636        assert!(matches!(pidfile, PidLock::Path(_)));
637    }
638
639    #[tokio::test]
640    async fn connect_to_handle_unix() {
641        let tmp = tempfile::tempdir().unwrap();
642        let name = "service.with.dots";
643        let mut handle = ServiceHandle::unix(tmp.path().try_into().unwrap(), name);
644
645        let _accept = handle.acceptor().unwrap();
646
647        let config = RegistryConfig::default();
648        let name = Cow::Borrowed(name);
649
650        let stream = connect_to_handle(&config, &handle, name.clone())
651            .await
652            .unwrap();
653
654        let info = stream.info();
655        let remote = info.remote_addr();
656        match remote {
657            BraidAddr::Unix(addr) => {
658                assert_eq!(
659                    addr.path().unwrap(),
660                    tmp.path().join(format!("{}.svc", name))
661                );
662            }
663            _ => panic!("expected Unix address"),
664        }
665    }
666
667    #[tokio::test]
668    async fn connect_to_handle_unix_error() {
669        let tmp = tempfile::tempdir().unwrap();
670        let name = "service.with.dots";
671        let handle = ServiceHandle::unix(tmp.path().try_into().unwrap(), name);
672
673        let config = RegistryConfig::default();
674        let name = Cow::Borrowed(name);
675
676        let result = connect_to_handle(&config, &handle, name).await;
677
678        match result.unwrap_err() {
679            ConnectionError::Unix { error, path, name } => {
680                assert_eq!(error.kind(), io::ErrorKind::NotFound);
681                assert_eq!(path, tmp.path().join(format!("{}.svc", name)));
682                assert_eq!(name, name);
683            }
684            _ => panic!("expected Unix error"),
685        }
686    }
687}