use crate::{
service_probe::{GrpcServiceProbe, GrpcServiceProbeConfig},
DnsResolver, LookupService, ServiceDefinition,
};
use anyhow::Context as _;
use http::Request;
use std::{
convert::TryInto,
net::SocketAddr,
task::{Context, Poll},
};
use tokio::time::Duration;
use tonic::transport::channel::Channel;
use tonic::transport::ClientTlsConfig;
use tonic::{body::Body, client::GrpcService};
use tower::Service;
static GRPC_REPORT_ENDPOINTS_CHANNEL_SIZE: usize = 1024;
#[derive(Debug, Clone)]
pub struct LoadBalancedChannel(Channel);
impl From<LoadBalancedChannel> for Channel {
fn from(channel: LoadBalancedChannel) -> Self {
channel.0
}
}
impl LoadBalancedChannel {
pub fn builder<S>(service_definition: S) -> LoadBalancedChannelBuilder<DnsResolver, S>
where
S: TryInto<ServiceDefinition> + Send + Sync + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + Sync,
{
LoadBalancedChannelBuilder::new_with_service(service_definition)
}
}
impl Service<http::Request<Body>> for LoadBalancedChannel {
type Response = http::Response<<Channel as GrpcService<Body>>::ResponseBody>;
type Error = <Channel as GrpcService<Body>>::Error;
type Future = <Channel as GrpcService<Body>>::Future;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
GrpcService::poll_ready(&mut self.0, cx)
}
fn call(&mut self, request: Request<Body>) -> Self::Future {
GrpcService::call(&mut self.0, request)
}
}
pub enum ResolutionStrategy {
Lazy,
Eager { timeout: Duration },
}
pub struct LoadBalancedChannelBuilder<T, S> {
service_definition: S,
probe_interval: Option<Duration>,
resolution_strategy: ResolutionStrategy,
timeout: Option<Duration>,
connect_timeout: Option<Duration>,
tls_config: Option<ClientTlsConfig>,
lookup_service: Option<T>,
}
impl<S> LoadBalancedChannelBuilder<DnsResolver, S>
where
S: TryInto<ServiceDefinition> + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + Sync,
{
pub fn new_with_service(service_definition: S) -> LoadBalancedChannelBuilder<DnsResolver, S> {
Self {
service_definition,
probe_interval: None,
timeout: None,
connect_timeout: None,
tls_config: None,
lookup_service: None,
resolution_strategy: ResolutionStrategy::Lazy,
}
}
pub fn lookup_service<T: LookupService + Send + Sync + 'static>(
self,
lookup_service: T,
) -> LoadBalancedChannelBuilder<T, S> {
LoadBalancedChannelBuilder {
lookup_service: Some(lookup_service),
service_definition: self.service_definition,
probe_interval: self.probe_interval,
tls_config: self.tls_config,
timeout: self.timeout,
connect_timeout: self.connect_timeout,
resolution_strategy: self.resolution_strategy,
}
}
}
impl<T: LookupService + Send + Sync + 'static + Sized, S> LoadBalancedChannelBuilder<T, S>
where
S: TryInto<ServiceDefinition> + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + Sync,
{
pub fn dns_probe_interval(self, interval: Duration) -> LoadBalancedChannelBuilder<T, S> {
Self {
probe_interval: Some(interval),
..self
}
}
pub fn timeout(self, timeout: Duration) -> LoadBalancedChannelBuilder<T, S> {
Self {
timeout: Some(timeout),
..self
}
}
pub fn connect_timeout(self, connection_timeout: Duration) -> LoadBalancedChannelBuilder<T, S> {
Self {
connect_timeout: Some(connection_timeout),
..self
}
}
pub fn resolution_strategy(
self,
resolution_strategy: ResolutionStrategy,
) -> LoadBalancedChannelBuilder<T, S> {
Self {
resolution_strategy,
..self
}
}
pub fn with_tls(self, tls_config: ClientTlsConfig) -> LoadBalancedChannelBuilder<T, S> {
Self {
tls_config: Some(tls_config),
..self
}
}
pub async fn channel(mut self) -> Result<LoadBalancedChannel, anyhow::Error> {
match self.lookup_service.take() {
Some(lookup_service) => self.channel_inner(lookup_service).await,
None => {
self.channel_inner(DnsResolver::from_system_config().await?)
.await
}
}
}
async fn channel_inner<U>(self, lookup_service: U) -> Result<LoadBalancedChannel, anyhow::Error>
where
U: LookupService + Send + Sync + 'static + Sized,
{
let (channel, sender) =
Channel::balance_channel::<SocketAddr>(GRPC_REPORT_ENDPOINTS_CHANNEL_SIZE);
let config = GrpcServiceProbeConfig {
service_definition: self
.service_definition
.try_into()
.map_err(Into::into)
.map_err(|err| anyhow::anyhow!(err))?,
dns_lookup: lookup_service,
endpoint_timeout: self.timeout,
endpoint_connect_timeout: self.connect_timeout.or(self.timeout),
probe_interval: self
.probe_interval
.unwrap_or_else(|| Duration::from_secs(10)),
};
let tls_config = self.tls_config.map(|mut tls_config| {
tls_config = tls_config.domain_name(config.service_definition.hostname());
tls_config
});
let mut service_probe = GrpcServiceProbe::new_with_reporter(config, sender);
if let Some(tls_config) = tls_config {
service_probe = service_probe.with_tls(tls_config);
}
if let ResolutionStrategy::Eager { timeout } = self.resolution_strategy {
tokio::time::timeout(timeout, service_probe.probe_once())
.await
.context("timeout out while attempting to resolve IPs")?
.context("failed to resolve IPs")?;
}
tokio::spawn(service_probe.probe());
Ok(LoadBalancedChannel(channel))
}
}
const _: () = {
const fn assert_is_send<T: Send>() {}
assert_is_send::<LoadBalancedChannelBuilder<DnsResolver, ServiceDefinition>>();
assert_is_send::<LoadBalancedChannel>();
};