1#![warn(missing_docs)]
6#![warn(missing_debug_implementations)]
7#![deny(unsafe_code)]
8
9use core::fmt;
10use std::borrow::Cow;
11use std::io;
12use std::sync::Arc;
13
14use hyperdriver::bridge::rt::TokioExecutor;
15use hyperdriver::client::conn::protocol::auto;
16
17use hyperdriver::client::conn::Stream as ClientStream;
18use hyperdriver::info::UnixAddr;
19use hyperdriver::server::AutoBuilder;
20use hyperdriver::stream::UnixStream;
21use pidfile::PidFile;
22
23use camino::{Utf8Path, Utf8PathBuf};
24use dashmap::mapref::one::{Ref, RefMut};
25use dashmap::DashMap;
26use hyper::Uri;
27use tower::make::Shared;
28
29mod transport;
30
31use hyperdriver::client::Client;
32pub use transport::GrpcScheme;
33pub use transport::RegistryTransport;
34pub use transport::Scheme;
35pub use transport::SvcScheme;
36pub use transport::TransportBuilder;
37
38#[derive(Debug, thiserror::Error)]
40pub enum ConnectionError {
41 #[error("Invalid name: {0}")]
43 InvalidName(String),
44
45 #[error("Connection to {0} timed out")]
47 ConnectionTimeout(String, #[source] tokio::time::error::Elapsed),
48
49 #[error("Invalid URI: {0}")]
51 InvalidUri(Uri),
52
53 #[error("Handshake with {name}")]
55 Handshake {
56 #[source]
58 error: io::Error,
59
60 name: String,
62 },
63
64 #[error("Error {} connecting to {name} over a duplex socket", .error.kind())]
66 Duplex {
67 #[source]
69 error: io::Error,
70
71 name: String,
73 },
74
75 #[error("Error {} connecting to {name} at {path}", .error.kind())]
77 Unix {
78 #[source]
80 error: io::Error,
81
82 path: Utf8PathBuf,
84
85 name: String,
87 },
88}
89
90#[derive(Debug)]
95pub(crate) enum InternalBindError {
96 AlreadyBound,
97
98 SocketResetError(Utf8PathBuf, io::Error),
99
100 PidLockError(Utf8PathBuf, io::Error),
101}
102
103#[derive(Debug)]
105pub struct BindError {
106 service: String,
107 inner: InternalBindError,
108}
109
110impl BindError {
111 fn new(service: String, inner: InternalBindError) -> Self {
112 Self { service, inner }
113 }
114}
115
116impl fmt::Display for BindError {
117 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
118 match &self.inner {
119 InternalBindError::AlreadyBound => {
120 write!(f, "Service {} is already bound", self.service)
121 }
122 InternalBindError::SocketResetError(path, error) => {
123 write!(
124 f,
125 "Service {}: Unable to reset socket at {}: {}",
126 self.service, path, error
127 )
128 }
129 InternalBindError::PidLockError(path, error) => {
130 write!(
131 f,
132 "Service {}: Unable to lock PID file at {}: {}",
133 self.service, path, error
134 )
135 }
136 }
137 }
138}
139
140impl std::error::Error for BindError {
141 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
142 match &self.inner {
143 InternalBindError::AlreadyBound => None,
144 InternalBindError::SocketResetError(_, error) => Some(error),
145 InternalBindError::PidLockError(_, error) => Some(error),
146 }
147 }
148}
149
150#[derive(Debug, Clone, Default)]
152#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
153#[cfg_attr(feature = "serde", serde(rename_all = "lowercase"))]
154pub enum ServiceDiscovery {
155 #[default]
157 InProcess,
158
159 Unix {
161 path: Utf8PathBuf,
163 },
164}
165
166#[derive(Debug, Clone)]
168#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
169#[cfg_attr(feature = "serde", serde(default))]
170pub struct RegistryConfig {
171 pub service_discovery: ServiceDiscovery,
173
174 #[cfg_attr(feature = "serde", serde(with = "humantime_serde"))]
176 pub connect_timeout: Option<std::time::Duration>,
177
178 pub buffer_size: usize,
180
181 #[cfg_attr(feature = "serde", serde(with = "humantime_serde"))]
183 pub proxy_timeout: std::time::Duration,
184
185 pub proxy_limit: usize,
187}
188
189impl Default for RegistryConfig {
190 fn default() -> Self {
191 Self {
192 service_discovery: Default::default(),
193 connect_timeout: None,
194 buffer_size: 1024 * 1024,
195 proxy_timeout: std::time::Duration::from_secs(30),
196 proxy_limit: 32,
197 }
198 }
199}
200
201#[derive(Clone, Default)]
204pub struct ServiceRegistry {
205 inner: Arc<InnerRegistry>,
206 config: Arc<RegistryConfig>,
207}
208
209impl std::fmt::Debug for ServiceRegistry {
210 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211 f.debug_struct("ServiceRegistry")
212 .field("config", &self.config)
213 .finish()
214 }
215}
216
217impl ServiceRegistry {
218 pub fn new() -> Self {
220 Self {
221 inner: Arc::new(InnerRegistry::default()),
222 config: Arc::new(RegistryConfig::default()),
223 }
224 }
225
226 pub fn new_with_config(config: RegistryConfig) -> Self {
228 Self {
229 inner: Arc::new(InnerRegistry::default()),
230 config: Arc::new(config),
231 }
232 }
233
234 #[inline]
235 fn config_mut(&mut self) -> &mut RegistryConfig {
236 Arc::make_mut(&mut self.config)
237 }
238
239 pub fn set_discovery(&mut self, discovery: ServiceDiscovery) {
241 self.config_mut().service_discovery = discovery;
242 }
243
244 pub fn set_connect_timeout(&mut self, timeout: std::time::Duration) {
246 self.config_mut().connect_timeout = Some(timeout);
247 }
248
249 pub fn set_buffer_size(&mut self, size: usize) {
253 self.config_mut().buffer_size = size;
254 }
255
256 pub fn is_available<S: AsRef<str>>(&self, service: S) -> bool {
258 self.inner.is_available(&self.config, service.as_ref())
259 }
260
261 #[tracing::instrument(skip_all, fields(service=tracing::field::Empty))]
263 pub async fn bind<'a, S>(
264 &'a self,
265 service: S,
266 ) -> Result<hyperdriver::server::conn::Acceptor, BindError>
267 where
268 S: Into<Cow<'a, str>>,
269 {
270 let name = service.into();
271 let span = tracing::Span::current();
272 span.record("service", name.as_ref());
273
274 self.inner
275 .bind(&self.config, &name)
276 .map_err(|err| BindError::new(name.into_owned(), err))
277 }
278
279 pub async fn server<'a, S, M, B, E>(
281 &'a self,
282 make_service: M,
283 name: S,
284 executor: E,
285 ) -> Result<
286 hyperdriver::server::Server<
287 hyperdriver::server::conn::Acceptor,
288 AutoBuilder<TokioExecutor>,
289 M,
290 B,
291 E,
292 >,
293 BindError,
294 >
295 where
296 S: Into<Cow<'a, str>>,
297 {
298 let acceptor = self.bind(name.into()).await?;
299 Ok(hyperdriver::server::Server::builder()
300 .with_acceptor(acceptor)
301 .with_auto_http()
302 .with_make_service(make_service)
303 .with_executor(executor))
304 }
305
306 pub fn router<A, B, E>(
308 &self,
309 acceptor: A,
310 executor: E,
311 ) -> hyperdriver::server::Server<A, AutoBuilder<TokioExecutor>, Shared<Client>, B, E>
312 where
313 A: hyperdriver::server::conn::Accept,
314 B: http_body::Body,
315 {
316 hyperdriver::server::Server::builder()
317 .with_acceptor(acceptor)
318 .with_auto_http()
319 .with_shared_service(self.client())
320 .with_executor(executor)
321 }
322
323 #[tracing::instrument(skip_all, fields(service=tracing::field::Empty))]
327 pub async fn connect<'a, S: Into<Cow<'a, str>>>(
328 &'a self,
329 service: S,
330 ) -> Result<ClientStream, ConnectionError> {
331 let service = service.into();
332 let span = tracing::Span::current();
333 span.record("service", service.as_ref());
334
335 self.inner.connect(&self.config, service).await
336 }
337
338 pub fn default_transport(&self) -> transport::RegistryTransport {
343 transport::RegistryTransport::with_default_schemes(self.clone())
344 }
345
346 pub fn transport_builder(&self) -> transport::TransportBuilder {
348 transport::RegistryTransport::builder(self.clone())
349 }
350
351 pub fn client(&self) -> Client {
353 let transport = self.default_transport();
354
355 Client::builder()
356 .with_transport(transport)
357 .with_protocol(auto::HttpConnectionBuilder::default())
358 .with_pool(Default::default())
359 .without_tls()
360 .build()
361 }
362}
363
364#[derive(Debug)]
366struct InnerRegistry {
367 services: DashMap<String, ServiceHandle>,
368}
369
370impl Default for InnerRegistry {
371 fn default() -> Self {
372 Self {
373 services: DashMap::new(),
374 }
375 }
376}
377
378impl InnerRegistry {
379 fn get_mut(&self, config: &RegistryConfig, service: &str) -> RefMut<'_, String, ServiceHandle> {
380 self.services
381 .entry(service.to_owned())
382 .or_insert_with(|| match &config.service_discovery {
383 ServiceDiscovery::InProcess => ServiceHandle::duplex(),
384 ServiceDiscovery::Unix { path } => ServiceHandle::unix(path, service),
385 })
386 }
387
388 fn get(&self, config: &RegistryConfig, service: &str) -> Ref<'_, String, ServiceHandle> {
389 if let Some(handle) = self.services.get(service) {
390 handle
391 } else {
392 self.get_mut(config, service).downgrade()
393 }
394 }
395
396 fn is_available(&self, config: &RegistryConfig, service: &str) -> bool {
397 let handle = self.get(config, service);
398 handle.is_available()
399 }
400
401 #[tracing::instrument(skip(self, config))]
403 async fn connect(
404 &self,
405 config: &RegistryConfig,
406 service: Cow<'_, str>,
407 ) -> Result<ClientStream, ConnectionError> {
408 let handle = self.get(config, service.as_ref());
409
410 connect_to_handle(config, handle.value(), service).await
411 }
412
413 fn bind(
415 &self,
416 config: &RegistryConfig,
417 service: &str,
418 ) -> Result<hyperdriver::server::conn::Acceptor, InternalBindError> {
419 let mut handle = self.get_mut(config, service);
420
421 handle.acceptor()
422 }
423}
424
425#[derive(Debug)]
427enum PidLock {
428 Path(Utf8PathBuf),
429
430 #[allow(dead_code)]
431 Lock(PidFile),
432}
433
434impl PidLock {
435 fn is_available(&self) -> bool {
436 tracing::trace!("Checking PID file {self:?}");
437 match self {
438 PidLock::Path(path) => PidFile::is_locked(path.as_std_path())
439 .map_err(|error| tracing::warn!("Unable to inspect PID file: {error:?}"))
440 .unwrap_or(false),
441 PidLock::Lock(_) => true,
442 }
443 }
444}
445
446#[derive(Debug)]
450enum ServiceHandle {
451 Duplex {
452 acceptor: Option<hyperdriver::server::conn::Acceptor>,
453 connector: hyperdriver::stream::duplex::DuplexClient,
454 },
455 Unix {
456 path: Utf8PathBuf,
457 pidfile: PidLock,
458 },
459}
460
461impl ServiceHandle {
462 fn duplex() -> Self {
463 let (connector, acceptor) = hyperdriver::stream::duplex::pair();
464 Self::Duplex {
465 acceptor: Some(acceptor.into()),
466 connector,
467 }
468 }
469
470 fn unix(path: &Utf8Path, service: &str) -> Self {
471 let svcpath = path.join(format!("{service}.svc"));
472 let pidfile = path.join(format!("{service}.pid"));
473
474 Self::Unix {
475 path: svcpath,
476 pidfile: PidLock::Path(pidfile),
477 }
478 }
479
480 fn is_available(&self) -> bool {
481 match self {
482 ServiceHandle::Duplex { acceptor, .. } => acceptor.is_none(),
483 ServiceHandle::Unix { pidfile, .. } => pidfile.is_available(),
484 }
485 }
486
487 async fn connect(
488 &self,
489 config: &RegistryConfig,
490 name: Cow<'_, str>,
491 ) -> Result<hyperdriver::client::conn::Stream, ConnectionError> {
492 match self {
493 ServiceHandle::Duplex { connector, .. } => Ok(connector
494 .connect(config.buffer_size)
495 .await
496 .map(|stream| stream.into())
497 .map_err(|error| ConnectionError::Duplex {
498 error,
499 name: name.into_owned(),
500 }))?,
501 ServiceHandle::Unix { path, .. } => tokio::net::UnixStream::connect(path)
502 .await
503 .map(|stream| {
504 UnixStream::new(stream, Some(UnixAddr::from_pathbuf(path.clone()))).into()
505 })
506 .map_err(|error| ConnectionError::Unix {
507 error,
508 path: path.into(),
509 name: name.into_owned(),
510 }),
511 }
512 }
513
514 fn acceptor(&mut self) -> Result<hyperdriver::server::conn::Acceptor, InternalBindError> {
516 match self {
517 ServiceHandle::Duplex { acceptor, .. } => {
518 tracing::trace!("Preparing in-process acceptor");
519 acceptor.take().ok_or(InternalBindError::AlreadyBound)
520 }
521 ServiceHandle::Unix { ref path, pidfile } => {
522 tracing::trace!("Locking PID file");
523 let file = match pidfile {
524 PidLock::Path(ref path) => PidFile::new(path.clone()).map_err(|err| {
525 tracing::warn!(
526 "Encountered an error resetting the Pid file {path}: {}",
527 err
528 );
529 InternalBindError::PidLockError(path.clone(), err)
530 })?,
531 PidLock::Lock(_) => {
532 tracing::warn!("Service is already bound in this process");
533 return Err(InternalBindError::AlreadyBound);
534 }
535 };
536 *pidfile = PidLock::Lock(file);
537
538 tracing::trace!("Binding to socket at {path}");
539 if let Err(error) = std::fs::remove_file(path) {
540 match error.kind() {
541 io::ErrorKind::NotFound => {}
542 _ => {
543 tracing::error!("Unable to remove socket: {:#}", error);
544 return Err(InternalBindError::SocketResetError(path.clone(), error));
545 }
546 }
547 }
548
549 tokio::net::UnixListener::bind(path)
550 .map(|listener| listener.into())
551 .map_err(|error| match error.kind() {
552 io::ErrorKind::AddrInUse => {
553 tracing::warn!("Service is already bound");
554 InternalBindError::AlreadyBound
555 }
556 _ => {
557 tracing::error!("Unable to bind socket: {:#}", error);
558 InternalBindError::SocketResetError(path.clone(), error)
559 }
560 })
561 }
562 }
563 }
564}
565
566async fn connect_to_handle(
567 config: &RegistryConfig,
568 handle: &ServiceHandle,
569 name: Cow<'_, str>,
570) -> Result<ClientStream, ConnectionError> {
571 let request = handle.connect(config, name.clone());
572
573 let stream = if let Some(timeout) = &config.connect_timeout {
574 tracing::trace!("Waiting for connection to {name} with timeout");
575 match tokio::time::timeout(*timeout, request).await {
576 Ok(outcome) => outcome,
577 Err(elapsed) => {
578 tracing::warn!(
579 "Connection to {name} timed out after {timeout:?}",
580 name = name,
581 timeout = elapsed
582 );
583 return Err(ConnectionError::ConnectionTimeout(
584 name.into_owned(),
585 elapsed,
586 ));
587 }
588 }
589 } else {
590 tracing::trace!("Waiting for connection to {name} without timeout");
591
592 tokio::pin!(request);
595
596 let default_timeout = std::time::Duration::from_secs(30);
598 match tokio::time::timeout(default_timeout, &mut request).await {
599 Ok(Ok(stream)) => Ok(stream),
600 Err(_) => {
601 tracing::warn!(
602 "Waited {}s without a timeout for connection to {name}... continuing",
603 default_timeout.as_secs()
604 );
605 request.await
606 }
607 Ok(Err(error)) => Err(error),
608 }
609 }?;
610
611 Ok(stream)
612}
613
614#[cfg(test)]
615mod tests {
616 use hyperdriver::info::{BraidAddr, HasConnectionInfo as _};
617
618 use super::*;
619
620 #[test]
621 fn test_service_handle() {
622 let tmp = tempfile::tempdir().unwrap();
623 let name = "service.with.dots";
624 let handle = ServiceHandle::unix(tmp.path().try_into().unwrap(), name);
625
626 assert!(!handle.is_available());
627
628 let ServiceHandle::Unix { path, pidfile } = handle else {
629 panic!("expected unix handle")
630 };
631
632 let expected =
633 Utf8PathBuf::from_path_buf(tmp.path().join(format!("{}.svc", name))).unwrap();
634
635 assert_eq!(path, expected);
636 assert!(matches!(pidfile, PidLock::Path(_)));
637 }
638
639 #[tokio::test]
640 async fn connect_to_handle_unix() {
641 let tmp = tempfile::tempdir().unwrap();
642 let name = "service.with.dots";
643 let mut handle = ServiceHandle::unix(tmp.path().try_into().unwrap(), name);
644
645 let _accept = handle.acceptor().unwrap();
646
647 let config = RegistryConfig::default();
648 let name = Cow::Borrowed(name);
649
650 let stream = connect_to_handle(&config, &handle, name.clone())
651 .await
652 .unwrap();
653
654 let info = stream.info();
655 let remote = info.remote_addr();
656 match remote {
657 BraidAddr::Unix(addr) => {
658 assert_eq!(
659 addr.path().unwrap(),
660 tmp.path().join(format!("{}.svc", name))
661 );
662 }
663 _ => panic!("expected Unix address"),
664 }
665 }
666
667 #[tokio::test]
668 async fn connect_to_handle_unix_error() {
669 let tmp = tempfile::tempdir().unwrap();
670 let name = "service.with.dots";
671 let handle = ServiceHandle::unix(tmp.path().try_into().unwrap(), name);
672
673 let config = RegistryConfig::default();
674 let name = Cow::Borrowed(name);
675
676 let result = connect_to_handle(&config, &handle, name).await;
677
678 match result.unwrap_err() {
679 ConnectionError::Unix { error, path, name } => {
680 assert_eq!(error.kind(), io::ErrorKind::NotFound);
681 assert_eq!(path, tmp.path().join(format!("{}.svc", name)));
682 assert_eq!(name, name);
683 }
684 _ => panic!("expected Unix error"),
685 }
686 }
687}