1mod events;
2mod internals;
3mod types;
4
5use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
6use std::sync::Arc;
7
8use tokio::net::TcpStream;
9use tokio::sync::{mpsc, RwLock};
10use tokio::task::JoinHandle;
11use tracing::{debug, info};
12
13use crate::protocol::header::{FrameFlags, FrameType};
14
15pub use events::{
16 OnAck, OnConnect, OnDisconnect, OnError, OnEvent, OnFrame, OnMessage, OnPong, OnReconnecting,
17};
18pub use types::{MIPClientOptions, MIPError, MIPMessage, MIPResult};
19
20use events::Callbacks;
21use internals::{
22 cleanup_tasks, send_close_internal, send_frame, spawn_ping_task, spawn_read_task,
23 spawn_write_task, ClientCommand,
24};
25
26pub struct MIPClient {
32 options: Arc<RwLock<MIPClientOptions>>,
33 connected: Arc<AtomicBool>,
34 running: Arc<AtomicBool>,
35 msg_id_counter: Arc<AtomicU64>,
36 reconnect_attempts: Arc<AtomicU64>,
37
38 callbacks: Arc<RwLock<Callbacks>>,
39 command_tx: Option<mpsc::Sender<ClientCommand>>,
40
41 read_task: Option<JoinHandle<()>>,
42 write_task: Option<JoinHandle<()>>,
43 ping_task: Option<JoinHandle<()>>,
44}
45
46impl MIPClient {
47 pub fn new(options: MIPClientOptions) -> Self {
48 Self {
49 options: Arc::new(RwLock::new(options)),
50 connected: Arc::new(AtomicBool::new(false)),
51 running: Arc::new(AtomicBool::new(false)),
52 msg_id_counter: Arc::new(AtomicU64::new(0)),
53 reconnect_attempts: Arc::new(AtomicU64::new(0)),
54 callbacks: Arc::new(RwLock::new(Callbacks::default())),
55 command_tx: None,
56 read_task: None,
57 write_task: None,
58 ping_task: None,
59 }
60 }
61
62 pub fn is_connected(&self) -> bool {
64 self.connected.load(Ordering::SeqCst)
65 }
66
67 pub fn on_connect<F>(&mut self, callback: F) -> &mut Self
73 where
74 F: Fn() + Send + Sync + 'static,
75 {
76 let callbacks = self.callbacks.clone();
77 tokio::task::block_in_place(|| {
78 tokio::runtime::Handle::current().block_on(async {
79 callbacks.write().await.on_connect.push(Arc::new(callback));
80 });
81 });
82 self
83 }
84
85 pub fn on_disconnect<F>(&mut self, callback: F) -> &mut Self
87 where
88 F: Fn() + Send + Sync + 'static,
89 {
90 let callbacks = self.callbacks.clone();
91 tokio::task::block_in_place(|| {
92 tokio::runtime::Handle::current().block_on(async {
93 callbacks
94 .write()
95 .await
96 .on_disconnect
97 .push(Arc::new(callback));
98 });
99 });
100 self
101 }
102
103 pub fn on_reconnecting<F>(&mut self, callback: F) -> &mut Self
105 where
106 F: Fn(u32) + Send + Sync + 'static,
107 {
108 let callbacks = self.callbacks.clone();
109 tokio::task::block_in_place(|| {
110 tokio::runtime::Handle::current().block_on(async {
111 callbacks
112 .write()
113 .await
114 .on_reconnecting
115 .push(Arc::new(callback));
116 });
117 });
118 self
119 }
120
121 pub fn on_message<F>(&mut self, callback: F) -> &mut Self
123 where
124 F: Fn(MIPMessage) + Send + Sync + 'static,
125 {
126 let callbacks = self.callbacks.clone();
127 tokio::task::block_in_place(|| {
128 tokio::runtime::Handle::current().block_on(async {
129 callbacks.write().await.on_message.push(Arc::new(callback));
130 });
131 });
132 self
133 }
134
135 pub fn on_event<F>(&mut self, callback: F) -> &mut Self
137 where
138 F: Fn(MIPMessage) + Send + Sync + 'static,
139 {
140 let callbacks = self.callbacks.clone();
141 tokio::task::block_in_place(|| {
142 tokio::runtime::Handle::current().block_on(async {
143 callbacks.write().await.on_event.push(Arc::new(callback));
144 });
145 });
146 self
147 }
148
149 pub fn on_ack<F>(&mut self, callback: F) -> &mut Self
151 where
152 F: Fn(u64) + Send + Sync + 'static,
153 {
154 let callbacks = self.callbacks.clone();
155 tokio::task::block_in_place(|| {
156 tokio::runtime::Handle::current().block_on(async {
157 callbacks.write().await.on_ack.push(Arc::new(callback));
158 });
159 });
160 self
161 }
162
163 pub fn on_pong<F>(&mut self, callback: F) -> &mut Self
165 where
166 F: Fn() + Send + Sync + 'static,
167 {
168 let callbacks = self.callbacks.clone();
169 tokio::task::block_in_place(|| {
170 tokio::runtime::Handle::current().block_on(async {
171 callbacks.write().await.on_pong.push(Arc::new(callback));
172 });
173 });
174 self
175 }
176
177 pub fn on_error<F>(&mut self, callback: F) -> &mut Self
179 where
180 F: Fn(MIPError) + Send + Sync + 'static,
181 {
182 let callbacks = self.callbacks.clone();
183 tokio::task::block_in_place(|| {
184 tokio::runtime::Handle::current().block_on(async {
185 callbacks.write().await.on_error.push(Arc::new(callback));
186 });
187 });
188 self
189 }
190
191 pub fn on_frame<F>(&mut self, callback: F) -> &mut Self
193 where
194 F: Fn(crate::protocol::header::Header, Vec<u8>) + Send + Sync + 'static,
195 {
196 let callbacks = self.callbacks.clone();
197 tokio::task::block_in_place(|| {
198 tokio::runtime::Handle::current().block_on(async {
199 callbacks.write().await.on_frame.push(Arc::new(callback));
200 });
201 });
202 self
203 }
204
205 pub async fn connect(&mut self) -> MIPResult<()> {
211 if self.connected.load(Ordering::SeqCst) {
212 return Ok(());
213 }
214
215 self.running.store(true, Ordering::SeqCst);
216
217 let options = self.options.read().await;
218 let addr = format!("{}:{}", options.host, options.port);
219 drop(options);
220
221 debug!("Connecting to {}", addr);
222
223 let stream = TcpStream::connect(&addr)
224 .await
225 .map_err(|e| MIPError::Connection(e.to_string()))?;
226
227 let (read_half, write_half) = stream.into_split();
228
229 let (command_tx, command_rx) = mpsc::channel::<ClientCommand>(100);
231 self.command_tx = Some(command_tx);
232
233 self.connected.store(true, Ordering::SeqCst);
234 self.reconnect_attempts.store(0, Ordering::SeqCst);
235
236 let read_task = spawn_read_task(
238 read_half,
239 self.connected.clone(),
240 self.running.clone(),
241 self.callbacks.clone(),
242 self.options.clone(),
243 self.reconnect_attempts.clone(),
244 );
245 self.read_task = Some(read_task);
246
247 let write_task = spawn_write_task(write_half, command_rx);
249 self.write_task = Some(write_task);
250
251 let options = self.options.read().await;
253 let ping_interval = options.ping_interval_ms;
254 drop(options);
255
256 self.ping_task = spawn_ping_task(
257 ping_interval,
258 self.connected.clone(),
259 self.running.clone(),
260 self.command_tx.clone(),
261 self.msg_id_counter.clone(),
262 );
263
264 let callbacks = self.callbacks.read().await;
266 for callback in &callbacks.on_connect {
267 callback();
268 }
269
270 info!("Connected to {}", addr);
271 Ok(())
272 }
273
274 pub async fn disconnect(&mut self) -> MIPResult<()> {
276 {
277 let mut options = self.options.write().await;
278 options.auto_reconnect = false;
279 }
280
281 self.running.store(false, Ordering::SeqCst);
282
283 if let Some(tx) = &self.command_tx {
285 let _ = send_close_internal(tx, &self.connected, &self.msg_id_counter).await;
286 let _ = tx.send(ClientCommand::Disconnect).await;
287 }
288
289 cleanup_tasks(&mut self.ping_task, &mut self.read_task, &mut self.write_task).await;
290 self.connected.store(false, Ordering::SeqCst);
291 self.command_tx = None;
292
293 info!("Disconnected");
294 Ok(())
295 }
296
297 pub fn subscribe(&self, topic: &str, require_ack: bool) -> MIPResult<u64> {
299 let topic_bytes = topic.as_bytes();
300 let flags = if require_ack {
301 FrameFlags::ACK_REQUIRED
302 } else {
303 FrameFlags::NONE
304 };
305 send_frame(
306 self.command_tx.as_ref(),
307 &self.connected,
308 &self.msg_id_counter,
309 FrameType::Subscribe,
310 topic_bytes,
311 flags,
312 )
313 }
314
315 pub fn unsubscribe(&self, topic: &str, require_ack: bool) -> MIPResult<u64> {
317 let topic_bytes = topic.as_bytes();
318 let flags = if require_ack {
319 FrameFlags::ACK_REQUIRED
320 } else {
321 FrameFlags::NONE
322 };
323 send_frame(
324 self.command_tx.as_ref(),
325 &self.connected,
326 &self.msg_id_counter,
327 FrameType::Unsubscribe,
328 topic_bytes,
329 flags,
330 )
331 }
332
333 pub fn publish(&self, topic: &str, message: &str, flags: FrameFlags) -> MIPResult<u64> {
335 let topic_bytes = topic.as_bytes();
336 let message_bytes = message.as_bytes();
337
338 let mut payload = Vec::with_capacity(2 + topic_bytes.len() + message_bytes.len());
340 payload.extend_from_slice(&(topic_bytes.len() as u16).to_be_bytes());
341 payload.extend_from_slice(topic_bytes);
342 payload.extend_from_slice(message_bytes);
343
344 send_frame(
345 self.command_tx.as_ref(),
346 &self.connected,
347 &self.msg_id_counter,
348 FrameType::Publish,
349 &payload,
350 flags,
351 )
352 }
353
354 pub fn ping(&self) -> MIPResult<u64> {
356 send_frame(
357 self.command_tx.as_ref(),
358 &self.connected,
359 &self.msg_id_counter,
360 FrameType::Ping,
361 &[],
362 FrameFlags::NONE,
363 )
364 }
365
366 pub fn send_raw_frame(
368 &self,
369 frame_type: FrameType,
370 payload: &[u8],
371 flags: FrameFlags,
372 ) -> MIPResult<u64> {
373 send_frame(
374 self.command_tx.as_ref(),
375 &self.connected,
376 &self.msg_id_counter,
377 frame_type,
378 payload,
379 flags,
380 )
381 }
382}
383
384impl Drop for MIPClient {
385 fn drop(&mut self) {
386 self.running.store(false, Ordering::SeqCst);
387 }
388}
389
390pub fn get_frame_type_name(frame_type: FrameType) -> &'static str {
396 match frame_type {
397 FrameType::Hello => "HELLO",
398 FrameType::Subscribe => "SUBSCRIBE",
399 FrameType::Unsubscribe => "UNSUBSCRIBE",
400 FrameType::Publish => "PUBLISH",
401 FrameType::Event => "EVENT",
402 FrameType::Ack => "ACK",
403 FrameType::Error => "ERROR",
404 FrameType::Ping => "PING",
405 FrameType::Pong => "PONG",
406 FrameType::Close => "CLOSE",
407 }
408}
409
410pub fn create_client() -> MIPClient {
412 MIPClient::new(MIPClientOptions::default())
413}