use crate::error::{Error, Result};
use std::sync::Arc;
use tonic::transport::{Channel, Endpoint};
#[derive(Clone)]
pub struct GrpcClient {
endpoint: String,
channel: Channel,
insecure: bool,
}
impl GrpcClient {
pub fn builder(endpoint: impl Into<String>) -> GrpcClientBuilder {
GrpcClientBuilder::new(endpoint)
}
pub async fn new(endpoint: impl Into<String>) -> Result<Self> {
Self::builder(endpoint).connect().await
}
pub async fn new_secure(endpoint: impl Into<String>) -> Result<Self> {
Self::builder(endpoint).insecure(false).connect().await
}
pub fn connect_lazy(endpoint: impl Into<String>, insecure: bool) -> Result<Self> {
Self::builder(endpoint)
.insecure(insecure)
.lazy(true)
.connect_lazy()
}
async fn create_channel(endpoint: &str, insecure: bool) -> Result<Channel> {
let endpoint_builder = configure_endpoint(endpoint, insecure)?;
let channel = endpoint_builder
.connect()
.await
.map_err(|e| Error::Connection(e.to_string()))?;
Ok(channel)
}
pub fn channel(&self) -> Channel {
self.channel.clone()
}
pub fn endpoint(&self) -> &str {
&self.endpoint
}
pub fn is_insecure(&self) -> bool {
self.insecure
}
pub async fn health_check(&self) -> Result<()> {
let _ = self.channel.clone();
Ok(())
}
}
impl std::fmt::Debug for GrpcClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GrpcClient")
.field("endpoint", &self.endpoint)
.field("insecure", &self.insecure)
.finish()
}
}
pub struct GrpcClientBuilder {
endpoint: String,
insecure: bool,
lazy: bool,
}
impl GrpcClientBuilder {
fn new(endpoint: impl Into<String>) -> Self {
let endpoint_str = endpoint.into();
let is_secure_endpoint = endpoint_str.starts_with("https://");
let is_insecure_endpoint = endpoint_str.starts_with("http://");
let insecure = if is_insecure_endpoint {
true
} else if is_secure_endpoint {
false
} else {
false
};
Self {
endpoint: endpoint_str,
insecure,
lazy: false,
}
}
pub fn insecure(mut self, insecure: bool) -> Self {
self.insecure = insecure;
self
}
pub fn lazy(mut self, lazy: bool) -> Self {
self.lazy = lazy;
self
}
pub async fn connect(self) -> Result<GrpcClient> {
if self.lazy {
return self.connect_lazy();
}
let channel = GrpcClient::create_channel(&self.endpoint, self.insecure).await?;
Ok(GrpcClient {
endpoint: self.endpoint,
channel,
insecure: self.insecure,
})
}
pub fn connect_lazy(self) -> Result<GrpcClient> {
let endpoint_builder = configure_endpoint(&self.endpoint, self.insecure)?;
let channel = endpoint_builder.connect_lazy();
Ok(GrpcClient {
endpoint: self.endpoint,
channel,
insecure: self.insecure,
})
}
}
fn configure_endpoint(endpoint: &str, insecure: bool) -> Result<Endpoint> {
let mut endpoint_builder = Endpoint::from_shared(endpoint.to_string())
.map_err(|e| Error::Connection(e.to_string()))?;
if !insecure {
endpoint_builder = endpoint_builder
.tls_config(tonic::transport::ClientTlsConfig::new())
.map_err(|e| Error::Connection(e.to_string()))?;
}
endpoint_builder = endpoint_builder
.keep_alive_timeout(std::time::Duration::from_secs(10))
.keep_alive_while_idle(true)
.http2_keep_alive_interval(std::time::Duration::from_secs(30));
endpoint_builder = endpoint_builder.http2_adaptive_window(true);
endpoint_builder = endpoint_builder
.initial_connection_window_size(65535 * 32) .initial_stream_window_size(65535 * 32);
endpoint_builder = endpoint_builder.tcp_nodelay(true);
endpoint_builder = endpoint_builder
.connect_timeout(std::time::Duration::from_secs(5))
.timeout(std::time::Duration::from_secs(30));
Ok(endpoint_builder)
}
#[derive(Clone, Default)]
pub struct GrpcClientPool {
clients: Arc<parking_lot::RwLock<std::collections::HashMap<String, GrpcClient>>>,
}
impl GrpcClientPool {
pub fn new() -> Self {
Self {
clients: Arc::new(parking_lot::RwLock::new(std::collections::HashMap::new())),
}
}
pub fn add(&self, name: impl Into<String>, client: GrpcClient) {
let mut clients = self.clients.write();
clients.insert(name.into(), client);
}
pub fn get(&self, name: &str) -> Option<GrpcClient> {
let clients = self.clients.read();
clients.get(name).cloned()
}
pub fn remove(&self, name: &str) -> Option<GrpcClient> {
let mut clients = self.clients.write();
clients.remove(name)
}
pub fn names(&self) -> Vec<String> {
let clients = self.clients.read();
clients.keys().cloned().collect()
}
pub fn clear(&self) {
let mut clients = self.clients.write();
clients.clear();
}
}
impl std::fmt::Debug for GrpcClientPool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let clients = self.clients.read();
f.debug_struct("GrpcClientPool")
.field("clients", &clients.keys().collect::<Vec<_>>())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::GrpcClient;
#[tokio::test]
async fn builder_creates_lazy_client() {
let client = GrpcClient::builder("http://localhost:50051")
.lazy(true)
.connect_lazy()
.expect("lazy connect should not fail");
assert!(client.is_insecure());
assert_eq!(client.endpoint(), "http://localhost:50051");
}
#[tokio::test]
async fn builder_respects_lazy_on_connect() {
let client = GrpcClient::builder("http://localhost:50051")
.lazy(true)
.connect()
.await
.expect("lazy connect should not open socket");
assert!(client.is_insecure());
assert_eq!(client.endpoint(), "http://localhost:50051");
}
#[tokio::test]
async fn builder_can_create_secure_lazy_client() {
let client = GrpcClient::builder("https://example.com:443")
.insecure(false)
.connect_lazy()
.expect("lazy TLS client should be configured");
assert!(!client.is_insecure());
}
#[tokio::test]
async fn connect_lazy_creates_client() {
let client = GrpcClient::connect_lazy("http://localhost:50051", true)
.expect("connect_lazy should work");
assert_eq!(client.endpoint(), "http://localhost:50051");
assert!(client.is_insecure());
}
#[tokio::test]
async fn test_grpc_client_pool_new() {
let pool = super::GrpcClientPool::new();
assert_eq!(pool.names().len(), 0);
}
#[tokio::test]
async fn test_grpc_client_pool_add_get() {
let pool = super::GrpcClientPool::new();
let client = GrpcClient::connect_lazy("http://localhost:50051", true).unwrap();
pool.add("service1", client.clone());
let retrieved = pool.get("service1");
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().endpoint(), client.endpoint());
}
#[tokio::test]
async fn test_grpc_client_pool_names() {
let pool = super::GrpcClientPool::new();
let client1 = GrpcClient::connect_lazy("http://localhost:50051", true).unwrap();
let client2 = GrpcClient::connect_lazy("http://localhost:50052", true).unwrap();
pool.add("service1", client1);
pool.add("service2", client2);
let names = pool.names();
assert_eq!(names.len(), 2);
assert!(names.contains(&"service1".to_string()));
assert!(names.contains(&"service2".to_string()));
}
#[tokio::test]
async fn test_grpc_client_pool_clear() {
let pool = super::GrpcClientPool::new();
let client = GrpcClient::connect_lazy("http://localhost:50051", true).unwrap();
pool.add("service1", client);
assert_eq!(pool.names().len(), 1);
pool.clear();
assert_eq!(pool.names().len(), 0);
}
#[tokio::test]
async fn test_grpc_client_pool_clone() {
let pool1 = super::GrpcClientPool::new();
let client = GrpcClient::connect_lazy("http://localhost:50051", true).unwrap();
pool1.add("service1", client);
let pool2 = pool1.clone();
assert!(pool2.get("service1").is_some());
}
#[tokio::test]
async fn test_grpc_client_pool_debug() {
let pool = super::GrpcClientPool::new();
let client = GrpcClient::connect_lazy("http://localhost:50051", true).unwrap();
pool.add("service1", client);
let debug_str = format!("{:?}", pool);
assert!(debug_str.contains("GrpcClientPool"));
assert!(debug_str.contains("service1"));
}
}