use crate::{LookupService, ServiceDefinition};
use std::collections::HashSet;
use std::net::SocketAddr;
use tokio::sync::mpsc::Sender;
use tonic::transport::{
channel::{Change, Endpoint},
ClientTlsConfig,
};
#[derive(thiserror::Error, Debug)]
pub enum ProbeError {
#[error("Failed to resolve ServiceDefinition")]
ResolveServiceDefinition(#[source] anyhow::Error),
#[error("Changeset sender closed")]
ChangesetSenderClosed(#[source] anyhow::Error),
}
pub struct GrpcServiceProbe<Lookup>
where
Lookup: LookupService,
{
service_definition: ServiceDefinition,
scheme: http::uri::Scheme,
dns_lookup: Lookup,
probe_interval: tokio::time::Duration,
endpoint_timeout: Option<tokio::time::Duration>,
endpoint_connect_timeout: Option<tokio::time::Duration>,
endpoints: HashSet<SocketAddr>,
endpoint_reporter: Sender<Change<SocketAddr, Endpoint>>,
tls_config: Option<ClientTlsConfig>,
}
pub struct GrpcServiceProbeConfig<Lookup>
where
Lookup: LookupService,
{
pub service_definition: ServiceDefinition,
pub dns_lookup: Lookup,
pub probe_interval: tokio::time::Duration,
pub endpoint_timeout: Option<tokio::time::Duration>,
pub endpoint_connect_timeout: Option<tokio::time::Duration>,
}
impl<Lookup: LookupService> GrpcServiceProbe<Lookup> {
pub fn new_with_reporter(
config: GrpcServiceProbeConfig<Lookup>,
endpoint_reporter: Sender<Change<SocketAddr, Endpoint>>,
) -> GrpcServiceProbe<Lookup> {
Self {
service_definition: config.service_definition,
dns_lookup: config.dns_lookup,
probe_interval: config.probe_interval,
endpoint_timeout: config.endpoint_timeout,
endpoint_connect_timeout: config.endpoint_connect_timeout,
endpoints: HashSet::new(),
endpoint_reporter,
scheme: http::uri::Scheme::HTTP,
tls_config: None,
}
}
pub fn with_tls(self, tls_config: ClientTlsConfig) -> GrpcServiceProbe<Lookup> {
Self {
tls_config: Some(tls_config),
scheme: http::uri::Scheme::HTTPS,
..self
}
}
pub async fn probe(mut self) -> Result<(), anyhow::Error> {
loop {
self.probe_once().await.or_else(|err| {
if let ProbeError::ChangesetSenderClosed(_) = err {
Err(err)
} else {
Ok(())
}
})?;
tokio::time::sleep(self.probe_interval).await;
}
}
pub async fn probe_once(&mut self) -> Result<(), ProbeError> {
match self
.dns_lookup
.resolve_service_endpoints(&self.service_definition)
.await
{
Ok(endpoints) => {
let changeset = self.create_changeset(&endpoints).await;
self.report_and_commit(changeset, endpoints).await.map_err(|e| {
tracing::error!("Failed to report the discovered DNS changeset. The gRPC client has closed the channel therefore the DNS probe loop will exit.\n{:?}", e);
e
})?;
}
Err(err) => {
return Err(ProbeError::ResolveServiceDefinition(
err.context("failed to resolve ips from host"),
));
}
}
Ok(())
}
async fn create_changeset(
&mut self,
endpoints: &HashSet<SocketAddr>,
) -> Vec<Change<SocketAddr, Endpoint>> {
let mut changeset = Vec::new();
let remove_set: HashSet<SocketAddr> =
self.endpoints.difference(endpoints).copied().collect();
let add_set: HashSet<SocketAddr> = endpoints.difference(&self.endpoints).copied().collect();
changeset.extend(
add_set
.into_iter()
.filter_map(|addr| self.build_endpoint(&addr).map(|endpoint| (addr, endpoint)))
.map(|(addr, endpoint)| Change::Insert(addr, endpoint)),
);
changeset.extend(remove_set.into_iter().map(Change::Remove));
changeset
}
fn overwrite_endpoints(&mut self, current_ips: HashSet<SocketAddr>) {
self.endpoints = current_ips;
}
#[tracing::instrument(
skip(endpoints, self),
level = "debug",
name = "report-and-commit-endpoint-changeset"
)]
async fn report_and_commit(
&mut self,
changeset: Vec<Change<SocketAddr, Endpoint>>,
endpoints: HashSet<SocketAddr>,
) -> Result<(), ProbeError> {
for change in changeset {
if self.endpoint_reporter.send(change).await.is_err() {
return Err(ProbeError::ChangesetSenderClosed(anyhow::anyhow!("Tried to report endpoint changes on a closed channel, this is probably due to the gRPC client being dropped.")));
}
}
self.overwrite_endpoints(endpoints);
Ok(())
}
fn build_endpoint(&self, ip_address: &SocketAddr) -> Option<Endpoint> {
let uri = match ip_address.is_ipv6() {
false => format!(
"{}://{}:{}",
self.scheme,
ip_address.ip(),
ip_address.port()
),
true => format!(
"{}://[{}]:{}",
self.scheme,
ip_address.ip(),
ip_address.port()
),
};
let mut endpoint = Endpoint::from_shared(uri)
.map_err(|err| {
tracing::warn!("endpoint creation error: {:?}", err);
})
.ok()?;
if let Some(ref tls_config) = self.tls_config {
endpoint = endpoint
.tls_config(tls_config.clone())
.map_err(|err| {
tracing::warn!("tls error: {:?}", err);
err
})
.ok()?;
}
if let Some(ref timeout) = self.endpoint_timeout {
endpoint = endpoint.timeout(*timeout);
}
if let Some(ref connect_timeout) = self.endpoint_connect_timeout {
endpoint = endpoint.connect_timeout(*connect_timeout)
}
Some(endpoint)
}
}