use std::borrow::Cow;
use std::collections::BTreeMap;
use std::fmt;
use std::sync::Arc;
use camino::Utf8Path;
use futures_util::future::BoxFuture;
use futures_util::FutureExt as _;
use hyper::Uri;
use super::ConnectionError;
use super::ServiceRegistry;
use hyperdriver::client::conn::Stream as ClientStream;
pub use builder::TransportBuilder;
pub trait Scheme {
fn scheme(&self) -> Cow<'_, str>;
fn service<'u>(&self, uri: &'u Uri) -> Option<&'u str>;
}
type BoxScheme = Box<dyn Scheme + Sync + Send + 'static>;
#[derive(Clone)]
pub struct RegistryTransport {
registry: ServiceRegistry,
schemes: Arc<BTreeMap<String, BoxScheme>>,
}
struct SchemesDebug<'a, S>(&'a BTreeMap<String, S>);
impl<'a, S> fmt::Debug for SchemesDebug<'a, S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut dbg = f.debug_list();
for scheme in self.0.keys() {
dbg.entry(&scheme);
}
dbg.finish()
}
}
impl fmt::Debug for RegistryTransport {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut dbg = f.debug_struct("RegistryTransport");
dbg.field("registry", &self.registry);
dbg.field("schemes", &SchemesDebug(&self.schemes)).finish()
}
}
impl RegistryTransport {
pub fn builder(registry: ServiceRegistry) -> builder::TransportBuilder {
builder::TransportBuilder {
registry,
schemes: BTreeMap::new(),
}
}
pub fn with_default_schemes(registry: ServiceRegistry) -> Self {
Self::builder(registry).add_default_schemes().build()
}
pub fn registry(&self) -> &ServiceRegistry {
&self.registry
}
}
impl tower::Service<http::request::Parts> for RegistryTransport {
type Response = ClientStream;
type Error = ConnectionError;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn call(&mut self, req: http::request::Parts) -> Self::Future {
match req
.uri
.scheme_str()
.and_then(|scheme| self.schemes.get(scheme))
{
Some(scheme) => {
let registry = self.registry.clone();
let service = scheme.service(&req.uri).map(|s| s.to_owned());
if let Some(service) = service {
(async move { registry.connect(&service).await }).boxed()
} else {
futures_util::future::ready(Err(ConnectionError::InvalidUri(req.uri))).boxed()
}
}
None => futures_util::future::ready(Err(ConnectionError::InvalidUri(req.uri))).boxed(),
}
}
}
mod builder {
use std::collections::BTreeMap;
use std::fmt;
use std::sync::Arc;
use super::BoxScheme;
use super::RegistryTransport;
use super::Scheme;
use super::SchemesDebug;
use crate::ServiceRegistry;
pub struct TransportBuilder {
pub(crate) registry: ServiceRegistry,
pub(crate) schemes: BTreeMap<String, BoxScheme>,
}
impl TransportBuilder {
pub fn add_scheme<S>(mut self, scheme: S) -> Self
where
S: Scheme + Send + Sync + 'static,
{
self.schemes
.insert(scheme.scheme().into(), Box::new(scheme));
self
}
pub fn add_default_schemes(self) -> Self {
self.add_scheme(super::SvcScheme::default())
.add_scheme(super::GrpcScheme::default())
}
pub fn build(self) -> RegistryTransport {
RegistryTransport {
registry: self.registry,
schemes: Arc::new(self.schemes),
}
}
}
impl fmt::Debug for TransportBuilder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut dbg = f.debug_struct("TransportBuilder");
dbg.field("registry", &self.registry);
dbg.field("schemes", &SchemesDebug(&self.schemes)).finish()
}
}
}
#[derive(Debug)]
pub struct SvcScheme(String);
impl SvcScheme {
pub fn new<S>(scheme: S) -> Self
where
S: Into<String>,
{
SvcScheme(scheme.into())
}
}
impl Default for SvcScheme {
fn default() -> Self {
SvcScheme("svc".into())
}
}
impl Scheme for SvcScheme {
fn scheme(&self) -> Cow<'_, str> {
self.0.clone().into()
}
fn service<'u>(&self, uri: &'u Uri) -> Option<&'u str> {
uri.host()
}
}
#[derive(Debug)]
pub struct GrpcScheme(String);
impl GrpcScheme {
pub fn new<S>(scheme: S) -> Self
where
S: Into<String>,
{
GrpcScheme(scheme.into())
}
}
impl Default for GrpcScheme {
fn default() -> Self {
GrpcScheme("grpc".into())
}
}
impl Scheme for GrpcScheme {
fn scheme(&self) -> Cow<'_, str> {
self.0.clone().into()
}
fn service<'u>(&self, uri: &'u Uri) -> Option<&'u str> {
let path = Utf8Path::new(uri.path());
path.components().nth(1).map(|c| c.as_str())
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use tower::{make::Shared, Service, ServiceExt};
use crate::{ConnectionError, ServiceDiscovery};
use hyperdriver::info::HasConnectionInfo;
use hyperdriver::IntoRequestParts as _;
use hyperdriver::{body::Body, info::BraidAddr};
use static_assertions::assert_impl_one;
use super::*;
type BoxError = Box<dyn std::error::Error + Sync + std::marker::Send + 'static>;
assert_impl_one!(RegistryTransport: hyperdriver::client::conn::Transport);
#[test]
fn test_svc_scheme() {
let scheme = SvcScheme::default();
let uri = "svc://service".parse().unwrap();
assert_eq!(scheme.service(&uri), Some("service"));
}
#[test]
fn test_grpc_scheme() {
let scheme = GrpcScheme::default();
let uri = "grpc://host/service/method".parse().unwrap();
assert_eq!(scheme.service(&uri), Some("service"));
}
#[derive(Debug, Clone)]
struct Svc;
impl Service<http::Request<Body>> for Svc {
type Response = http::Response<Body>;
type Error = BoxError;
type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn call(&mut self, _req: http::Request<Body>) -> Self::Future {
let res = http::Response::new(Body::empty());
std::future::ready(Ok(res))
}
}
#[tokio::test]
async fn test_transport() {
let registry = ServiceRegistry::new();
let svc = Svc;
let server = registry
.server(
Shared::new(svc),
"service",
hyperdriver::bridge::rt::TokioExecutor::new(),
)
.await
.unwrap();
let (tx, rx) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
server
.with_graceful_shutdown(async move { rx.await.unwrap() })
.await
.unwrap();
});
let transport = RegistryTransport::with_default_schemes(registry);
let uri = "svc://service".into_request_parts();
let stream = transport.oneshot(uri).await.unwrap();
let info = stream.info();
assert_eq!(info.remote_addr(), &BraidAddr::Duplex);
tx.send(()).unwrap();
}
#[tokio::test]
async fn test_transport_not_found() {
let mut registry = ServiceRegistry::new();
registry.config_mut().connect_timeout = Some(Duration::ZERO);
let transport = RegistryTransport::with_default_schemes(registry);
let uri = "svc://service".into_request_parts();
let err = transport.oneshot(uri).await.unwrap_err();
assert!(matches!(err, ConnectionError::ConnectionTimeout(_, _)));
}
#[tokio::test]
async fn test_transport_unix_not_found() {
let tmp = tempfile::tempdir().unwrap();
let mut registry = ServiceRegistry::new();
let cfg = registry.config_mut();
cfg.connect_timeout = Some(Duration::ZERO);
cfg.service_discovery = ServiceDiscovery::Unix {
path: Utf8Path::from_path(tmp.path()).unwrap().to_owned(),
};
let transport = RegistryTransport::with_default_schemes(registry);
let uri = "svc://service".into_request_parts();
let err = transport.oneshot(uri).await.unwrap_err();
match err {
ConnectionError::Unix { error, path, name } => {
assert_eq!(error.kind(), std::io::ErrorKind::NotFound);
assert_eq!(path, tmp.path().join(format!("{}.svc", name)));
assert_eq!(name, "service");
}
_ => panic!("unexpected error: {:?}", err),
}
}
}