Skip to main content

mip_client/client/
mod.rs

1mod events;
2mod internals;
3mod types;
4
5use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
6use std::sync::Arc;
7
8use tokio::net::TcpStream;
9use tokio::sync::{mpsc, RwLock};
10use tokio::task::JoinHandle;
11use tracing::{debug, info};
12
13use crate::protocol::header::{FrameFlags, FrameType};
14
15pub use events::{
16    OnAck, OnConnect, OnDisconnect, OnError, OnEvent, OnFrame, OnMessage, OnPong, OnReconnecting,
17};
18pub use types::{MIPClientOptions, MIPError, MIPMessage, MIPResult};
19
20use events::Callbacks;
21use internals::{
22    cleanup_tasks, send_close_internal, send_frame, spawn_ping_task, spawn_read_task,
23    spawn_write_task, ClientCommand,
24};
25
26// ============================================================================
27// MIP Client
28// ============================================================================
29
30/// Async MIP protocol client with auto-reconnection support
31pub struct MIPClient {
32    client_id: String,
33    options: Arc<RwLock<MIPClientOptions>>,
34    connected: Arc<AtomicBool>,
35    running: Arc<AtomicBool>,
36    msg_id_counter: Arc<AtomicU64>,
37    reconnect_attempts: Arc<AtomicU64>,
38
39    callbacks: Arc<RwLock<Callbacks>>,
40    command_tx: Option<mpsc::Sender<ClientCommand>>,
41
42    read_task: Option<JoinHandle<()>>,
43    write_task: Option<JoinHandle<()>>,
44    ping_task: Option<JoinHandle<()>>,
45}
46
47impl MIPClient {
48    pub fn new(options: MIPClientOptions) -> Self {
49        Self {
50          client_id: options.client_id.clone(),
51          options: Arc::new(RwLock::new(options)),
52          connected: Arc::new(AtomicBool::new(false)),
53          running: Arc::new(AtomicBool::new(false)),
54          msg_id_counter: Arc::new(AtomicU64::new(0)),
55          reconnect_attempts: Arc::new(AtomicU64::new(0)),
56          callbacks: Arc::new(RwLock::new(Callbacks::default())),
57          command_tx: None,
58          read_task: None,
59          write_task: None,
60          ping_task: None,
61        }
62    }
63
64    /// Check if the client is connected
65    pub fn is_connected(&self) -> bool {
66        self.connected.load(Ordering::SeqCst)
67    }
68
69    // --------------------------------------------------------------------------
70    // Event Registration
71    // --------------------------------------------------------------------------
72
73    /// Register connect event callback
74    pub fn on_connect<F>(&mut self, callback: F) -> &mut Self
75    where
76        F: Fn() + Send + Sync + 'static,
77    {
78        let callbacks = self.callbacks.clone();
79        tokio::task::block_in_place(|| {
80            tokio::runtime::Handle::current().block_on(async {
81                callbacks.write().await.on_connect.push(Arc::new(callback));
82            });
83        });
84        self
85    }
86
87    /// Register disconnect event callback
88    pub fn on_disconnect<F>(&mut self, callback: F) -> &mut Self
89    where
90        F: Fn() + Send + Sync + 'static,
91    {
92        let callbacks = self.callbacks.clone();
93        tokio::task::block_in_place(|| {
94            tokio::runtime::Handle::current().block_on(async {
95                callbacks
96                    .write()
97                    .await
98                    .on_disconnect
99                    .push(Arc::new(callback));
100            });
101        });
102        self
103    }
104
105    /// Register reconnecting event callback
106    pub fn on_reconnecting<F>(&mut self, callback: F) -> &mut Self
107    where
108        F: Fn(u32) + Send + Sync + 'static,
109    {
110        let callbacks = self.callbacks.clone();
111        tokio::task::block_in_place(|| {
112            tokio::runtime::Handle::current().block_on(async {
113                callbacks
114                    .write()
115                    .await
116                    .on_reconnecting
117                    .push(Arc::new(callback));
118            });
119        });
120        self
121    }
122
123    /// Register message event callback
124    pub fn on_message<F>(&mut self, callback: F) -> &mut Self
125    where
126        F: Fn(MIPMessage) + Send + Sync + 'static,
127    {
128        let callbacks = self.callbacks.clone();
129        tokio::task::block_in_place(|| {
130            tokio::runtime::Handle::current().block_on(async {
131                callbacks.write().await.on_message.push(Arc::new(callback));
132            });
133        });
134        self
135    }
136
137    /// Register event callback
138    pub fn on_event<F>(&mut self, callback: F) -> &mut Self
139    where
140        F: Fn(MIPMessage) + Send + Sync + 'static,
141    {
142        let callbacks = self.callbacks.clone();
143        tokio::task::block_in_place(|| {
144            tokio::runtime::Handle::current().block_on(async {
145                callbacks.write().await.on_event.push(Arc::new(callback));
146            });
147        });
148        self
149    }
150
151    /// Register ACK event callback
152    pub fn on_ack<F>(&mut self, callback: F) -> &mut Self
153    where
154        F: Fn(u64) + Send + Sync + 'static,
155    {
156        let callbacks = self.callbacks.clone();
157        tokio::task::block_in_place(|| {
158            tokio::runtime::Handle::current().block_on(async {
159                callbacks.write().await.on_ack.push(Arc::new(callback));
160            });
161        });
162        self
163    }
164
165    /// Register pong event callback
166    pub fn on_pong<F>(&mut self, callback: F) -> &mut Self
167    where
168        F: Fn() + Send + Sync + 'static,
169    {
170        let callbacks = self.callbacks.clone();
171        tokio::task::block_in_place(|| {
172            tokio::runtime::Handle::current().block_on(async {
173                callbacks.write().await.on_pong.push(Arc::new(callback));
174            });
175        });
176        self
177    }
178
179    /// Register error event callback
180    pub fn on_error<F>(&mut self, callback: F) -> &mut Self
181    where
182        F: Fn(MIPError) + Send + Sync + 'static,
183    {
184        let callbacks = self.callbacks.clone();
185        tokio::task::block_in_place(|| {
186            tokio::runtime::Handle::current().block_on(async {
187                callbacks.write().await.on_error.push(Arc::new(callback));
188            });
189        });
190        self
191    }
192
193    /// Register raw frame event callback
194    pub fn on_frame<F>(&mut self, callback: F) -> &mut Self
195    where
196        F: Fn(crate::protocol::header::Header, Vec<u8>) + Send + Sync + 'static,
197    {
198        let callbacks = self.callbacks.clone();
199        tokio::task::block_in_place(|| {
200            tokio::runtime::Handle::current().block_on(async {
201                callbacks.write().await.on_frame.push(Arc::new(callback));
202            });
203        });
204        self
205    }
206
207    // --------------------------------------------------------------------------
208    // Public API
209    // --------------------------------------------------------------------------
210
211    /// Connect to the MIP server
212    pub async fn connect(&mut self) -> MIPResult<()> {
213        if self.connected.load(Ordering::SeqCst) {
214            return Ok(());
215        }
216
217        self.running.store(true, Ordering::SeqCst);
218
219        let options = self.options.read().await;
220        let addr = format!("{}:{}", options.host, options.port);
221        drop(options);
222
223        debug!("Connecting to {}", addr);
224
225        let stream = TcpStream::connect(&addr)
226            .await
227            .map_err(|e| MIPError::Connection(e.to_string()))?;
228
229        let (read_half, write_half) = stream.into_split();
230
231        // Create command channel
232        let (command_tx, command_rx) = mpsc::channel::<ClientCommand>(100);
233        self.command_tx = Some(command_tx);
234
235        self.connected.store(true, Ordering::SeqCst);
236        self.reconnect_attempts.store(0, Ordering::SeqCst);
237
238        // Start read task
239        let read_task = spawn_read_task(
240            read_half,
241            self.connected.clone(),
242            self.running.clone(),
243            self.callbacks.clone(),
244            self.options.clone(),
245            self.reconnect_attempts.clone(),
246        );
247        self.read_task = Some(read_task);
248
249        // Start write task
250        let write_task = spawn_write_task(write_half, command_rx);
251        self.write_task = Some(write_task);
252
253        // Setup ping interval
254        let options = self.options.read().await;
255        let ping_interval = options.ping_interval_ms;
256        drop(options);
257
258        self.ping_task = spawn_ping_task(
259            ping_interval,
260            self.connected.clone(),
261            self.running.clone(),
262            self.command_tx.clone(),
263            self.msg_id_counter.clone(),
264        );
265
266        // Emit connect event
267        let callbacks = self.callbacks.read().await;
268        for callback in &callbacks.on_connect {
269            callback();
270        }
271
272        let client_id = self.options.read().await.client_id.clone();
273        let payload = client_id.as_bytes();
274
275        if let Some(tx) = &self.command_tx {
276            let res = send_frame(
277              Some(tx),
278              &self.connected.clone(),
279              &self.msg_id_counter.clone(),
280              FrameType::Hello, payload,
281              FrameFlags::NONE
282            );
283
284            if let Err(e) = res {
285              println!("Failed to send HELLO frame: {}", e);
286                self.connected.store(false, Ordering::SeqCst);
287                self.running.store(false, Ordering::SeqCst);
288                return Err(e);
289            } else {
290              self.client_id = res.unwrap().to_string();
291            }
292        }
293
294        info!("Connected to {}", addr);
295        Ok(())
296    }
297
298    /// Disconnect from the server
299    pub async fn disconnect(&mut self) -> MIPResult<()> {
300        {
301            let mut options = self.options.write().await;
302            options.auto_reconnect = false;
303        }
304
305        self.running.store(false, Ordering::SeqCst);
306
307        // Send close frame
308        if let Some(tx) = &self.command_tx {
309            let _ = send_close_internal(tx, &self.connected, &self.msg_id_counter).await;
310            let _ = tx.send(ClientCommand::Disconnect).await;
311        }
312
313        cleanup_tasks(&mut self.ping_task, &mut self.read_task, &mut self.write_task).await;
314        self.connected.store(false, Ordering::SeqCst);
315        self.command_tx = None;
316
317        info!("Disconnected");
318        Ok(())
319    }
320
321    /// Subscribe to a topic
322    pub fn subscribe(&self, topic: &str, require_ack: bool) -> MIPResult<u64> {
323        let topic_bytes = topic.as_bytes();
324        let flags = if require_ack {
325            FrameFlags::ACK_REQUIRED
326        } else {
327            FrameFlags::NONE
328        };
329        send_frame(
330            self.command_tx.as_ref(),
331            &self.connected,
332            &self.msg_id_counter,
333            FrameType::Subscribe,
334            topic_bytes,
335            flags,
336        )
337    }
338
339    /// Unsubscribe from a topic
340    pub fn unsubscribe(&self, topic: &str, require_ack: bool) -> MIPResult<u64> {
341        let topic_bytes = topic.as_bytes();
342        let flags = if require_ack {
343            FrameFlags::ACK_REQUIRED
344        } else {
345            FrameFlags::NONE
346        };
347        send_frame(
348            self.command_tx.as_ref(),
349            &self.connected,
350            &self.msg_id_counter,
351            FrameType::Unsubscribe,
352            topic_bytes,
353            flags,
354        )
355    }
356
357    /// Publish a message to a topic
358    pub fn publish(&self, topic: &str, message: &str, flags: FrameFlags) -> MIPResult<u64> {
359        let topic_bytes = topic.as_bytes();
360        let message_bytes = message.as_bytes();
361
362        // Build payload: [topic_length (2 bytes)] [topic] [message]
363        let mut payload = Vec::with_capacity(2 + topic_bytes.len() + message_bytes.len());
364        payload.extend_from_slice(&(topic_bytes.len() as u16).to_be_bytes());
365        payload.extend_from_slice(topic_bytes);
366        payload.extend_from_slice(message_bytes);
367
368        send_frame(
369            self.command_tx.as_ref(),
370            &self.connected,
371            &self.msg_id_counter,
372            FrameType::Publish,
373            &payload,
374            flags,
375        )
376    }
377
378    /// Send a ping to the server
379    pub fn ping(&self) -> MIPResult<u64> {
380        send_frame(
381            self.command_tx.as_ref(),
382            &self.connected,
383            &self.msg_id_counter,
384            FrameType::Ping,
385            &[],
386            FrameFlags::NONE,
387        )
388    }
389
390    /// Send raw frame (advanced usage)
391    pub fn send_raw_frame(
392        &self,
393        frame_type: FrameType,
394        payload: &[u8],
395        flags: FrameFlags,
396    ) -> MIPResult<u64> {
397        send_frame(
398            self.command_tx.as_ref(),
399            &self.connected,
400            &self.msg_id_counter,
401            frame_type,
402            payload,
403            flags,
404        )
405    }
406}
407
408impl Drop for MIPClient {
409    fn drop(&mut self) {
410        self.running.store(false, Ordering::SeqCst);
411    }
412}
413
414// ============================================================================
415// Utility Functions
416// ============================================================================
417
418/// Get the name of a frame type
419pub fn get_frame_type_name(frame_type: FrameType) -> &'static str {
420    match frame_type {
421        FrameType::Hello => "HELLO",
422        FrameType::Subscribe => "SUBSCRIBE",
423        FrameType::Unsubscribe => "UNSUBSCRIBE",
424        FrameType::Publish => "PUBLISH",
425        FrameType::Event => "EVENT",
426        FrameType::Ack => "ACK",
427        FrameType::Error => "ERROR",
428        FrameType::Ping => "PING",
429        FrameType::Pong => "PONG",
430        FrameType::Close => "CLOSE",
431    }
432}
433
434/// Create a new MIP client with default options
435pub fn create_client() -> MIPClient {
436    MIPClient::new(MIPClientOptions::default())
437}