use std::collections::HashMap;
use std::str::FromStr;
use std::sync::RwLock;
use async_trait::async_trait;
use derive_more::{Deref, DerefMut};
use thiserror::Error;
use tonic::transport::{Channel, Endpoint};
use crate::config::MainConfig;
use crate::model::ValidShardedId;
use crate::prelude::{GrpcRequestInterceptor, Shard};
use crate::service::ServiceContext;
use crate::types::{ProjectId, RequestId};
#[derive(Debug, Error)]
pub enum GrpcClientError {
#[error(transparent)]
Connect(#[from] tonic::transport::Error),
#[error("Internal data routing error: {0}")]
Routing(String),
#[error("Malformed grpc endpoint address: {0}")]
BadAddress(String),
}
#[derive(Deref, DerefMut)]
pub struct ScopedGrpcClient<T> {
pub project_id: ValidShardedId<ProjectId>,
pub request_id: RequestId,
#[deref]
#[deref_mut]
inner: T,
}
impl<T> ScopedGrpcClient<T> {
pub fn new(
project_id: ValidShardedId<ProjectId>,
request_id: RequestId,
inner: T,
) -> Self {
Self {
project_id,
request_id,
inner,
}
}
}
#[async_trait]
pub trait GrpcClientType: Sync + Send {
type RawGrpcClient;
fn create_scoped_client(
project_id: ValidShardedId<ProjectId>,
request_id: RequestId,
channel: tonic::transport::Channel,
interceptor: GrpcRequestInterceptor,
) -> Self;
fn get_mut(&mut self) -> &mut ScopedGrpcClient<Self::RawGrpcClient>;
async fn create_channel(
address: &str,
) -> Result<tonic::transport::Channel, GrpcClientError> {
let channel = Endpoint::from_str(address)
.map_err(|_| GrpcClientError::BadAddress(address.to_string()))?
.connect()
.await?;
Ok(channel)
}
fn address_map(config: &MainConfig) -> &HashMap<u64, String>;
fn get_address(
config: &MainConfig,
_project_id: &ValidShardedId<ProjectId>,
) -> Result<String, GrpcClientError> {
let shard = Shard(0);
let address =
Self::address_map(config).get(&shard.0).ok_or_else(|| {
GrpcClientError::Routing(format!(
"No endpoint was found for shard {shard} in config (grpc \
client type {:?})",
std::any::type_name::<Self>(),
))
})?;
Ok(address.clone())
}
}
#[async_trait]
pub trait GrpcClientFactory: Send + Sync {
type ClientType;
async fn get_client(
&self,
request_id: &RequestId,
project_id: &ValidShardedId<ProjectId>,
) -> Result<Self::ClientType, GrpcClientError>;
}
pub struct GrpcClientProvider<T> {
service_context: ServiceContext,
channel_cache: RwLock<HashMap<String, Channel>>,
phantom: std::marker::PhantomData<T>,
}
impl<T: GrpcClientType> GrpcClientProvider<T> {
pub fn new(service_context: ServiceContext) -> Self {
Self {
service_context,
channel_cache: Default::default(),
phantom: Default::default(),
}
}
}
#[async_trait]
impl<T: GrpcClientType> GrpcClientFactory for GrpcClientProvider<T> {
type ClientType = T;
async fn get_client(
&self,
request_id: &RequestId,
project_id: &ValidShardedId<ProjectId>,
) -> Result<Self::ClientType, GrpcClientError> {
let address = T::get_address(
&self.service_context.get_config().main,
project_id,
)?;
let mut channel = None;
{
let cache = self.channel_cache.read().unwrap();
if let Some(ch) = cache.get(&address) {
channel = Some(ch.clone());
}
}
if channel.is_none() {
let temp_new_ch = T::create_channel(&address).await?;
{
let mut cache = self.channel_cache.write().unwrap();
if let Some(ch) = cache.get(&address) {
channel = Some(ch.clone());
} else {
cache.insert(address.clone(), temp_new_ch.clone());
channel = Some(temp_new_ch);
}
}
}
assert!(channel.is_some());
let interceptor = GrpcRequestInterceptor {
project_id: Some(project_id.clone()),
request_id: Some(request_id.clone()),
};
Ok(T::create_scoped_client(
project_id.clone(),
request_id.clone(),
channel.unwrap(),
interceptor,
))
}
}
pub mod test_helpers {
use std::sync::Arc;
use hyper::Uri;
use tempfile::TempPath;
use tokio::net::UnixStream;
use tower::service_fn;
use super::*;
pub struct TestGrpcClientProvider<T> {
cell_to_socket_path: HashMap<u16, Arc<TempPath>>,
channel_cache: RwLock<HashMap<u16, Channel>>,
phantom: std::marker::PhantomData<T>,
}
impl<T: GrpcClientType> TestGrpcClientProvider<T> {
pub fn new(cell_to_socket_path: HashMap<u16, Arc<TempPath>>) -> Self {
Self {
cell_to_socket_path,
channel_cache: Default::default(),
phantom: Default::default(),
}
}
pub fn new_single_shard(socket_path: Arc<TempPath>) -> Self {
let mut cell_to_socket_path = HashMap::with_capacity(1);
cell_to_socket_path.insert(0, socket_path);
Self {
cell_to_socket_path,
channel_cache: Default::default(),
phantom: Default::default(),
}
}
}
#[async_trait]
impl<T: GrpcClientType> GrpcClientFactory for TestGrpcClientProvider<T> {
type ClientType = T;
async fn get_client(
&self,
request_id: &RequestId,
project_id: &ValidShardedId<ProjectId>,
) -> Result<Self::ClientType, GrpcClientError> {
let _shard = Shard(0);
let cell: u16 = 0;
let socket = self
.cell_to_socket_path
.get(&cell)
.expect("Cell not found!")
.clone();
let mut channel = None;
{
let cache = self.channel_cache.read().unwrap();
if let Some(ch) = cache.get(&cell) {
channel = Some(ch.clone());
}
}
if channel.is_none() {
let temp_new_ch = Endpoint::try_from("http://example.url")
.unwrap()
.connect_with_connector(service_fn(move |_: Uri| {
let socket = Arc::clone(&socket);
async move { UnixStream::connect(&*socket).await }
}))
.await?;
{
let mut cache = self.channel_cache.write().unwrap();
if let Some(ch) = cache.get(&cell) {
channel = Some(ch.clone());
} else {
cache.insert(cell, temp_new_ch.clone());
channel = Some(temp_new_ch);
}
}
}
assert!(channel.is_some());
let interceptor = GrpcRequestInterceptor {
project_id: Some(project_id.clone()),
request_id: Some(request_id.clone()),
};
Ok(T::create_scoped_client(
project_id.clone(),
request_id.clone(),
channel.unwrap(),
interceptor,
))
}
}
}