use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use parking_lot::RwLock;
use tokio::time::timeout;
use tonic::transport::Channel;
use tracing::{debug, warn};
use crate::api::remote::{RequestTrait, ResponseTrait};
use crate::api::{request_client::RequestClient as GrpcRequestClient, Payload};
use crate::common::DEFAULT_TIMEOUT_MS;
use crate::error::{BatataError, Result};
use crate::remote::{GrpcConnection, ServerListManager};
pub struct RpcClient {
server_list: Arc<ServerListManager>,
connection: Arc<RwLock<Option<Arc<GrpcConnection>>>>,
namespace: String,
app_name: String,
labels: HashMap<String, String>,
timeout_ms: u64,
retry_times: u32,
}
impl RpcClient {
pub fn new(server_addresses: Vec<String>) -> Result<Self> {
let server_list = Arc::new(ServerListManager::new(server_addresses)?);
Ok(Self {
server_list,
connection: Arc::new(RwLock::new(None)),
namespace: String::new(),
app_name: String::new(),
labels: HashMap::new(),
timeout_ms: DEFAULT_TIMEOUT_MS,
retry_times: 3,
})
}
pub fn with_namespace(mut self, namespace: &str) -> Self {
self.namespace = namespace.to_string();
self
}
pub fn with_app_name(mut self, app_name: &str) -> Self {
self.app_name = app_name.to_string();
self
}
pub fn with_labels(mut self, labels: HashMap<String, String>) -> Self {
self.labels = labels;
self
}
pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
self.timeout_ms = timeout_ms;
self
}
pub fn with_retry(mut self, retry_times: u32) -> Self {
self.retry_times = retry_times;
self
}
pub async fn start(&self) -> Result<()> {
let server = self.server_list.current_server().clone();
let connection = GrpcConnection::new(server)
.with_namespace(&self.namespace)
.with_app_name(&self.app_name)
.with_labels(self.labels.clone())
.with_timeout(self.timeout_ms);
connection.connect().await?;
*self.connection.write() = Some(Arc::new(connection));
Ok(())
}
pub async fn stop(&self) {
let conn = self.connection.write().take();
if let Some(conn) = conn {
conn.disconnect().await;
}
}
pub fn is_connected(&self) -> bool {
self.connection
.read()
.as_ref()
.map(|c| c.is_connected())
.unwrap_or(false)
}
pub fn connection_id(&self) -> Option<String> {
self.connection.read().as_ref().map(|c| c.connection_id())
}
pub async fn request<Req, Resp>(&self, request: &Req) -> Result<Resp>
where
Req: RequestTrait + serde::Serialize + Clone,
Resp: for<'de> serde::Deserialize<'de> + Default + ResponseTrait,
{
let connection = self
.connection
.read()
.clone()
.ok_or(BatataError::ClientNotStarted)?;
let mut last_error = BatataError::NoAvailableServer;
for attempt in 0..=self.retry_times {
if attempt > 0 {
debug!("Retry attempt {} for request", attempt);
}
match connection.request::<Req, Resp>(request).await {
Ok(resp) => return Ok(resp),
Err(e) => {
if e.is_retryable() && attempt < self.retry_times {
warn!("Request failed, will retry: {}", e);
last_error = e;
tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
continue;
}
return Err(e);
}
}
}
Err(last_error)
}
pub async fn send(&self, payload: Payload) -> Result<()> {
let connection = self
.connection
.read()
.clone()
.ok_or(BatataError::ClientNotStarted)?;
connection.send(payload).await
}
pub fn connection(&self) -> Option<Arc<GrpcConnection>> {
self.connection.read().clone()
}
pub async fn unary_request<Req, Resp>(&self, request: &Req) -> Result<Resp>
where
Req: RequestTrait + serde::Serialize,
Resp: for<'de> serde::Deserialize<'de> + Default + ResponseTrait,
{
let server = self.server_list.current_server();
let endpoint = server.grpc_endpoint();
let channel = Channel::from_shared(endpoint)
.map_err(|e| BatataError::connection_error(format!("Invalid endpoint: {}", e)))?
.connect_timeout(Duration::from_millis(self.timeout_ms))
.connect()
.await?;
let mut client = GrpcRequestClient::new(channel);
let client_ip = get_local_ip();
let payload = request.to_payload(&client_ip);
let response = timeout(
Duration::from_millis(self.timeout_ms),
client.request(payload),
)
.await
.map_err(|_| BatataError::Timeout {
timeout_ms: self.timeout_ms,
})??;
let payload = response.into_inner();
let resp: Resp = payload
.body
.as_ref()
.and_then(|body| serde_json::from_slice(&body.value).ok())
.unwrap_or_default();
if !resp.is_success() {
return Err(BatataError::server_error(resp.error_code(), resp.message()));
}
Ok(resp)
}
}
fn get_local_ip() -> String {
if let Ok(addrs) = if_addrs::get_if_addrs() {
for iface in addrs {
if !iface.is_loopback()
&& let std::net::IpAddr::V4(ipv4) = iface.ip()
{
return ipv4.to_string();
}
}
}
"127.0.0.1".to_string()
}