Skip to main content

mip_client/client/
mod.rs

1mod events;
2mod internals;
3mod types;
4
5use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
6use std::sync::Arc;
7
8use tokio::net::TcpStream;
9use tokio::sync::{mpsc, RwLock};
10use tokio::task::JoinHandle;
11use tracing::{debug, info};
12
13use crate::protocol::header::{FrameFlags, FrameType};
14
15pub use events::{
16    OnAck, OnConnect, OnDisconnect, OnError, OnEvent, OnFrame, OnMessage, OnPong, OnReconnecting,
17};
18pub use types::{MIPClientOptions, MIPError, MIPMessage, MIPResult};
19
20use events::Callbacks;
21use internals::{
22    cleanup_tasks, send_close_internal, send_frame, spawn_ping_task, spawn_read_task,
23    spawn_write_task, ClientCommand,
24};
25
26// ============================================================================
27// MIP Client
28// ============================================================================
29
30/// Async MIP protocol client with auto-reconnection support
31pub struct MIPClient {
32    options: Arc<RwLock<MIPClientOptions>>,
33    connected: Arc<AtomicBool>,
34    running: Arc<AtomicBool>,
35    msg_id_counter: Arc<AtomicU64>,
36    reconnect_attempts: Arc<AtomicU64>,
37
38    callbacks: Arc<RwLock<Callbacks>>,
39    command_tx: Option<mpsc::Sender<ClientCommand>>,
40
41    read_task: Option<JoinHandle<()>>,
42    write_task: Option<JoinHandle<()>>,
43    ping_task: Option<JoinHandle<()>>,
44}
45
46impl MIPClient {
47    pub fn new(options: MIPClientOptions) -> Self {
48        Self {
49            options: Arc::new(RwLock::new(options)),
50            connected: Arc::new(AtomicBool::new(false)),
51            running: Arc::new(AtomicBool::new(false)),
52            msg_id_counter: Arc::new(AtomicU64::new(0)),
53            reconnect_attempts: Arc::new(AtomicU64::new(0)),
54            callbacks: Arc::new(RwLock::new(Callbacks::default())),
55            command_tx: None,
56            read_task: None,
57            write_task: None,
58            ping_task: None,
59        }
60    }
61
62    /// Check if the client is connected
63    pub fn is_connected(&self) -> bool {
64        self.connected.load(Ordering::SeqCst)
65    }
66
67    // --------------------------------------------------------------------------
68    // Event Registration
69    // --------------------------------------------------------------------------
70
71    /// Register connect event callback
72    pub fn on_connect<F>(&mut self, callback: F) -> &mut Self
73    where
74        F: Fn() + Send + Sync + 'static,
75    {
76        let callbacks = self.callbacks.clone();
77        tokio::task::block_in_place(|| {
78            tokio::runtime::Handle::current().block_on(async {
79                callbacks.write().await.on_connect.push(Arc::new(callback));
80            });
81        });
82        self
83    }
84
85    /// Register disconnect event callback
86    pub fn on_disconnect<F>(&mut self, callback: F) -> &mut Self
87    where
88        F: Fn() + Send + Sync + 'static,
89    {
90        let callbacks = self.callbacks.clone();
91        tokio::task::block_in_place(|| {
92            tokio::runtime::Handle::current().block_on(async {
93                callbacks
94                    .write()
95                    .await
96                    .on_disconnect
97                    .push(Arc::new(callback));
98            });
99        });
100        self
101    }
102
103    /// Register reconnecting event callback
104    pub fn on_reconnecting<F>(&mut self, callback: F) -> &mut Self
105    where
106        F: Fn(u32) + Send + Sync + 'static,
107    {
108        let callbacks = self.callbacks.clone();
109        tokio::task::block_in_place(|| {
110            tokio::runtime::Handle::current().block_on(async {
111                callbacks
112                    .write()
113                    .await
114                    .on_reconnecting
115                    .push(Arc::new(callback));
116            });
117        });
118        self
119    }
120
121    /// Register message event callback
122    pub fn on_message<F>(&mut self, callback: F) -> &mut Self
123    where
124        F: Fn(MIPMessage) + Send + Sync + 'static,
125    {
126        let callbacks = self.callbacks.clone();
127        tokio::task::block_in_place(|| {
128            tokio::runtime::Handle::current().block_on(async {
129                callbacks.write().await.on_message.push(Arc::new(callback));
130            });
131        });
132        self
133    }
134
135    /// Register event callback
136    pub fn on_event<F>(&mut self, callback: F) -> &mut Self
137    where
138        F: Fn(MIPMessage) + Send + Sync + 'static,
139    {
140        let callbacks = self.callbacks.clone();
141        tokio::task::block_in_place(|| {
142            tokio::runtime::Handle::current().block_on(async {
143                callbacks.write().await.on_event.push(Arc::new(callback));
144            });
145        });
146        self
147    }
148
149    /// Register ACK event callback
150    pub fn on_ack<F>(&mut self, callback: F) -> &mut Self
151    where
152        F: Fn(u64) + Send + Sync + 'static,
153    {
154        let callbacks = self.callbacks.clone();
155        tokio::task::block_in_place(|| {
156            tokio::runtime::Handle::current().block_on(async {
157                callbacks.write().await.on_ack.push(Arc::new(callback));
158            });
159        });
160        self
161    }
162
163    /// Register pong event callback
164    pub fn on_pong<F>(&mut self, callback: F) -> &mut Self
165    where
166        F: Fn() + Send + Sync + 'static,
167    {
168        let callbacks = self.callbacks.clone();
169        tokio::task::block_in_place(|| {
170            tokio::runtime::Handle::current().block_on(async {
171                callbacks.write().await.on_pong.push(Arc::new(callback));
172            });
173        });
174        self
175    }
176
177    /// Register error event callback
178    pub fn on_error<F>(&mut self, callback: F) -> &mut Self
179    where
180        F: Fn(MIPError) + Send + Sync + 'static,
181    {
182        let callbacks = self.callbacks.clone();
183        tokio::task::block_in_place(|| {
184            tokio::runtime::Handle::current().block_on(async {
185                callbacks.write().await.on_error.push(Arc::new(callback));
186            });
187        });
188        self
189    }
190
191    /// Register raw frame event callback
192    pub fn on_frame<F>(&mut self, callback: F) -> &mut Self
193    where
194        F: Fn(crate::protocol::header::Header, Vec<u8>) + Send + Sync + 'static,
195    {
196        let callbacks = self.callbacks.clone();
197        tokio::task::block_in_place(|| {
198            tokio::runtime::Handle::current().block_on(async {
199                callbacks.write().await.on_frame.push(Arc::new(callback));
200            });
201        });
202        self
203    }
204
205    // --------------------------------------------------------------------------
206    // Public API
207    // --------------------------------------------------------------------------
208
209    /// Connect to the MIP server
210    pub async fn connect(&mut self) -> MIPResult<()> {
211        if self.connected.load(Ordering::SeqCst) {
212            return Ok(());
213        }
214
215        self.running.store(true, Ordering::SeqCst);
216
217        let options = self.options.read().await;
218        let addr = format!("{}:{}", options.host, options.port);
219        drop(options);
220
221        debug!("Connecting to {}", addr);
222
223        let stream = TcpStream::connect(&addr)
224            .await
225            .map_err(|e| MIPError::Connection(e.to_string()))?;
226
227        let (read_half, write_half) = stream.into_split();
228
229        // Create command channel
230        let (command_tx, command_rx) = mpsc::channel::<ClientCommand>(100);
231        self.command_tx = Some(command_tx);
232
233        self.connected.store(true, Ordering::SeqCst);
234        self.reconnect_attempts.store(0, Ordering::SeqCst);
235
236        // Start read task
237        let read_task = spawn_read_task(
238            read_half,
239            self.connected.clone(),
240            self.running.clone(),
241            self.callbacks.clone(),
242            self.options.clone(),
243            self.reconnect_attempts.clone(),
244        );
245        self.read_task = Some(read_task);
246
247        // Start write task
248        let write_task = spawn_write_task(write_half, command_rx);
249        self.write_task = Some(write_task);
250
251        // Setup ping interval
252        let options = self.options.read().await;
253        let ping_interval = options.ping_interval_ms;
254        drop(options);
255
256        self.ping_task = spawn_ping_task(
257            ping_interval,
258            self.connected.clone(),
259            self.running.clone(),
260            self.command_tx.clone(),
261            self.msg_id_counter.clone(),
262        );
263
264        // Emit connect event
265        let callbacks = self.callbacks.read().await;
266        for callback in &callbacks.on_connect {
267            callback();
268        }
269
270        info!("Connected to {}", addr);
271        Ok(())
272    }
273
274    /// Disconnect from the server
275    pub async fn disconnect(&mut self) -> MIPResult<()> {
276        {
277            let mut options = self.options.write().await;
278            options.auto_reconnect = false;
279        }
280
281        self.running.store(false, Ordering::SeqCst);
282
283        // Send close frame
284        if let Some(tx) = &self.command_tx {
285            let _ = send_close_internal(tx, &self.connected, &self.msg_id_counter).await;
286            let _ = tx.send(ClientCommand::Disconnect).await;
287        }
288
289        cleanup_tasks(&mut self.ping_task, &mut self.read_task, &mut self.write_task).await;
290        self.connected.store(false, Ordering::SeqCst);
291        self.command_tx = None;
292
293        info!("Disconnected");
294        Ok(())
295    }
296
297    /// Subscribe to a topic
298    pub fn subscribe(&self, topic: &str, require_ack: bool) -> MIPResult<u64> {
299        let topic_bytes = topic.as_bytes();
300        let flags = if require_ack {
301            FrameFlags::ACK_REQUIRED
302        } else {
303            FrameFlags::NONE
304        };
305        send_frame(
306            self.command_tx.as_ref(),
307            &self.connected,
308            &self.msg_id_counter,
309            FrameType::Subscribe,
310            topic_bytes,
311            flags,
312        )
313    }
314
315    /// Unsubscribe from a topic
316    pub fn unsubscribe(&self, topic: &str, require_ack: bool) -> MIPResult<u64> {
317        let topic_bytes = topic.as_bytes();
318        let flags = if require_ack {
319            FrameFlags::ACK_REQUIRED
320        } else {
321            FrameFlags::NONE
322        };
323        send_frame(
324            self.command_tx.as_ref(),
325            &self.connected,
326            &self.msg_id_counter,
327            FrameType::Unsubscribe,
328            topic_bytes,
329            flags,
330        )
331    }
332
333    /// Publish a message to a topic
334    pub fn publish(&self, topic: &str, message: &str, flags: FrameFlags) -> MIPResult<u64> {
335        let topic_bytes = topic.as_bytes();
336        let message_bytes = message.as_bytes();
337
338        // Build payload: [topic_length (2 bytes)] [topic] [message]
339        let mut payload = Vec::with_capacity(2 + topic_bytes.len() + message_bytes.len());
340        payload.extend_from_slice(&(topic_bytes.len() as u16).to_be_bytes());
341        payload.extend_from_slice(topic_bytes);
342        payload.extend_from_slice(message_bytes);
343
344        send_frame(
345            self.command_tx.as_ref(),
346            &self.connected,
347            &self.msg_id_counter,
348            FrameType::Publish,
349            &payload,
350            flags,
351        )
352    }
353
354    /// Send a ping to the server
355    pub fn ping(&self) -> MIPResult<u64> {
356        send_frame(
357            self.command_tx.as_ref(),
358            &self.connected,
359            &self.msg_id_counter,
360            FrameType::Ping,
361            &[],
362            FrameFlags::NONE,
363        )
364    }
365
366    /// Send raw frame (advanced usage)
367    pub fn send_raw_frame(
368        &self,
369        frame_type: FrameType,
370        payload: &[u8],
371        flags: FrameFlags,
372    ) -> MIPResult<u64> {
373        send_frame(
374            self.command_tx.as_ref(),
375            &self.connected,
376            &self.msg_id_counter,
377            frame_type,
378            payload,
379            flags,
380        )
381    }
382}
383
384impl Drop for MIPClient {
385    fn drop(&mut self) {
386        self.running.store(false, Ordering::SeqCst);
387    }
388}
389
390// ============================================================================
391// Utility Functions
392// ============================================================================
393
394/// Get the name of a frame type
395pub fn get_frame_type_name(frame_type: FrameType) -> &'static str {
396    match frame_type {
397        FrameType::Hello => "HELLO",
398        FrameType::Subscribe => "SUBSCRIBE",
399        FrameType::Unsubscribe => "UNSUBSCRIBE",
400        FrameType::Publish => "PUBLISH",
401        FrameType::Event => "EVENT",
402        FrameType::Ack => "ACK",
403        FrameType::Error => "ERROR",
404        FrameType::Ping => "PING",
405        FrameType::Pong => "PONG",
406        FrameType::Close => "CLOSE",
407    }
408}
409
410/// Create a new MIP client with default options
411pub fn create_client() -> MIPClient {
412    MIPClient::new(MIPClientOptions::default())
413}