1mod events;
2mod internals;
3mod types;
4
5use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
6use std::sync::Arc;
7
8use tokio::net::TcpStream;
9use tokio::sync::{mpsc, RwLock};
10use tokio::task::JoinHandle;
11use tracing::{debug, info};
12
13use crate::protocol::header::{FrameFlags, FrameType};
14
15pub use events::{
16 OnAck, OnConnect, OnDisconnect, OnError, OnEvent, OnFrame, OnMessage, OnPong, OnReconnecting,
17};
18pub use types::{MIPClientOptions, MIPError, MIPMessage, MIPResult};
19
20use events::Callbacks;
21use internals::{
22 cleanup_tasks, send_close_internal, send_frame, spawn_ping_task, spawn_read_task,
23 spawn_write_task, ClientCommand,
24};
25
26pub struct MIPClient {
32 client_id: String,
33 options: Arc<RwLock<MIPClientOptions>>,
34 connected: Arc<AtomicBool>,
35 running: Arc<AtomicBool>,
36 msg_id_counter: Arc<AtomicU64>,
37 reconnect_attempts: Arc<AtomicU64>,
38
39 callbacks: Arc<RwLock<Callbacks>>,
40 command_tx: Option<mpsc::Sender<ClientCommand>>,
41
42 read_task: Option<JoinHandle<()>>,
43 write_task: Option<JoinHandle<()>>,
44 ping_task: Option<JoinHandle<()>>,
45}
46
47impl MIPClient {
48 pub fn new(options: MIPClientOptions) -> Self {
49 Self {
50 client_id: options.client_id.clone(),
51 options: Arc::new(RwLock::new(options)),
52 connected: Arc::new(AtomicBool::new(false)),
53 running: Arc::new(AtomicBool::new(false)),
54 msg_id_counter: Arc::new(AtomicU64::new(0)),
55 reconnect_attempts: Arc::new(AtomicU64::new(0)),
56 callbacks: Arc::new(RwLock::new(Callbacks::default())),
57 command_tx: None,
58 read_task: None,
59 write_task: None,
60 ping_task: None,
61 }
62 }
63
64 pub fn is_connected(&self) -> bool {
66 self.connected.load(Ordering::SeqCst)
67 }
68
69 pub fn on_connect<F>(&mut self, callback: F) -> &mut Self
75 where
76 F: Fn() + Send + Sync + 'static,
77 {
78 let callbacks = self.callbacks.clone();
79 tokio::task::block_in_place(|| {
80 tokio::runtime::Handle::current().block_on(async {
81 callbacks.write().await.on_connect.push(Arc::new(callback));
82 });
83 });
84 self
85 }
86
87 pub fn on_disconnect<F>(&mut self, callback: F) -> &mut Self
89 where
90 F: Fn() + Send + Sync + 'static,
91 {
92 let callbacks = self.callbacks.clone();
93 tokio::task::block_in_place(|| {
94 tokio::runtime::Handle::current().block_on(async {
95 callbacks
96 .write()
97 .await
98 .on_disconnect
99 .push(Arc::new(callback));
100 });
101 });
102 self
103 }
104
105 pub fn on_reconnecting<F>(&mut self, callback: F) -> &mut Self
107 where
108 F: Fn(u32) + Send + Sync + 'static,
109 {
110 let callbacks = self.callbacks.clone();
111 tokio::task::block_in_place(|| {
112 tokio::runtime::Handle::current().block_on(async {
113 callbacks
114 .write()
115 .await
116 .on_reconnecting
117 .push(Arc::new(callback));
118 });
119 });
120 self
121 }
122
123 pub fn on_message<F>(&mut self, callback: F) -> &mut Self
125 where
126 F: Fn(MIPMessage) + Send + Sync + 'static,
127 {
128 let callbacks = self.callbacks.clone();
129 tokio::task::block_in_place(|| {
130 tokio::runtime::Handle::current().block_on(async {
131 callbacks.write().await.on_message.push(Arc::new(callback));
132 });
133 });
134 self
135 }
136
137 pub fn on_event<F>(&mut self, callback: F) -> &mut Self
139 where
140 F: Fn(MIPMessage) + Send + Sync + 'static,
141 {
142 let callbacks = self.callbacks.clone();
143 tokio::task::block_in_place(|| {
144 tokio::runtime::Handle::current().block_on(async {
145 callbacks.write().await.on_event.push(Arc::new(callback));
146 });
147 });
148 self
149 }
150
151 pub fn on_ack<F>(&mut self, callback: F) -> &mut Self
153 where
154 F: Fn(u64) + Send + Sync + 'static,
155 {
156 let callbacks = self.callbacks.clone();
157 tokio::task::block_in_place(|| {
158 tokio::runtime::Handle::current().block_on(async {
159 callbacks.write().await.on_ack.push(Arc::new(callback));
160 });
161 });
162 self
163 }
164
165 pub fn on_pong<F>(&mut self, callback: F) -> &mut Self
167 where
168 F: Fn() + Send + Sync + 'static,
169 {
170 let callbacks = self.callbacks.clone();
171 tokio::task::block_in_place(|| {
172 tokio::runtime::Handle::current().block_on(async {
173 callbacks.write().await.on_pong.push(Arc::new(callback));
174 });
175 });
176 self
177 }
178
179 pub fn on_error<F>(&mut self, callback: F) -> &mut Self
181 where
182 F: Fn(MIPError) + Send + Sync + 'static,
183 {
184 let callbacks = self.callbacks.clone();
185 tokio::task::block_in_place(|| {
186 tokio::runtime::Handle::current().block_on(async {
187 callbacks.write().await.on_error.push(Arc::new(callback));
188 });
189 });
190 self
191 }
192
193 pub fn on_frame<F>(&mut self, callback: F) -> &mut Self
195 where
196 F: Fn(crate::protocol::header::Header, Vec<u8>) + Send + Sync + 'static,
197 {
198 let callbacks = self.callbacks.clone();
199 tokio::task::block_in_place(|| {
200 tokio::runtime::Handle::current().block_on(async {
201 callbacks.write().await.on_frame.push(Arc::new(callback));
202 });
203 });
204 self
205 }
206
207 pub async fn connect(&mut self) -> MIPResult<()> {
213 if self.connected.load(Ordering::SeqCst) {
214 return Ok(());
215 }
216
217 self.running.store(true, Ordering::SeqCst);
218
219 let options = self.options.read().await;
220 let addr = format!("{}:{}", options.host, options.port);
221 drop(options);
222
223 debug!("Connecting to {}", addr);
224
225 let stream = TcpStream::connect(&addr)
226 .await
227 .map_err(|e| MIPError::Connection(e.to_string()))?;
228
229 let (read_half, write_half) = stream.into_split();
230
231 let (command_tx, command_rx) = mpsc::channel::<ClientCommand>(100);
233 self.command_tx = Some(command_tx);
234
235 self.connected.store(true, Ordering::SeqCst);
236 self.reconnect_attempts.store(0, Ordering::SeqCst);
237
238 let read_task = spawn_read_task(
240 read_half,
241 self.connected.clone(),
242 self.running.clone(),
243 self.callbacks.clone(),
244 self.options.clone(),
245 self.reconnect_attempts.clone(),
246 );
247 self.read_task = Some(read_task);
248
249 let write_task = spawn_write_task(write_half, command_rx);
251 self.write_task = Some(write_task);
252
253 let options = self.options.read().await;
255 let ping_interval = options.ping_interval_ms;
256 drop(options);
257
258 self.ping_task = spawn_ping_task(
259 ping_interval,
260 self.connected.clone(),
261 self.running.clone(),
262 self.command_tx.clone(),
263 self.msg_id_counter.clone(),
264 );
265
266 let callbacks = self.callbacks.read().await;
268 for callback in &callbacks.on_connect {
269 callback();
270 }
271
272 let client_id = self.options.read().await.client_id.clone();
273 let payload = client_id.as_bytes();
274
275 if let Some(tx) = &self.command_tx {
276 let res = send_frame(
277 Some(tx),
278 &self.connected.clone(),
279 &self.msg_id_counter.clone(),
280 FrameType::Hello, payload,
281 FrameFlags::NONE
282 );
283
284 if let Err(e) = res {
285 println!("Failed to send HELLO frame: {}", e);
286 self.connected.store(false, Ordering::SeqCst);
287 self.running.store(false, Ordering::SeqCst);
288 return Err(e);
289 } else {
290 self.client_id = res.unwrap().to_string();
291 }
292 }
293
294 info!("Connected to {}", addr);
295 Ok(())
296 }
297
298 pub async fn disconnect(&mut self) -> MIPResult<()> {
300 {
301 let mut options = self.options.write().await;
302 options.auto_reconnect = false;
303 }
304
305 self.running.store(false, Ordering::SeqCst);
306
307 if let Some(tx) = &self.command_tx {
309 let _ = send_close_internal(tx, &self.connected, &self.msg_id_counter).await;
310 let _ = tx.send(ClientCommand::Disconnect).await;
311 }
312
313 cleanup_tasks(&mut self.ping_task, &mut self.read_task, &mut self.write_task).await;
314 self.connected.store(false, Ordering::SeqCst);
315 self.command_tx = None;
316
317 info!("Disconnected");
318 Ok(())
319 }
320
321 pub fn subscribe(&self, topic: &str, require_ack: bool) -> MIPResult<u64> {
323 let topic_bytes = topic.as_bytes();
324 let flags = if require_ack {
325 FrameFlags::ACK_REQUIRED
326 } else {
327 FrameFlags::NONE
328 };
329 send_frame(
330 self.command_tx.as_ref(),
331 &self.connected,
332 &self.msg_id_counter,
333 FrameType::Subscribe,
334 topic_bytes,
335 flags,
336 )
337 }
338
339 pub fn unsubscribe(&self, topic: &str, require_ack: bool) -> MIPResult<u64> {
341 let topic_bytes = topic.as_bytes();
342 let flags = if require_ack {
343 FrameFlags::ACK_REQUIRED
344 } else {
345 FrameFlags::NONE
346 };
347 send_frame(
348 self.command_tx.as_ref(),
349 &self.connected,
350 &self.msg_id_counter,
351 FrameType::Unsubscribe,
352 topic_bytes,
353 flags,
354 )
355 }
356
357 pub fn publish(&self, topic: &str, message: &str, flags: FrameFlags) -> MIPResult<u64> {
359 let topic_bytes = topic.as_bytes();
360 let message_bytes = message.as_bytes();
361
362 let mut payload = Vec::with_capacity(2 + topic_bytes.len() + message_bytes.len());
364 payload.extend_from_slice(&(topic_bytes.len() as u16).to_be_bytes());
365 payload.extend_from_slice(topic_bytes);
366 payload.extend_from_slice(message_bytes);
367
368 send_frame(
369 self.command_tx.as_ref(),
370 &self.connected,
371 &self.msg_id_counter,
372 FrameType::Publish,
373 &payload,
374 flags,
375 )
376 }
377
378 pub fn ping(&self) -> MIPResult<u64> {
380 send_frame(
381 self.command_tx.as_ref(),
382 &self.connected,
383 &self.msg_id_counter,
384 FrameType::Ping,
385 &[],
386 FrameFlags::NONE,
387 )
388 }
389
390 pub fn send_raw_frame(
392 &self,
393 frame_type: FrameType,
394 payload: &[u8],
395 flags: FrameFlags,
396 ) -> MIPResult<u64> {
397 send_frame(
398 self.command_tx.as_ref(),
399 &self.connected,
400 &self.msg_id_counter,
401 frame_type,
402 payload,
403 flags,
404 )
405 }
406}
407
408impl Drop for MIPClient {
409 fn drop(&mut self) {
410 self.running.store(false, Ordering::SeqCst);
411 }
412}
413
414pub fn get_frame_type_name(frame_type: FrameType) -> &'static str {
420 match frame_type {
421 FrameType::Hello => "HELLO",
422 FrameType::Subscribe => "SUBSCRIBE",
423 FrameType::Unsubscribe => "UNSUBSCRIBE",
424 FrameType::Publish => "PUBLISH",
425 FrameType::Event => "EVENT",
426 FrameType::Ack => "ACK",
427 FrameType::Error => "ERROR",
428 FrameType::Ping => "PING",
429 FrameType::Pong => "PONG",
430 FrameType::Close => "CLOSE",
431 }
432}
433
434pub fn create_client() -> MIPClient {
436 MIPClient::new(MIPClientOptions::default())
437}