batata_client/remote/
grpc_connection.rs1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use futures::StreamExt;
6use parking_lot::RwLock;
7use tokio::sync::mpsc;
8use tokio::time::timeout;
9use tonic::transport::Channel;
10use tracing::{debug, error, info};
11
12use crate::api::remote::{
13 ConnectionSetupRequest, RequestTrait, ResponseTrait, ServerCheckRequest, ServerCheckResponse,
14};
15use crate::api::{
16 bi_request_stream_client::BiRequestStreamClient, request_client::RequestClient, Payload,
17};
18use crate::common::{DEFAULT_HEARTBEAT_INTERVAL_MS, DEFAULT_TIMEOUT_MS, LABEL_APP_NAME, LABEL_SOURCE, LABEL_SOURCE_SDK};
19use crate::error::{BatataError, Result};
20use crate::remote::ServerAddress;
21
22#[derive(Clone, Copy, Debug, PartialEq, Eq)]
24pub enum ConnectionState {
25 Disconnected,
26 Connecting,
27 Connected,
28 Reconnecting,
29}
30
31pub type ServerPushHandler = Arc<dyn Fn(Payload) -> Option<Payload> + Send + Sync>;
33
34pub struct GrpcConnection {
36 server_address: ServerAddress,
37 connection_id: Arc<RwLock<String>>,
38 state: Arc<RwLock<ConnectionState>>,
39 client_ip: String,
40 namespace: String,
41 app_name: String,
42 labels: HashMap<String, String>,
43
44 sender: Arc<RwLock<Option<mpsc::Sender<Payload>>>>,
46
47 push_handlers: Arc<RwLock<HashMap<String, ServerPushHandler>>>,
49
50 timeout_ms: u64,
52 #[allow(dead_code)]
53 heartbeat_interval_ms: u64,
54}
55
56impl GrpcConnection {
57 pub fn new(server_address: ServerAddress) -> Self {
59 Self {
60 server_address,
61 connection_id: Arc::new(RwLock::new(String::new())),
62 state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
63 client_ip: get_local_ip(),
64 namespace: String::new(),
65 app_name: String::new(),
66 labels: HashMap::new(),
67 sender: Arc::new(RwLock::new(None)),
68 push_handlers: Arc::new(RwLock::new(HashMap::new())),
69 timeout_ms: DEFAULT_TIMEOUT_MS,
70 heartbeat_interval_ms: DEFAULT_HEARTBEAT_INTERVAL_MS,
71 }
72 }
73
74 pub fn with_namespace(mut self, namespace: &str) -> Self {
76 self.namespace = namespace.to_string();
77 self
78 }
79
80 pub fn with_app_name(mut self, app_name: &str) -> Self {
82 self.app_name = app_name.to_string();
83 self
84 }
85
86 pub fn with_labels(mut self, labels: HashMap<String, String>) -> Self {
88 self.labels = labels;
89 self
90 }
91
92 pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
94 self.timeout_ms = timeout_ms;
95 self
96 }
97
98 pub fn connection_id(&self) -> String {
100 self.connection_id.read().clone()
101 }
102
103 pub fn state(&self) -> ConnectionState {
105 *self.state.read()
106 }
107
108 pub fn is_connected(&self) -> bool {
110 *self.state.read() == ConnectionState::Connected
111 }
112
113 pub fn register_push_handler(&self, message_type: &str, handler: ServerPushHandler) {
115 self.push_handlers
116 .write()
117 .insert(message_type.to_string(), handler);
118 }
119
120 pub async fn connect(&self) -> Result<()> {
122 {
123 let mut state = self.state.write();
124 if *state == ConnectionState::Connected {
125 return Ok(());
126 }
127 *state = ConnectionState::Connecting;
128 }
129
130 let endpoint = self.server_address.grpc_endpoint();
131 info!("Connecting to server: {}", endpoint);
132
133 let channel = Channel::from_shared(endpoint.clone())
135 .map_err(|e| BatataError::connection_error(format!("Invalid endpoint: {}", e)))?
136 .connect_timeout(Duration::from_millis(self.timeout_ms))
137 .connect()
138 .await?;
139
140 let connection_id = self.server_check(&channel).await?;
142 *self.connection_id.write() = connection_id.clone();
143
144 let mut bi_client = BiRequestStreamClient::new(channel.clone());
146
147 let (tx, rx) = mpsc::channel::<Payload>(100);
149 *self.sender.write() = Some(tx.clone());
150
151 let outbound = tokio_stream::wrappers::ReceiverStream::new(rx);
153
154 let response = bi_client.request_bi_stream(outbound).await?;
156 let mut inbound = response.into_inner();
157
158 self.send_connection_setup(&tx).await?;
160
161 let push_handlers = self.push_handlers.clone();
163 let state = self.state.clone();
164 let tx_clone = tx.clone();
165
166 tokio::spawn(async move {
167 while let Some(result) = inbound.next().await {
168 match result {
169 Ok(payload) => {
170 if let Some(metadata) = &payload.metadata {
171 let msg_type = &metadata.r#type;
172 debug!("Received message type: {}", msg_type);
173
174 let handler_opt = {
176 let handlers = push_handlers.read();
177 handlers.get(msg_type).cloned()
178 };
179
180 if let Some(handler) = handler_opt
181 && let Some(response) = handler(payload)
182 && let Err(e) = tx_clone.send(response).await
183 {
184 error!("Failed to send response: {}", e);
185 }
186 }
187 }
188 Err(e) => {
189 error!("Stream error: {}", e);
190 *state.write() = ConnectionState::Disconnected;
191 break;
192 }
193 }
194 }
195 });
196
197 *self.state.write() = ConnectionState::Connected;
199 info!(
200 "Connected to server, connection_id: {}",
201 self.connection_id()
202 );
203
204 Ok(())
205 }
206
207 pub async fn disconnect(&self) {
209 *self.state.write() = ConnectionState::Disconnected;
210 *self.sender.write() = None;
211 info!("Disconnected from server");
212 }
213
214 pub async fn request<Req, Resp>(&self, request: &Req) -> Result<Resp>
216 where
217 Req: RequestTrait + serde::Serialize,
218 Resp: for<'de> serde::Deserialize<'de> + Default + ResponseTrait,
219 {
220 if !self.is_connected() {
221 return Err(BatataError::ClientNotStarted);
222 }
223
224 let payload = request.to_payload(&self.client_ip);
225
226 let channel = Channel::from_shared(self.server_address.grpc_endpoint())
228 .map_err(|e| BatataError::connection_error(format!("Invalid endpoint: {}", e)))?
229 .connect_timeout(Duration::from_millis(self.timeout_ms))
230 .connect()
231 .await?;
232
233 let mut client = RequestClient::new(channel);
234
235 let response = timeout(
236 Duration::from_millis(self.timeout_ms),
237 client.request(payload),
238 )
239 .await
240 .map_err(|_| BatataError::Timeout {
241 timeout_ms: self.timeout_ms,
242 })??;
243
244 let payload = response.into_inner();
245
246 let resp: Resp = payload
248 .body
249 .as_ref()
250 .and_then(|body| serde_json::from_slice(&body.value).ok())
251 .unwrap_or_default();
252
253 if !resp.is_success() {
254 return Err(BatataError::server_error(resp.error_code(), resp.message()));
255 }
256
257 Ok(resp)
258 }
259
260 pub async fn send(&self, payload: Payload) -> Result<()> {
262 let sender = self.sender.read().clone();
263 let sender = sender.ok_or(BatataError::ClientNotStarted)?;
264
265 sender
266 .send(payload)
267 .await
268 .map_err(|e| BatataError::connection_error(format!("Failed to send: {}", e)))
269 }
270
271 async fn server_check(&self, channel: &Channel) -> Result<String> {
273 let mut client = RequestClient::new(channel.clone());
274
275 let request = ServerCheckRequest::new();
276 let payload = request.to_payload(&self.client_ip);
277
278 let response = timeout(
279 Duration::from_millis(self.timeout_ms),
280 client.request(payload),
281 )
282 .await
283 .map_err(|_| BatataError::Timeout {
284 timeout_ms: self.timeout_ms,
285 })??;
286
287 let payload = response.into_inner();
288 let resp: ServerCheckResponse = payload
289 .body
290 .as_ref()
291 .and_then(|body| serde_json::from_slice(&body.value).ok())
292 .unwrap_or_default();
293
294 if !resp.is_success() {
295 return Err(BatataError::server_error(resp.error_code(), resp.message()));
296 }
297
298 Ok(resp.connection_id)
299 }
300
301 async fn send_connection_setup(&self, sender: &mpsc::Sender<Payload>) -> Result<()> {
303 let mut labels = self.labels.clone();
304 labels.insert(LABEL_SOURCE.to_string(), LABEL_SOURCE_SDK.to_string());
305 if !self.app_name.is_empty() {
306 labels.insert(LABEL_APP_NAME.to_string(), self.app_name.clone());
307 }
308
309 let request = ConnectionSetupRequest::new()
310 .with_labels(labels)
311 .with_tenant(self.namespace.clone());
312
313 let payload = request.to_payload(&self.client_ip);
314
315 sender
316 .send(payload)
317 .await
318 .map_err(|e| BatataError::connection_error(format!("Failed to send setup: {}", e)))
319 }
320}
321
322fn get_local_ip() -> String {
324 if let Ok(addrs) = if_addrs::get_if_addrs() {
325 for iface in addrs {
326 if !iface.is_loopback()
327 && let std::net::IpAddr::V4(ipv4) = iface.ip()
328 {
329 return ipv4.to_string();
330 }
331 }
332 }
333 "127.0.0.1".to_string()
334}