Skip to main content

rumq_broker/
router.rs

1use rumq_core::mqtt4::{has_wildcards, matches, QoS, Packet, Connect, Publish, Subscribe, Unsubscribe};
2use tokio::sync::mpsc::{Receiver, Sender};
3use tokio::sync::mpsc::error::TrySendError;
4use tokio::select;
5use tokio::time::{self, Duration};
6use tokio::stream::StreamExt;
7
8use std::collections::{HashMap, VecDeque};
9use std::mem;
10use std::fmt;
11
12use crate::state::{self, MqttState};
13
14#[derive(Debug, thiserror::Error)]
15pub enum Error {
16    #[error("State")]
17    State(#[from] state::Error),
18    #[error("All senders down")]
19    AllSendersDown,
20    #[error("Error sending message on internal bus")]
21    Mpsc(#[from] TrySendError<RouterMessage>),
22}
23
24/// Router message to orchestrate data between connections. We can also
25/// use this to send control signals to connections to modify their behavior
26/// dynamically from the console
27#[derive(Debug)]
28pub enum RouterMessage {
29    /// Client id and connection handle
30    Connect(Connection),
31    /// Packet
32    Packet(Packet),
33    /// Packets
34    Packets(VecDeque<Packet>),
35    /// Disconnects a client from active connections list. Will handling
36    Death(String),
37    /// Pending messages of the previous connection
38    Pending(VecDeque<Publish>)
39}
40
41pub struct Connection {
42    pub connect: Connect,
43    pub handle: Option<Sender<RouterMessage>>
44}
45
46impl Connection {
47    pub fn new(connect: Connect, handle: Sender<RouterMessage>) -> Connection {
48        Connection {
49            connect,
50            handle: Some(handle)
51        }
52    }
53}
54
55impl fmt::Debug for Connection {
56    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57        write!(f, "{:?}", self.connect)
58    }
59}
60
61#[derive(Debug)]
62struct ActiveConnection {
63    pub state: MqttState,
64    pub outgoing: VecDeque<Packet>,
65    tx: Sender<RouterMessage>
66}
67
68impl ActiveConnection {
69    pub fn new(tx: Sender<RouterMessage>, state: MqttState) -> ActiveConnection {
70        ActiveConnection {
71            state,
72            outgoing: VecDeque::new(),
73            tx
74        }
75    }
76}
77
78#[derive(Debug)]
79struct InactiveConnection {
80    pub state: MqttState
81}
82
83impl InactiveConnection {
84    pub fn new(state: MqttState) -> InactiveConnection {
85        InactiveConnection {
86            state,
87        }
88    }
89}
90
91#[derive(Debug, Clone)]
92struct Subscriber {
93    client_id: String,
94    qos: QoS,
95}
96
97impl Subscriber {
98    pub fn new(id: &str, qos: QoS) -> Subscriber {
99        Subscriber {
100            client_id: id.to_owned(),
101            qos,
102        }
103    }
104}
105
106#[derive(Debug)]
107pub struct Router {
108    // handles to all active connections. used to route data
109    active_connections:     HashMap<String, ActiveConnection>,
110    // inactive persistent connections
111    inactive_connections:   HashMap<String, InactiveConnection>,
112    // maps concrete subscriptions to interested subscribers
113    concrete_subscriptions: HashMap<String, Vec<Subscriber>>,
114    // maps wildcard subscriptions to interested subscribers
115    wild_subscriptions:     HashMap<String, Vec<Subscriber>>,
116    // retained publishes
117    retained_publishes:     HashMap<String, Publish>,
118    // channel receiver to receive data from all the active_connections.
119    // each connection will have a tx handle
120    data_rx:                Receiver<(String, RouterMessage)>,
121}
122
123impl Router {
124    pub fn new(data_rx: Receiver<(String, RouterMessage)>) -> Self {
125        Router {
126            active_connections: HashMap::new(),
127            inactive_connections: HashMap::new(),
128            concrete_subscriptions: HashMap::new(),
129            wild_subscriptions: HashMap::new(),
130            retained_publishes: HashMap::new(),
131            data_rx,
132        }
133    }
134
135    pub async fn start(&mut self) -> Result<(), Error> {
136        let mut interval = time::interval(Duration::from_millis(10));
137        
138        loop {
139            select! {
140                o = self.data_rx.recv() => {
141                    let (id, mut message) = o.unwrap();                   
142                    debug!("In router message. Id = {}, {:?}", id, message);
143                    match self.reply(id.clone(), &mut message) {
144                        Ok(Some(message)) => self.forward(&id, message),
145                        Ok(None) => (),
146                        Err(e) => {
147                            error!("Incoming handle error = {:?}", e);
148                            continue;
149                        }
150                    }
151
152                    // adds routes, routes the message to other connections etc etc
153                    if let Err(e) = self.route(id, message) {
154                        error!("Routing error = {:?}", e);
155                    }
156                }
157                _ = interval.next() => {
158                    for (_, connection) in self.active_connections.iter_mut() {
159                        let pending = connection.outgoing.split_off(0);
160                        if pending.len() > 0 {
161                            let _ = connection.tx.try_send(RouterMessage::Packets(pending));
162                        }
163                    }
164                }
165            }
166
167
168        }
169    }
170
171    /// replys back to the connection sending the message
172    /// doesn't modify any routing information of the router
173    /// all the routing and route modifications due to subscriptions
174    /// are part of `route` method
175    /// generates reply to send backto the connection. Shouldn't touch anything except active and 
176    /// inactive connections
177    /// No routing modifications here
178    fn reply(&mut self, id: String, message: &mut RouterMessage) -> Result<Option<RouterMessage>, Error> {
179        match message {
180            RouterMessage::Connect(connection) => {
181                let handle = connection.handle.take().unwrap();
182                let message = self.handle_connect(connection.connect.clone(), handle)?;
183                Ok(message)
184            }
185            RouterMessage::Packet(packet) => {
186                let message = self.handle_incoming_packet(&id, packet.clone())?;
187                Ok(message)
188            }
189            _ => Ok(None)
190        }
191    }
192
193    fn route(&mut self, id: String, message: RouterMessage) -> Result<(), Error> {
194        match message {
195            RouterMessage::Packet(packet) => {
196                match packet {
197                    Packet::Publish(publish) => self.match_subscriptions(&id, publish),
198                    Packet::Subscribe(subscribe) => self.add_to_subscriptions(id, subscribe),
199                    Packet::Unsubscribe(unsubscribe) => self.remove_from_subscriptions(id, unsubscribe),
200                    Packet::Disconnect => self.deactivate(id),
201                    _ => return Ok(())
202                }
203            }
204            RouterMessage::Death(id) => {
205                self.deactivate_and_forward_will(id.to_owned());
206            }
207            _ => () 
208        }
209        Ok(())
210    }
211
212    fn handle_connect(&mut self, connect: Connect, connection_handle: Sender<RouterMessage>) -> Result<Option<RouterMessage>, Error> {
213        let id = connect.client_id;
214        let clean_session = connect.clean_session;
215        let will = connect.last_will;
216
217        info!("Connect. Id = {:?}", id);
218        let reply = if clean_session {
219            self.inactive_connections.remove(&id);
220
221            let state = MqttState::new(clean_session, will);
222            self.active_connections.insert(id.clone(), ActiveConnection::new(connection_handle, state));
223            Some(RouterMessage::Pending(VecDeque::new()))
224        } else {
225            if let Some(connection) = self.inactive_connections.remove(&id) {
226                let pending = connection.state.outgoing_publishes.clone();
227                self.active_connections.insert(id.clone(), ActiveConnection::new(connection_handle, connection.state));
228                Some(RouterMessage::Pending(pending))
229            } else {
230                let state = MqttState::new(clean_session, will);
231                self.active_connections.insert(id.clone(), ActiveConnection::new(connection_handle, state));
232                Some(RouterMessage::Pending(VecDeque::new()))
233            }
234        };
235
236        if clean_session {
237            // FIXME: This is costly for every clean connection with a lot of subscribers
238            for (_, subscribers) in self.concrete_subscriptions.iter_mut() {
239                if let Some(index) = subscribers.iter().position(|s| s.client_id == id) {
240                    subscribers.remove(index);
241                }
242            }
243
244            for (_, subscribers) in self.wild_subscriptions.iter_mut() {
245                if let Some(index) = subscribers.iter().position(|s| s.client_id == id) {
246                    subscribers.remove(index);
247                }
248            }
249        }
250
251        Ok(reply)
252    }
253
254    fn match_subscriptions(&mut self, _id: &str, publish: Publish) {
255        if publish.retain {
256            if publish.payload.len() == 0 {
257                self.retained_publishes.remove(&publish.topic_name);
258                return
259            } else {
260                self.retained_publishes.insert(publish.topic_name.clone(), publish.clone());
261            }
262        }
263
264        let topic = &publish.topic_name;
265        if let Some(subscribers) = self.concrete_subscriptions.get(topic) {
266            let subscribers = subscribers.clone();
267            for subscriber in subscribers.iter() {
268                self.fill_subscriber(subscriber, publish.clone());
269            }
270        }
271
272        // TODO: O(n) which happens during every publish. publish perf is going to be
273        // linearly degraded based on number of wildcard subscriptions. fix this
274        let wild_subscriptions = self.wild_subscriptions.clone(); 
275        for (filter, subscribers) in wild_subscriptions.into_iter() {
276            if matches(&topic, &filter) {
277                for subscriber in subscribers.into_iter() {
278                    let publish = publish.clone();
279                    self.fill_subscriber(&subscriber, publish);
280                }
281            }
282        };
283    }
284
285
286    fn add_to_subscriptions(&mut self, id: String, subscribe: Subscribe) {
287        // Each subscribe message can send multiple topics to subscribe to. handle dupicates
288        for topic in subscribe.topics {
289            let mut filter = topic.topic_path.clone();
290            let qos = topic.qos;
291            let subscriber = Subscriber::new(&id, qos);
292
293            let subscriber = if let Some((f, subscriber)) = self.fix_overlapping_subscriptions(&id, &filter, qos) {
294                filter = f;
295                subscriber
296            } else {
297                subscriber
298            };
299
300            // a publish happens on a/b/c.
301            let subscriptions = if has_wildcards(&filter) {
302                &mut self.wild_subscriptions
303            } else {
304                &mut self.concrete_subscriptions
305            };
306
307            // add client id to subscriptions
308            match subscriptions.get_mut(&filter) {
309                // push client id to the list of clients intrested in this subspcription
310                Some(subscribers) => {
311                    // don't add the same id twice
312                    if !subscribers.iter().any(|s| s.client_id == id) {
313                        subscribers.push(subscriber.clone())
314                    }
315                }
316                // create a new subscription and push the client id
317                None => {
318                    let mut subscribers = Vec::new();
319                    subscribers.push(subscriber.clone());
320                    subscriptions.insert(filter.to_owned(), subscribers);
321                }
322            }
323
324            // Handle retained publishes after subscription duplicates are handled above
325            if has_wildcards(&filter) {
326                let retained_publishes = self.retained_publishes.clone();
327                for (topic, publish) in retained_publishes.into_iter() {
328                    if matches(&topic, &filter) {
329                        self.fill_subscriber(&subscriber, publish);
330                    }
331                }
332            } else {
333                if let Some(publish) = self.retained_publishes.get(&filter) {
334                    let publish = publish.clone();
335                    self.fill_subscriber(&subscriber, publish);
336                }
337            }
338        }
339    }
340
341    /// removes the subscriber from subscription if the current subscription is wider than the
342    /// existing subscription and returns it
343    ///
344    /// if wildcard subscription:
345    /// move subscriber from concrete to wild subscription with greater qos
346    /// move subscriber from existing wild to current wild subscription if current is wider
347    /// move subscriber from current wild to existing wild if the existing wild is wider
348    /// returns the subscriber and the wild subscription it is to be added to
349    /// none implies that there are no overlapping subscriptions for this subscriber
350    /// new subscriber a/+/c (qos 1) matches existing subscription a/b/c
351    /// subscriber should be moved from a/b/c to a/+/c
352    ///
353    /// * if the new subscription is wider than existing subscription, move the subscriber to wider
354    /// subscription with highest qos
355    ///
356    /// * any new wildcard subsciption checks for matching concrete subscriber
357    /// * if matches, add the subscriber to `wild_subscriptions` with greatest qos 
358    ///
359    /// * any new concrete subscriber checks for matching wildcard subscriber
360    /// * if matches, add the subscriber to `wild_subscriptions` (instead of concrete subscription) with greatest qos
361    /// 
362    /// coming to overlapping wildcard subscriptions
363    ///
364    /// * new subsciber-a a/+/c/d  mathes subscriber-a in a/# 
365    /// * add subscriber-a to a/# directly with highest qos
366    /// 
367    /// * new subscriber a/# matches with existing a/+/c/d
368    /// * remove subscriber from a/+/c/d and move it to a/# with highest qos
369    /// 
370    /// * finally a subscriber won't be part of multiple subscriptions
371    fn fix_overlapping_subscriptions(&mut self, id: &str, current_filter: &str, qos: QoS) -> Option<(String, Subscriber)> {
372        let mut subscriber = None;
373        let mut filter = current_filter.to_owned();
374        let mut qos = qos;
375
376        // subscriber in concrete_subscriptions a/b/c/d matchs new subscription a/+/c/d on same
377        // subscriber. move it from concrete to wild
378        if has_wildcards(current_filter) {
379            for (existing_filter, subscribers) in self.concrete_subscriptions.iter_mut() {
380                if matches(existing_filter, current_filter) {
381                    if let Some(index) = subscribers.iter().position(|s| s.client_id == id) {
382                        let mut s = subscribers.remove(index);
383                        if qos > s.qos {
384                            s.qos = qos;
385                        }
386                        subscriber = Some(s);
387                    }
388                }
389            }
390
391            for (existing_filter, subscribers) in self.wild_subscriptions.iter_mut() {
392                // current filter is wider than existing filter. remove subscriber (if it exists)
393                // from current filter
394                if matches(existing_filter, current_filter) {
395                    if let Some(index) = subscribers.iter().position(|s| s.client_id == id) {
396                        let s = subscribers.remove(index);
397
398                        if s.qos > qos {
399                            qos = s.qos
400                        }
401
402                        subscriber = Some(s);
403                    }
404                } else if matches(current_filter, existing_filter) {
405                    // existing filter is wider than current filter, return the subscriber with
406                    // wider subscription (to be added outside this method)
407                    filter = existing_filter.clone();
408                    // remove the subscriber and update the global qos (this subscriber will be
409                    // added again outside)
410                    if let Some(index) = subscribers.iter().position(|s| s.client_id == id) {
411                        let s = subscribers.remove(index);
412
413                        if s.qos > qos {
414                            qos = s.qos
415                        }
416
417                        subscriber = Some(s);
418                    }   
419                }
420            }
421        } 
422
423        match subscriber {
424            Some(mut subscriber) => {
425                subscriber.qos = qos;
426                Some((filter, subscriber))
427            }
428            None => None
429        }
430    }
431
432    fn remove_from_subscriptions(&mut self, id: String, unsubscribe: Unsubscribe) {
433        for topic in unsubscribe.topics.iter() {
434            if has_wildcards(topic) {
435                // remove client from the concrete subscription list incase of a matching wildcard
436                // subscription
437                for (filter, subscribers) in self.concrete_subscriptions.iter_mut() {
438                    if topic == filter {
439                        if let Some(index) = subscribers.iter().position(|s| s.client_id == id) {
440                            subscribers.remove(index);
441                        }
442                    }
443                }
444            } else {
445                // ignore a new concrete subscription if the client already has a matching wildcard
446                // subscription
447                for (filter, subscribers) in self.concrete_subscriptions.iter_mut() {
448                    if topic == filter {
449                        if let Some(index) = subscribers.iter().position(|s| s.client_id == id) {
450                            subscribers.remove(index);
451                        }
452                    }
453                }
454            };
455        }
456    }
457
458    fn deactivate(&mut self, id: String) {
459        info!("Deactivating client due to disconnect packet. Id = {}", id);
460
461        if let Some(connection) = self.active_connections.remove(&id) {
462            if !connection.state.clean_session {
463                self.inactive_connections.insert(id, InactiveConnection::new(connection.state));
464            }
465        }
466    }
467
468    fn deactivate_and_forward_will(&mut self, id: String) {
469        info!("Deactivating client due to connection death. Id = {}", id);
470
471        if let Some(mut connection) = self.active_connections.remove(&id) {
472            if let Some(mut will) = connection.state.will.take() {
473                let topic = mem::replace(&mut will.topic, "".to_owned());
474                let message = mem::replace(&mut will.message, "".to_owned());
475                let qos = will.qos;
476
477                let publish = Publish::new(topic, qos, message);
478                self.match_subscriptions(&id, publish);
479            }
480
481            if !connection.state.clean_session {
482                self.inactive_connections.insert(id.clone(), InactiveConnection::new(connection.state));
483            }
484        }
485    }
486
487    /// Saves state and sends network reply back to the connection
488    fn handle_incoming_packet(&mut self, id: &str, packet: Packet) -> Result<Option<RouterMessage>, Error> {
489        if let Some(connection) = self.active_connections.get_mut(id) {
490            let reply = match connection.state.handle_incoming_mqtt_packet(packet) {
491                Ok(Some(reply)) => reply,
492                Ok(None) => return Ok(None),
493                Err(state::Error::Unsolicited(packet)) => {
494                    // NOTE: Some clients seems to be sending pending acks after reconnection
495                    // even during a clean session. Let's be little lineant for now
496                    error!("Unsolicited ack = {:?}. Id = {}", packet, id);
497                    return Ok(None)
498                }
499                Err(e) => {
500                    error!("State error = {:?}. Id = {}", e, id);
501                    self.active_connections.remove(id);
502                    return Err::<_, Error>(e.into())
503                }
504            };
505
506            return Ok(Some(reply))
507        }
508
509        Ok(None)
510    }
511
512    // forwards data to the connection with the following id
513    fn fill_subscriber(&mut self, subscriber: &Subscriber, mut publish: Publish)  {
514        publish.qos = subscriber.qos;
515
516        if let Some(connection) = self.inactive_connections.get_mut(&subscriber.client_id) {
517            debug!("Forwarding publish to active connection. Id = {}, {:?}", subscriber.client_id, publish);
518            connection.state.handle_outgoing_publish(publish); 
519            return
520        }
521
522        if let Some(connection) = self.active_connections.get_mut(&subscriber.client_id) {
523            let packet = connection.state.handle_outgoing_publish(publish); 
524            connection.outgoing.push_back(packet);
525        }
526    }
527
528    fn forward(&mut self, id: &str, message: RouterMessage) {
529        if let Some(connection) = self.active_connections.get_mut(id) {
530            // slow connections should be moved to inactive connections. This drops tx handle of the
531            // connection leading to connection disconnection
532            if let Err(e) = connection.tx.try_send(message) {
533                match e {
534                    TrySendError::Full(_m) => {
535                        error!("Slow connection during forward. Dropping handle and moving id to inactive list. Id = {}", id);
536                        if let Some(connection) = self.active_connections.remove(id) {
537                            self.inactive_connections.insert(id.to_owned(), InactiveConnection::new(connection.state));
538                        }
539                    }
540                    TrySendError::Closed(_m) => {
541                        error!("Closed connection. Forward failed");
542                        self.active_connections.remove(id);
543                    }
544                }
545            }
546        }
547    }
548}
549
550
551
552
553#[cfg(test)]
554mod test {
555    #[test]
556    fn persistent_disconnected_and_dead_connections_are_moved_to_inactive_state() {}
557
558    #[test]
559    fn persistend_reconnections_are_move_from_inactive_to_active_state() {}
560
561    #[test]
562    fn offline_messages_are_given_back_to_reconnected_persistent_connection() {}
563
564    #[test]
565    fn remove_client_from_concrete_subsctiptions_if_new_wildcard_subscription_matches_existing_concrecte_subscription() {
566        // client subscibing to a/b/c and a/+/c should receive message only once when
567        // a publish happens on a/b/c
568    }
569
570    #[test]
571    fn ingnore_new_concrete_subscription_if_a_matching_wildcard_subscription_exists_for_the_client() {}
572
573    #[test]
574    fn router_should_remove_the_connection_during_disconnect() {}
575
576    #[test]
577    fn router_should_not_add_same_client_to_subscription_list() {}
578
579    #[test]
580    fn router_saves_offline_messages_of_a_persistent_dead_connection() {}
581}