use crate::error::{ClientError, Result};
use std::time::Duration;
use tonic::transport::{Channel, ClientTlsConfig, Endpoint};
#[derive(Clone, Debug)]
pub struct Client {
channel: Channel,
endpoint: String,
}
#[derive(Debug)]
pub struct ClientBuilder {
endpoint: String,
timeout: Option<Duration>,
connect_timeout: Option<Duration>,
tls_config: Option<ClientTlsConfig>,
}
impl Client {
pub async fn connect(endpoint: impl Into<String>) -> Result<Self> {
let endpoint = endpoint.into();
let channel = Self::create_channel(&endpoint).await?;
Ok(Self { channel, endpoint })
}
pub fn builder() -> ClientBuilder {
ClientBuilder {
endpoint: String::new(),
timeout: None,
connect_timeout: None,
tls_config: None,
}
}
pub async fn connect_with_config<F>(endpoint: impl Into<String>, config: F) -> Result<Self>
where
F: FnOnce(Endpoint) -> Endpoint,
{
let endpoint_str = endpoint.into();
let endpoint = Endpoint::from_shared(endpoint_str.clone())
.map_err(|e| ClientError::InvalidEndpoint(e.to_string()))?;
let endpoint = config(endpoint);
let channel = endpoint.connect().await.map_err(ClientError::Transport)?;
Ok(Self {
channel,
endpoint: endpoint_str,
})
}
pub fn channel(&self) -> &Channel {
&self.channel
}
pub fn endpoint(&self) -> &str {
&self.endpoint
}
async fn create_channel(endpoint: &str) -> Result<Channel> {
let endpoint = Endpoint::from_shared(endpoint.to_string())
.map_err(|e| ClientError::InvalidEndpoint(e.to_string()))?
.timeout(Duration::from_secs(10))
.connect_timeout(Duration::from_secs(5));
let channel = endpoint.connect().await.map_err(ClientError::Transport)?;
Ok(channel)
}
}
impl ClientBuilder {
pub fn endpoint(mut self, endpoint: impl Into<String>) -> Self {
self.endpoint = endpoint.into();
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn connect_timeout(mut self, timeout: Duration) -> Self {
self.connect_timeout = Some(timeout);
self
}
pub fn tls(mut self) -> Self {
self.tls_config = Some(ClientTlsConfig::new());
self
}
pub fn tls_config(mut self, config: ClientTlsConfig) -> Self {
self.tls_config = Some(config);
self
}
pub async fn build(self) -> Result<Client> {
if self.endpoint.is_empty() {
return Err(ClientError::InvalidEndpoint(
"Endpoint cannot be empty".to_string(),
));
}
let endpoint = Endpoint::from_shared(self.endpoint.clone())
.map_err(|e| ClientError::InvalidEndpoint(e.to_string()))?;
let endpoint = if let Some(timeout) = self.timeout {
endpoint.timeout(timeout)
} else {
endpoint.timeout(Duration::from_secs(10))
};
let endpoint = if let Some(connect_timeout) = self.connect_timeout {
endpoint.connect_timeout(connect_timeout)
} else {
endpoint.connect_timeout(Duration::from_secs(5))
};
let endpoint = if let Some(tls_config) = self.tls_config {
endpoint.tls_config(tls_config)?
} else {
endpoint
};
let channel = endpoint.connect().await.map_err(ClientError::Transport)?;
Ok(Client {
channel,
endpoint: self.endpoint,
})
}
}