use crate::model::RetryConfig;
use crate::retries::RetryState;
use dashmap::DashMap;
use http::Uri;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tonic::transport::{Channel, Endpoint};
use tonic::{Code, Status};
use tracing::{debug, debug_span, warn, Instrument};
#[derive(Clone)]
pub struct GrpcClient<T: Clone> {
endpoint: Uri,
config: GrpcClientConfig,
client: Arc<Mutex<Option<GrpcClientConnection<T>>>>,
client_factory: Arc<dyn Fn(Channel) -> T + Send + Sync + 'static>,
target_name: String,
}
impl<T: Clone> GrpcClient<T> {
pub fn new(
target_name: impl AsRef<str>,
client_factory: impl Fn(Channel) -> T + Send + Sync + 'static,
endpoint: Uri,
config: GrpcClientConfig,
) -> Self {
Self {
target_name: target_name.as_ref().to_string(),
endpoint,
config,
client: Arc::new(Mutex::new(None)),
client_factory: Arc::new(client_factory),
}
}
pub async fn call<F, R>(&self, description: impl AsRef<str>, f: F) -> Result<R, Status>
where
F: for<'a> Fn(&'a mut T) -> Pin<Box<dyn Future<Output = Result<R, Status>> + 'a + Send>>
+ Send,
{
let mut retries = RetryState::new(&self.config.retries_on_unavailable);
let span = debug_span!(
"gRPC call",
target_name = self.target_name,
description = description.as_ref()
);
loop {
retries.start_attempt();
let mut entry = self
.get()
.await
.map_err(|err| Status::from_error(Box::new(err)))?;
match f(&mut entry.client).instrument(span.clone()).await {
Ok(result) => break Ok(result),
Err(e) => {
if requires_reconnect(&e) {
let _ = self.client.lock().await.take();
if !retries.failed_attempt().await {
span.in_scope(|| {
warn!("gRPC call failed: {:?}, no more retries", e);
});
break Err(e);
} else {
span.in_scope(|| {
debug!("gRPC call failed with {:?}, retrying", e);
});
continue; }
} else {
span.in_scope(|| {
warn!("gRPC call failed: {:?}, not retriable", e);
});
break Err(e);
}
}
}
}
}
async fn get(&self) -> Result<GrpcClientConnection<T>, tonic::transport::Error> {
let mut entry = self.client.lock().await;
match &*entry {
Some(client) => Ok(client.clone()),
None => {
let endpoint = Endpoint::new(self.endpoint.clone())?
.connect_timeout(self.config.connect_timeout);
let channel = endpoint.connect_lazy();
let client = (self.client_factory)(channel);
let connection = GrpcClientConnection { client };
*entry = Some(connection.clone());
Ok(connection)
}
}
}
}
#[derive(Clone)]
pub struct MultiTargetGrpcClient<T: Clone> {
config: GrpcClientConfig,
clients: Arc<DashMap<Uri, GrpcClientConnection<T>>>,
client_factory: Arc<dyn Fn(Channel) -> T + Send + Sync>,
target_name: String,
}
impl<T: Clone> MultiTargetGrpcClient<T> {
pub fn new(
target_name: impl AsRef<str>,
client_factory: impl Fn(Channel) -> T + Send + Sync + 'static,
config: GrpcClientConfig,
) -> Self {
Self {
config,
clients: Arc::new(DashMap::new()),
client_factory: Arc::new(client_factory),
target_name: target_name.as_ref().to_string(),
}
}
pub async fn call<F, R>(
&self,
description: impl AsRef<str>,
endpoint: Uri,
f: F,
) -> Result<R, Status>
where
F: for<'a> Fn(&'a mut T) -> Pin<Box<dyn Future<Output = Result<R, Status>> + 'a + Send>>
+ Send,
{
let mut retries = RetryState::new(&self.config.retries_on_unavailable);
let span = debug_span!(
"gRPC call",
target_name = self.target_name,
endpoint = endpoint.to_string(),
description = description.as_ref()
);
loop {
retries.start_attempt();
let mut entry = self
.get(endpoint.clone())
.map_err(|err| Status::from_error(Box::new(err)))?;
match f(&mut entry.client).instrument(span.clone()).await {
Ok(result) => break Ok(result),
Err(e) => {
if requires_reconnect(&e) {
self.clients.remove(&endpoint);
if !retries.failed_attempt().await {
span.in_scope(|| {
warn!("gRPC call failed: {:?}, no more retries", e);
});
break Err(e);
} else {
span.in_scope(|| {
debug!("gRPC call failed with {:?}, retrying", e);
});
continue; }
} else {
span.in_scope(|| {
warn!("gRPC call failed: {:?}, not retriable", e);
});
break Err(e);
}
}
}
}
}
fn get(&self, endpoint: Uri) -> Result<GrpcClientConnection<T>, tonic::transport::Error> {
let connect_timeout = self.config.connect_timeout;
let entry = self
.clients
.entry(endpoint.clone())
.or_try_insert_with(move || {
let endpoint = Endpoint::new(endpoint)?.connect_timeout(connect_timeout);
let channel = endpoint.connect_lazy();
let client = (self.client_factory)(channel);
Ok(GrpcClientConnection { client })
})?;
Ok(entry.clone())
}
}
#[derive(Clone)]
pub struct GrpcClientConnection<T: Clone> {
client: T,
}
#[derive(Debug, Clone)]
pub struct GrpcClientConfig {
pub connect_timeout: Duration,
pub retries_on_unavailable: RetryConfig,
}
impl Default for GrpcClientConfig {
fn default() -> Self {
Self {
connect_timeout: Duration::from_secs(10),
retries_on_unavailable: RetryConfig::default(),
}
}
}
fn requires_reconnect(e: &Status) -> bool {
e.code() == Code::Unavailable
}