batata_client/remote/
request_client.rs1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use parking_lot::RwLock;
6use tokio::time::timeout;
7use tonic::transport::Channel;
8use tracing::{debug, warn};
9
10use crate::api::remote::{RequestTrait, ResponseTrait};
11use crate::api::{request_client::RequestClient as GrpcRequestClient, Payload};
12use crate::common::DEFAULT_TIMEOUT_MS;
13use crate::error::{BatataError, Result};
14use crate::remote::{GrpcConnection, ServerListManager};
15
16pub struct RpcClient {
18 server_list: Arc<ServerListManager>,
19 connection: Arc<RwLock<Option<Arc<GrpcConnection>>>>,
20 namespace: String,
21 app_name: String,
22 labels: HashMap<String, String>,
23 timeout_ms: u64,
24 retry_times: u32,
25}
26
27impl RpcClient {
28 pub fn new(server_addresses: Vec<String>) -> Result<Self> {
30 let server_list = Arc::new(ServerListManager::new(server_addresses)?);
31
32 Ok(Self {
33 server_list,
34 connection: Arc::new(RwLock::new(None)),
35 namespace: String::new(),
36 app_name: String::new(),
37 labels: HashMap::new(),
38 timeout_ms: DEFAULT_TIMEOUT_MS,
39 retry_times: 3,
40 })
41 }
42
43 pub fn with_namespace(mut self, namespace: &str) -> Self {
45 self.namespace = namespace.to_string();
46 self
47 }
48
49 pub fn with_app_name(mut self, app_name: &str) -> Self {
51 self.app_name = app_name.to_string();
52 self
53 }
54
55 pub fn with_labels(mut self, labels: HashMap<String, String>) -> Self {
57 self.labels = labels;
58 self
59 }
60
61 pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
63 self.timeout_ms = timeout_ms;
64 self
65 }
66
67 pub fn with_retry(mut self, retry_times: u32) -> Self {
69 self.retry_times = retry_times;
70 self
71 }
72
73 pub async fn start(&self) -> Result<()> {
75 let server = self.server_list.current_server().clone();
76
77 let connection = GrpcConnection::new(server)
78 .with_namespace(&self.namespace)
79 .with_app_name(&self.app_name)
80 .with_labels(self.labels.clone())
81 .with_timeout(self.timeout_ms);
82
83 connection.connect().await?;
84
85 *self.connection.write() = Some(Arc::new(connection));
86
87 Ok(())
88 }
89
90 pub async fn stop(&self) {
92 let conn = self.connection.write().take();
93 if let Some(conn) = conn {
94 conn.disconnect().await;
95 }
96 }
97
98 pub fn is_connected(&self) -> bool {
100 self.connection
101 .read()
102 .as_ref()
103 .map(|c| c.is_connected())
104 .unwrap_or(false)
105 }
106
107 pub fn connection_id(&self) -> Option<String> {
109 self.connection.read().as_ref().map(|c| c.connection_id())
110 }
111
112 pub async fn request<Req, Resp>(&self, request: &Req) -> Result<Resp>
114 where
115 Req: RequestTrait + serde::Serialize + Clone,
116 Resp: for<'de> serde::Deserialize<'de> + Default + ResponseTrait,
117 {
118 let connection = self
119 .connection
120 .read()
121 .clone()
122 .ok_or(BatataError::ClientNotStarted)?;
123
124 let mut last_error = BatataError::NoAvailableServer;
125
126 for attempt in 0..=self.retry_times {
127 if attempt > 0 {
128 debug!("Retry attempt {} for request", attempt);
129 }
130
131 match connection.request::<Req, Resp>(request).await {
132 Ok(resp) => return Ok(resp),
133 Err(e) => {
134 if e.is_retryable() && attempt < self.retry_times {
135 warn!("Request failed, will retry: {}", e);
136 last_error = e;
137 tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
138 continue;
139 }
140 return Err(e);
141 }
142 }
143 }
144
145 Err(last_error)
146 }
147
148 pub async fn send(&self, payload: Payload) -> Result<()> {
150 let connection = self
151 .connection
152 .read()
153 .clone()
154 .ok_or(BatataError::ClientNotStarted)?;
155
156 connection.send(payload).await
157 }
158
159 pub fn connection(&self) -> Option<Arc<GrpcConnection>> {
161 self.connection.read().clone()
162 }
163
164 pub async fn unary_request<Req, Resp>(&self, request: &Req) -> Result<Resp>
166 where
167 Req: RequestTrait + serde::Serialize,
168 Resp: for<'de> serde::Deserialize<'de> + Default + ResponseTrait,
169 {
170 let server = self.server_list.current_server();
171 let endpoint = server.grpc_endpoint();
172
173 let channel = Channel::from_shared(endpoint)
174 .map_err(|e| BatataError::connection_error(format!("Invalid endpoint: {}", e)))?
175 .connect_timeout(Duration::from_millis(self.timeout_ms))
176 .connect()
177 .await?;
178
179 let mut client = GrpcRequestClient::new(channel);
180
181 let client_ip = get_local_ip();
182 let payload = request.to_payload(&client_ip);
183
184 let response = timeout(
185 Duration::from_millis(self.timeout_ms),
186 client.request(payload),
187 )
188 .await
189 .map_err(|_| BatataError::Timeout {
190 timeout_ms: self.timeout_ms,
191 })??;
192
193 let payload = response.into_inner();
194
195 let resp: Resp = payload
196 .body
197 .as_ref()
198 .and_then(|body| serde_json::from_slice(&body.value).ok())
199 .unwrap_or_default();
200
201 if !resp.is_success() {
202 return Err(BatataError::server_error(resp.error_code(), resp.message()));
203 }
204
205 Ok(resp)
206 }
207}
208
209fn get_local_ip() -> String {
211 if let Ok(addrs) = if_addrs::get_if_addrs() {
212 for iface in addrs {
213 if !iface.is_loopback()
214 && let std::net::IpAddr::V4(ipv4) = iface.ip()
215 {
216 return ipv4.to_string();
217 }
218 }
219 }
220 "127.0.0.1".to_string()
221}