1use rumq_core::mqtt4::{has_wildcards, matches, QoS, Packet, Connect, Publish, Subscribe, Unsubscribe};
2use tokio::sync::mpsc::{Receiver, Sender};
3use tokio::sync::mpsc::error::TrySendError;
4use tokio::select;
5use tokio::time::{self, Duration};
6use tokio::stream::StreamExt;
7
8use std::collections::{HashMap, VecDeque};
9use std::mem;
10use std::fmt;
11
12use crate::state::{self, MqttState};
13
14#[derive(Debug, thiserror::Error)]
15pub enum Error {
16 #[error("State")]
17 State(#[from] state::Error),
18 #[error("All senders down")]
19 AllSendersDown,
20 #[error("Error sending message on internal bus")]
21 Mpsc(#[from] TrySendError<RouterMessage>),
22}
23
24#[derive(Debug)]
28pub enum RouterMessage {
29 Connect(Connection),
31 Packet(Packet),
33 Packets(VecDeque<Packet>),
35 Death(String),
37 Pending(VecDeque<Publish>)
39}
40
41pub struct Connection {
42 pub connect: Connect,
43 pub handle: Option<Sender<RouterMessage>>
44}
45
46impl Connection {
47 pub fn new(connect: Connect, handle: Sender<RouterMessage>) -> Connection {
48 Connection {
49 connect,
50 handle: Some(handle)
51 }
52 }
53}
54
55impl fmt::Debug for Connection {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 write!(f, "{:?}", self.connect)
58 }
59}
60
61#[derive(Debug)]
62struct ActiveConnection {
63 pub state: MqttState,
64 pub outgoing: VecDeque<Packet>,
65 tx: Sender<RouterMessage>
66}
67
68impl ActiveConnection {
69 pub fn new(tx: Sender<RouterMessage>, state: MqttState) -> ActiveConnection {
70 ActiveConnection {
71 state,
72 outgoing: VecDeque::new(),
73 tx
74 }
75 }
76}
77
78#[derive(Debug)]
79struct InactiveConnection {
80 pub state: MqttState
81}
82
83impl InactiveConnection {
84 pub fn new(state: MqttState) -> InactiveConnection {
85 InactiveConnection {
86 state,
87 }
88 }
89}
90
91#[derive(Debug, Clone)]
92struct Subscriber {
93 client_id: String,
94 qos: QoS,
95}
96
97impl Subscriber {
98 pub fn new(id: &str, qos: QoS) -> Subscriber {
99 Subscriber {
100 client_id: id.to_owned(),
101 qos,
102 }
103 }
104}
105
106#[derive(Debug)]
107pub struct Router {
108 active_connections: HashMap<String, ActiveConnection>,
110 inactive_connections: HashMap<String, InactiveConnection>,
112 concrete_subscriptions: HashMap<String, Vec<Subscriber>>,
114 wild_subscriptions: HashMap<String, Vec<Subscriber>>,
116 retained_publishes: HashMap<String, Publish>,
118 data_rx: Receiver<(String, RouterMessage)>,
121}
122
123impl Router {
124 pub fn new(data_rx: Receiver<(String, RouterMessage)>) -> Self {
125 Router {
126 active_connections: HashMap::new(),
127 inactive_connections: HashMap::new(),
128 concrete_subscriptions: HashMap::new(),
129 wild_subscriptions: HashMap::new(),
130 retained_publishes: HashMap::new(),
131 data_rx,
132 }
133 }
134
135 pub async fn start(&mut self) -> Result<(), Error> {
136 let mut interval = time::interval(Duration::from_millis(10));
137
138 loop {
139 select! {
140 o = self.data_rx.recv() => {
141 let (id, mut message) = o.unwrap();
142 debug!("In router message. Id = {}, {:?}", id, message);
143 match self.reply(id.clone(), &mut message) {
144 Ok(Some(message)) => self.forward(&id, message),
145 Ok(None) => (),
146 Err(e) => {
147 error!("Incoming handle error = {:?}", e);
148 continue;
149 }
150 }
151
152 if let Err(e) = self.route(id, message) {
154 error!("Routing error = {:?}", e);
155 }
156 }
157 _ = interval.next() => {
158 for (_, connection) in self.active_connections.iter_mut() {
159 let pending = connection.outgoing.split_off(0);
160 if pending.len() > 0 {
161 let _ = connection.tx.try_send(RouterMessage::Packets(pending));
162 }
163 }
164 }
165 }
166
167
168 }
169 }
170
171 fn reply(&mut self, id: String, message: &mut RouterMessage) -> Result<Option<RouterMessage>, Error> {
179 match message {
180 RouterMessage::Connect(connection) => {
181 let handle = connection.handle.take().unwrap();
182 let message = self.handle_connect(connection.connect.clone(), handle)?;
183 Ok(message)
184 }
185 RouterMessage::Packet(packet) => {
186 let message = self.handle_incoming_packet(&id, packet.clone())?;
187 Ok(message)
188 }
189 _ => Ok(None)
190 }
191 }
192
193 fn route(&mut self, id: String, message: RouterMessage) -> Result<(), Error> {
194 match message {
195 RouterMessage::Packet(packet) => {
196 match packet {
197 Packet::Publish(publish) => self.match_subscriptions(&id, publish),
198 Packet::Subscribe(subscribe) => self.add_to_subscriptions(id, subscribe),
199 Packet::Unsubscribe(unsubscribe) => self.remove_from_subscriptions(id, unsubscribe),
200 Packet::Disconnect => self.deactivate(id),
201 _ => return Ok(())
202 }
203 }
204 RouterMessage::Death(id) => {
205 self.deactivate_and_forward_will(id.to_owned());
206 }
207 _ => ()
208 }
209 Ok(())
210 }
211
212 fn handle_connect(&mut self, connect: Connect, connection_handle: Sender<RouterMessage>) -> Result<Option<RouterMessage>, Error> {
213 let id = connect.client_id;
214 let clean_session = connect.clean_session;
215 let will = connect.last_will;
216
217 info!("Connect. Id = {:?}", id);
218 let reply = if clean_session {
219 self.inactive_connections.remove(&id);
220
221 let state = MqttState::new(clean_session, will);
222 self.active_connections.insert(id.clone(), ActiveConnection::new(connection_handle, state));
223 Some(RouterMessage::Pending(VecDeque::new()))
224 } else {
225 if let Some(connection) = self.inactive_connections.remove(&id) {
226 let pending = connection.state.outgoing_publishes.clone();
227 self.active_connections.insert(id.clone(), ActiveConnection::new(connection_handle, connection.state));
228 Some(RouterMessage::Pending(pending))
229 } else {
230 let state = MqttState::new(clean_session, will);
231 self.active_connections.insert(id.clone(), ActiveConnection::new(connection_handle, state));
232 Some(RouterMessage::Pending(VecDeque::new()))
233 }
234 };
235
236 if clean_session {
237 for (_, subscribers) in self.concrete_subscriptions.iter_mut() {
239 if let Some(index) = subscribers.iter().position(|s| s.client_id == id) {
240 subscribers.remove(index);
241 }
242 }
243
244 for (_, subscribers) in self.wild_subscriptions.iter_mut() {
245 if let Some(index) = subscribers.iter().position(|s| s.client_id == id) {
246 subscribers.remove(index);
247 }
248 }
249 }
250
251 Ok(reply)
252 }
253
254 fn match_subscriptions(&mut self, _id: &str, publish: Publish) {
255 if publish.retain {
256 if publish.payload.len() == 0 {
257 self.retained_publishes.remove(&publish.topic_name);
258 return
259 } else {
260 self.retained_publishes.insert(publish.topic_name.clone(), publish.clone());
261 }
262 }
263
264 let topic = &publish.topic_name;
265 if let Some(subscribers) = self.concrete_subscriptions.get(topic) {
266 let subscribers = subscribers.clone();
267 for subscriber in subscribers.iter() {
268 self.fill_subscriber(subscriber, publish.clone());
269 }
270 }
271
272 let wild_subscriptions = self.wild_subscriptions.clone();
275 for (filter, subscribers) in wild_subscriptions.into_iter() {
276 if matches(&topic, &filter) {
277 for subscriber in subscribers.into_iter() {
278 let publish = publish.clone();
279 self.fill_subscriber(&subscriber, publish);
280 }
281 }
282 };
283 }
284
285
286 fn add_to_subscriptions(&mut self, id: String, subscribe: Subscribe) {
287 for topic in subscribe.topics {
289 let mut filter = topic.topic_path.clone();
290 let qos = topic.qos;
291 let subscriber = Subscriber::new(&id, qos);
292
293 let subscriber = if let Some((f, subscriber)) = self.fix_overlapping_subscriptions(&id, &filter, qos) {
294 filter = f;
295 subscriber
296 } else {
297 subscriber
298 };
299
300 let subscriptions = if has_wildcards(&filter) {
302 &mut self.wild_subscriptions
303 } else {
304 &mut self.concrete_subscriptions
305 };
306
307 match subscriptions.get_mut(&filter) {
309 Some(subscribers) => {
311 if !subscribers.iter().any(|s| s.client_id == id) {
313 subscribers.push(subscriber.clone())
314 }
315 }
316 None => {
318 let mut subscribers = Vec::new();
319 subscribers.push(subscriber.clone());
320 subscriptions.insert(filter.to_owned(), subscribers);
321 }
322 }
323
324 if has_wildcards(&filter) {
326 let retained_publishes = self.retained_publishes.clone();
327 for (topic, publish) in retained_publishes.into_iter() {
328 if matches(&topic, &filter) {
329 self.fill_subscriber(&subscriber, publish);
330 }
331 }
332 } else {
333 if let Some(publish) = self.retained_publishes.get(&filter) {
334 let publish = publish.clone();
335 self.fill_subscriber(&subscriber, publish);
336 }
337 }
338 }
339 }
340
341 fn fix_overlapping_subscriptions(&mut self, id: &str, current_filter: &str, qos: QoS) -> Option<(String, Subscriber)> {
372 let mut subscriber = None;
373 let mut filter = current_filter.to_owned();
374 let mut qos = qos;
375
376 if has_wildcards(current_filter) {
379 for (existing_filter, subscribers) in self.concrete_subscriptions.iter_mut() {
380 if matches(existing_filter, current_filter) {
381 if let Some(index) = subscribers.iter().position(|s| s.client_id == id) {
382 let mut s = subscribers.remove(index);
383 if qos > s.qos {
384 s.qos = qos;
385 }
386 subscriber = Some(s);
387 }
388 }
389 }
390
391 for (existing_filter, subscribers) in self.wild_subscriptions.iter_mut() {
392 if matches(existing_filter, current_filter) {
395 if let Some(index) = subscribers.iter().position(|s| s.client_id == id) {
396 let s = subscribers.remove(index);
397
398 if s.qos > qos {
399 qos = s.qos
400 }
401
402 subscriber = Some(s);
403 }
404 } else if matches(current_filter, existing_filter) {
405 filter = existing_filter.clone();
408 if let Some(index) = subscribers.iter().position(|s| s.client_id == id) {
411 let s = subscribers.remove(index);
412
413 if s.qos > qos {
414 qos = s.qos
415 }
416
417 subscriber = Some(s);
418 }
419 }
420 }
421 }
422
423 match subscriber {
424 Some(mut subscriber) => {
425 subscriber.qos = qos;
426 Some((filter, subscriber))
427 }
428 None => None
429 }
430 }
431
432 fn remove_from_subscriptions(&mut self, id: String, unsubscribe: Unsubscribe) {
433 for topic in unsubscribe.topics.iter() {
434 if has_wildcards(topic) {
435 for (filter, subscribers) in self.concrete_subscriptions.iter_mut() {
438 if topic == filter {
439 if let Some(index) = subscribers.iter().position(|s| s.client_id == id) {
440 subscribers.remove(index);
441 }
442 }
443 }
444 } else {
445 for (filter, subscribers) in self.concrete_subscriptions.iter_mut() {
448 if topic == filter {
449 if let Some(index) = subscribers.iter().position(|s| s.client_id == id) {
450 subscribers.remove(index);
451 }
452 }
453 }
454 };
455 }
456 }
457
458 fn deactivate(&mut self, id: String) {
459 info!("Deactivating client due to disconnect packet. Id = {}", id);
460
461 if let Some(connection) = self.active_connections.remove(&id) {
462 if !connection.state.clean_session {
463 self.inactive_connections.insert(id, InactiveConnection::new(connection.state));
464 }
465 }
466 }
467
468 fn deactivate_and_forward_will(&mut self, id: String) {
469 info!("Deactivating client due to connection death. Id = {}", id);
470
471 if let Some(mut connection) = self.active_connections.remove(&id) {
472 if let Some(mut will) = connection.state.will.take() {
473 let topic = mem::replace(&mut will.topic, "".to_owned());
474 let message = mem::replace(&mut will.message, "".to_owned());
475 let qos = will.qos;
476
477 let publish = Publish::new(topic, qos, message);
478 self.match_subscriptions(&id, publish);
479 }
480
481 if !connection.state.clean_session {
482 self.inactive_connections.insert(id.clone(), InactiveConnection::new(connection.state));
483 }
484 }
485 }
486
487 fn handle_incoming_packet(&mut self, id: &str, packet: Packet) -> Result<Option<RouterMessage>, Error> {
489 if let Some(connection) = self.active_connections.get_mut(id) {
490 let reply = match connection.state.handle_incoming_mqtt_packet(packet) {
491 Ok(Some(reply)) => reply,
492 Ok(None) => return Ok(None),
493 Err(state::Error::Unsolicited(packet)) => {
494 error!("Unsolicited ack = {:?}. Id = {}", packet, id);
497 return Ok(None)
498 }
499 Err(e) => {
500 error!("State error = {:?}. Id = {}", e, id);
501 self.active_connections.remove(id);
502 return Err::<_, Error>(e.into())
503 }
504 };
505
506 return Ok(Some(reply))
507 }
508
509 Ok(None)
510 }
511
512 fn fill_subscriber(&mut self, subscriber: &Subscriber, mut publish: Publish) {
514 publish.qos = subscriber.qos;
515
516 if let Some(connection) = self.inactive_connections.get_mut(&subscriber.client_id) {
517 debug!("Forwarding publish to active connection. Id = {}, {:?}", subscriber.client_id, publish);
518 connection.state.handle_outgoing_publish(publish);
519 return
520 }
521
522 if let Some(connection) = self.active_connections.get_mut(&subscriber.client_id) {
523 let packet = connection.state.handle_outgoing_publish(publish);
524 connection.outgoing.push_back(packet);
525 }
526 }
527
528 fn forward(&mut self, id: &str, message: RouterMessage) {
529 if let Some(connection) = self.active_connections.get_mut(id) {
530 if let Err(e) = connection.tx.try_send(message) {
533 match e {
534 TrySendError::Full(_m) => {
535 error!("Slow connection during forward. Dropping handle and moving id to inactive list. Id = {}", id);
536 if let Some(connection) = self.active_connections.remove(id) {
537 self.inactive_connections.insert(id.to_owned(), InactiveConnection::new(connection.state));
538 }
539 }
540 TrySendError::Closed(_m) => {
541 error!("Closed connection. Forward failed");
542 self.active_connections.remove(id);
543 }
544 }
545 }
546 }
547 }
548}
549
550
551
552
553#[cfg(test)]
554mod test {
555 #[test]
556 fn persistent_disconnected_and_dead_connections_are_moved_to_inactive_state() {}
557
558 #[test]
559 fn persistend_reconnections_are_move_from_inactive_to_active_state() {}
560
561 #[test]
562 fn offline_messages_are_given_back_to_reconnected_persistent_connection() {}
563
564 #[test]
565 fn remove_client_from_concrete_subsctiptions_if_new_wildcard_subscription_matches_existing_concrecte_subscription() {
566 }
569
570 #[test]
571 fn ingnore_new_concrete_subscription_if_a_matching_wildcard_subscription_exists_for_the_client() {}
572
573 #[test]
574 fn router_should_remove_the_connection_during_disconnect() {}
575
576 #[test]
577 fn router_should_not_add_same_client_to_subscription_list() {}
578
579 #[test]
580 fn router_saves_offline_messages_of_a_persistent_dead_connection() {}
581}