batata_client/remote/
request_client.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use parking_lot::RwLock;
6use tokio::time::timeout;
7use tonic::transport::Channel;
8use tracing::{debug, warn};
9
10use crate::api::remote::{RequestTrait, ResponseTrait};
11use crate::api::{request_client::RequestClient as GrpcRequestClient, Payload};
12use crate::common::DEFAULT_TIMEOUT_MS;
13use crate::error::{BatataError, Result};
14use crate::remote::{GrpcConnection, ServerListManager};
15
16/// Client for making requests to Batata/Nacos server
17pub struct RpcClient {
18    server_list: Arc<ServerListManager>,
19    connection: Arc<RwLock<Option<Arc<GrpcConnection>>>>,
20    namespace: String,
21    app_name: String,
22    labels: HashMap<String, String>,
23    timeout_ms: u64,
24    retry_times: u32,
25}
26
27impl RpcClient {
28    /// Create a new RPC client
29    pub fn new(server_addresses: Vec<String>) -> Result<Self> {
30        let server_list = Arc::new(ServerListManager::new(server_addresses)?);
31
32        Ok(Self {
33            server_list,
34            connection: Arc::new(RwLock::new(None)),
35            namespace: String::new(),
36            app_name: String::new(),
37            labels: HashMap::new(),
38            timeout_ms: DEFAULT_TIMEOUT_MS,
39            retry_times: 3,
40        })
41    }
42
43    /// Set namespace
44    pub fn with_namespace(mut self, namespace: &str) -> Self {
45        self.namespace = namespace.to_string();
46        self
47    }
48
49    /// Set app name
50    pub fn with_app_name(mut self, app_name: &str) -> Self {
51        self.app_name = app_name.to_string();
52        self
53    }
54
55    /// Set labels
56    pub fn with_labels(mut self, labels: HashMap<String, String>) -> Self {
57        self.labels = labels;
58        self
59    }
60
61    /// Set timeout
62    pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
63        self.timeout_ms = timeout_ms;
64        self
65    }
66
67    /// Set retry times
68    pub fn with_retry(mut self, retry_times: u32) -> Self {
69        self.retry_times = retry_times;
70        self
71    }
72
73    /// Start the client and establish connection
74    pub async fn start(&self) -> Result<()> {
75        let server = self.server_list.current_server().clone();
76
77        let connection = GrpcConnection::new(server)
78            .with_namespace(&self.namespace)
79            .with_app_name(&self.app_name)
80            .with_labels(self.labels.clone())
81            .with_timeout(self.timeout_ms);
82
83        connection.connect().await?;
84
85        *self.connection.write() = Some(Arc::new(connection));
86
87        Ok(())
88    }
89
90    /// Stop the client
91    pub async fn stop(&self) {
92        let conn = self.connection.write().take();
93        if let Some(conn) = conn {
94            conn.disconnect().await;
95        }
96    }
97
98    /// Check if connected
99    pub fn is_connected(&self) -> bool {
100        self.connection
101            .read()
102            .as_ref()
103            .map(|c| c.is_connected())
104            .unwrap_or(false)
105    }
106
107    /// Get connection ID
108    pub fn connection_id(&self) -> Option<String> {
109        self.connection.read().as_ref().map(|c| c.connection_id())
110    }
111
112    /// Send request with retry
113    pub async fn request<Req, Resp>(&self, request: &Req) -> Result<Resp>
114    where
115        Req: RequestTrait + serde::Serialize + Clone,
116        Resp: for<'de> serde::Deserialize<'de> + Default + ResponseTrait,
117    {
118        let connection = self
119            .connection
120            .read()
121            .clone()
122            .ok_or(BatataError::ClientNotStarted)?;
123
124        let mut last_error = BatataError::NoAvailableServer;
125
126        for attempt in 0..=self.retry_times {
127            if attempt > 0 {
128                debug!("Retry attempt {} for request", attempt);
129            }
130
131            match connection.request::<Req, Resp>(request).await {
132                Ok(resp) => return Ok(resp),
133                Err(e) => {
134                    if e.is_retryable() && attempt < self.retry_times {
135                        warn!("Request failed, will retry: {}", e);
136                        last_error = e;
137                        tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
138                        continue;
139                    }
140                    return Err(e);
141                }
142            }
143        }
144
145        Err(last_error)
146    }
147
148    /// Send request without retry (for fire-and-forget)
149    pub async fn send(&self, payload: Payload) -> Result<()> {
150        let connection = self
151            .connection
152            .read()
153            .clone()
154            .ok_or(BatataError::ClientNotStarted)?;
155
156        connection.send(payload).await
157    }
158
159    /// Get the current connection
160    pub fn connection(&self) -> Option<Arc<GrpcConnection>> {
161        self.connection.read().clone()
162    }
163
164    /// Simple unary request without connection state
165    pub async fn unary_request<Req, Resp>(&self, request: &Req) -> Result<Resp>
166    where
167        Req: RequestTrait + serde::Serialize,
168        Resp: for<'de> serde::Deserialize<'de> + Default + ResponseTrait,
169    {
170        let server = self.server_list.current_server();
171        let endpoint = server.grpc_endpoint();
172
173        let channel = Channel::from_shared(endpoint)
174            .map_err(|e| BatataError::connection_error(format!("Invalid endpoint: {}", e)))?
175            .connect_timeout(Duration::from_millis(self.timeout_ms))
176            .connect()
177            .await?;
178
179        let mut client = GrpcRequestClient::new(channel);
180
181        let client_ip = get_local_ip();
182        let payload = request.to_payload(&client_ip);
183
184        let response = timeout(
185            Duration::from_millis(self.timeout_ms),
186            client.request(payload),
187        )
188        .await
189        .map_err(|_| BatataError::Timeout {
190            timeout_ms: self.timeout_ms,
191        })??;
192
193        let payload = response.into_inner();
194
195        let resp: Resp = payload
196            .body
197            .as_ref()
198            .and_then(|body| serde_json::from_slice(&body.value).ok())
199            .unwrap_or_default();
200
201        if !resp.is_success() {
202            return Err(BatataError::server_error(resp.error_code(), resp.message()));
203        }
204
205        Ok(resp)
206    }
207}
208
209/// Get local IP address
210fn get_local_ip() -> String {
211    if let Ok(addrs) = if_addrs::get_if_addrs() {
212        for iface in addrs {
213            if !iface.is_loopback()
214                && let std::net::IpAddr::V4(ipv4) = iface.ip()
215            {
216                return ipv4.to_string();
217            }
218        }
219    }
220    "127.0.0.1".to_string()
221}