batata_client/remote/
grpc_connection.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use futures::StreamExt;
6use parking_lot::RwLock;
7use tokio::sync::mpsc;
8use tokio::time::timeout;
9use tonic::transport::Channel;
10use tracing::{debug, error, info};
11
12use crate::api::remote::{
13    ConnectionSetupRequest, RequestTrait, ResponseTrait, ServerCheckRequest, ServerCheckResponse,
14};
15use crate::api::{
16    bi_request_stream_client::BiRequestStreamClient, request_client::RequestClient, Payload,
17};
18use crate::common::{DEFAULT_HEARTBEAT_INTERVAL_MS, DEFAULT_TIMEOUT_MS, LABEL_APP_NAME, LABEL_SOURCE, LABEL_SOURCE_SDK};
19use crate::error::{BatataError, Result};
20use crate::remote::ServerAddress;
21
22/// Connection state
23#[derive(Clone, Copy, Debug, PartialEq, Eq)]
24pub enum ConnectionState {
25    Disconnected,
26    Connecting,
27    Connected,
28    Reconnecting,
29}
30
31/// Server push message handler callback
32pub type ServerPushHandler = Arc<dyn Fn(Payload) -> Option<Payload> + Send + Sync>;
33
34/// gRPC connection manager for bidirectional streaming
35pub struct GrpcConnection {
36    server_address: ServerAddress,
37    connection_id: Arc<RwLock<String>>,
38    state: Arc<RwLock<ConnectionState>>,
39    client_ip: String,
40    namespace: String,
41    app_name: String,
42    labels: HashMap<String, String>,
43
44    // Channel sender for outgoing messages
45    sender: Arc<RwLock<Option<mpsc::Sender<Payload>>>>,
46
47    // Handler for server push messages
48    push_handlers: Arc<RwLock<HashMap<String, ServerPushHandler>>>,
49
50    // Timeout settings
51    timeout_ms: u64,
52    #[allow(dead_code)]
53    heartbeat_interval_ms: u64,
54}
55
56impl GrpcConnection {
57    /// Create a new gRPC connection
58    pub fn new(server_address: ServerAddress) -> Self {
59        Self {
60            server_address,
61            connection_id: Arc::new(RwLock::new(String::new())),
62            state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
63            client_ip: get_local_ip(),
64            namespace: String::new(),
65            app_name: String::new(),
66            labels: HashMap::new(),
67            sender: Arc::new(RwLock::new(None)),
68            push_handlers: Arc::new(RwLock::new(HashMap::new())),
69            timeout_ms: DEFAULT_TIMEOUT_MS,
70            heartbeat_interval_ms: DEFAULT_HEARTBEAT_INTERVAL_MS,
71        }
72    }
73
74    /// Set namespace
75    pub fn with_namespace(mut self, namespace: &str) -> Self {
76        self.namespace = namespace.to_string();
77        self
78    }
79
80    /// Set app name
81    pub fn with_app_name(mut self, app_name: &str) -> Self {
82        self.app_name = app_name.to_string();
83        self
84    }
85
86    /// Set labels
87    pub fn with_labels(mut self, labels: HashMap<String, String>) -> Self {
88        self.labels = labels;
89        self
90    }
91
92    /// Set timeout
93    pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
94        self.timeout_ms = timeout_ms;
95        self
96    }
97
98    /// Get connection ID
99    pub fn connection_id(&self) -> String {
100        self.connection_id.read().clone()
101    }
102
103    /// Get current state
104    pub fn state(&self) -> ConnectionState {
105        *self.state.read()
106    }
107
108    /// Check if connected
109    pub fn is_connected(&self) -> bool {
110        *self.state.read() == ConnectionState::Connected
111    }
112
113    /// Register a push handler for a specific message type
114    pub fn register_push_handler(&self, message_type: &str, handler: ServerPushHandler) {
115        self.push_handlers
116            .write()
117            .insert(message_type.to_string(), handler);
118    }
119
120    /// Connect to server
121    pub async fn connect(&self) -> Result<()> {
122        {
123            let mut state = self.state.write();
124            if *state == ConnectionState::Connected {
125                return Ok(());
126            }
127            *state = ConnectionState::Connecting;
128        }
129
130        let endpoint = self.server_address.grpc_endpoint();
131        info!("Connecting to server: {}", endpoint);
132
133        // Create channel
134        let channel = Channel::from_shared(endpoint.clone())
135            .map_err(|e| BatataError::connection_error(format!("Invalid endpoint: {}", e)))?
136            .connect_timeout(Duration::from_millis(self.timeout_ms))
137            .connect()
138            .await?;
139
140        // First, do server check using unary RPC
141        let connection_id = self.server_check(&channel).await?;
142        *self.connection_id.write() = connection_id.clone();
143
144        // Create bidirectional stream
145        let mut bi_client = BiRequestStreamClient::new(channel.clone());
146
147        // Create channel for sending messages
148        let (tx, rx) = mpsc::channel::<Payload>(100);
149        *self.sender.write() = Some(tx.clone());
150
151        // Convert receiver to stream
152        let outbound = tokio_stream::wrappers::ReceiverStream::new(rx);
153
154        // Start bidirectional stream
155        let response = bi_client.request_bi_stream(outbound).await?;
156        let mut inbound = response.into_inner();
157
158        // Send connection setup request
159        self.send_connection_setup(&tx).await?;
160
161        // Spawn task to handle incoming messages
162        let push_handlers = self.push_handlers.clone();
163        let state = self.state.clone();
164        let tx_clone = tx.clone();
165
166        tokio::spawn(async move {
167            while let Some(result) = inbound.next().await {
168                match result {
169                    Ok(payload) => {
170                        if let Some(metadata) = &payload.metadata {
171                            let msg_type = &metadata.r#type;
172                            debug!("Received message type: {}", msg_type);
173
174                            // Handle server push - clone handler before await
175                            let handler_opt = {
176                                let handlers = push_handlers.read();
177                                handlers.get(msg_type).cloned()
178                            };
179
180                            if let Some(handler) = handler_opt
181                                && let Some(response) = handler(payload)
182                                && let Err(e) = tx_clone.send(response).await
183                            {
184                                error!("Failed to send response: {}", e);
185                            }
186                        }
187                    }
188                    Err(e) => {
189                        error!("Stream error: {}", e);
190                        *state.write() = ConnectionState::Disconnected;
191                        break;
192                    }
193                }
194            }
195        });
196
197        // Update state
198        *self.state.write() = ConnectionState::Connected;
199        info!(
200            "Connected to server, connection_id: {}",
201            self.connection_id()
202        );
203
204        Ok(())
205    }
206
207    /// Disconnect from server
208    pub async fn disconnect(&self) {
209        *self.state.write() = ConnectionState::Disconnected;
210        *self.sender.write() = None;
211        info!("Disconnected from server");
212    }
213
214    /// Send a request and wait for response
215    pub async fn request<Req, Resp>(&self, request: &Req) -> Result<Resp>
216    where
217        Req: RequestTrait + serde::Serialize,
218        Resp: for<'de> serde::Deserialize<'de> + Default + ResponseTrait,
219    {
220        if !self.is_connected() {
221            return Err(BatataError::ClientNotStarted);
222        }
223
224        let payload = request.to_payload(&self.client_ip);
225
226        // Use unary request
227        let channel = Channel::from_shared(self.server_address.grpc_endpoint())
228            .map_err(|e| BatataError::connection_error(format!("Invalid endpoint: {}", e)))?
229            .connect_timeout(Duration::from_millis(self.timeout_ms))
230            .connect()
231            .await?;
232
233        let mut client = RequestClient::new(channel);
234
235        let response = timeout(
236            Duration::from_millis(self.timeout_ms),
237            client.request(payload),
238        )
239        .await
240        .map_err(|_| BatataError::Timeout {
241            timeout_ms: self.timeout_ms,
242        })??;
243
244        let payload = response.into_inner();
245
246        // Deserialize response
247        let resp: Resp = payload
248            .body
249            .as_ref()
250            .and_then(|body| serde_json::from_slice(&body.value).ok())
251            .unwrap_or_default();
252
253        if !resp.is_success() {
254            return Err(BatataError::server_error(resp.error_code(), resp.message()));
255        }
256
257        Ok(resp)
258    }
259
260    /// Send a message through the stream (no response expected)
261    pub async fn send(&self, payload: Payload) -> Result<()> {
262        let sender = self.sender.read().clone();
263        let sender = sender.ok_or(BatataError::ClientNotStarted)?;
264
265        sender
266            .send(payload)
267            .await
268            .map_err(|e| BatataError::connection_error(format!("Failed to send: {}", e)))
269    }
270
271    /// Server check using unary RPC
272    async fn server_check(&self, channel: &Channel) -> Result<String> {
273        let mut client = RequestClient::new(channel.clone());
274
275        let request = ServerCheckRequest::new();
276        let payload = request.to_payload(&self.client_ip);
277
278        let response = timeout(
279            Duration::from_millis(self.timeout_ms),
280            client.request(payload),
281        )
282        .await
283        .map_err(|_| BatataError::Timeout {
284            timeout_ms: self.timeout_ms,
285        })??;
286
287        let payload = response.into_inner();
288        let resp: ServerCheckResponse = payload
289            .body
290            .as_ref()
291            .and_then(|body| serde_json::from_slice(&body.value).ok())
292            .unwrap_or_default();
293
294        if !resp.is_success() {
295            return Err(BatataError::server_error(resp.error_code(), resp.message()));
296        }
297
298        Ok(resp.connection_id)
299    }
300
301    /// Send connection setup request
302    async fn send_connection_setup(&self, sender: &mpsc::Sender<Payload>) -> Result<()> {
303        let mut labels = self.labels.clone();
304        labels.insert(LABEL_SOURCE.to_string(), LABEL_SOURCE_SDK.to_string());
305        if !self.app_name.is_empty() {
306            labels.insert(LABEL_APP_NAME.to_string(), self.app_name.clone());
307        }
308
309        let request = ConnectionSetupRequest::new()
310            .with_labels(labels)
311            .with_tenant(self.namespace.clone());
312
313        let payload = request.to_payload(&self.client_ip);
314
315        sender
316            .send(payload)
317            .await
318            .map_err(|e| BatataError::connection_error(format!("Failed to send setup: {}", e)))
319    }
320}
321
322/// Get local IP address
323fn get_local_ip() -> String {
324    if let Ok(addrs) = if_addrs::get_if_addrs() {
325        for iface in addrs {
326            if !iface.is_loopback()
327                && let std::net::IpAddr::V4(ipv4) = iface.ip()
328            {
329                return ipv4.to_string();
330            }
331        }
332    }
333    "127.0.0.1".to_string()
334}