rumqttd/router/
shared_subs.rs

1use rand::Rng;
2use serde::{Deserialize, Serialize};
3
4pub struct SharedGroup {
5    // using Vec over HashSet for maintaining order of iter
6    clients: Vec<String>,
7    // Index into clients, allows us to skip doing iter everytime
8    current_client_index: usize,
9    pub cursor: (u64, u64),
10    pub strategy: Strategy,
11}
12
13impl SharedGroup {
14    pub fn new(cursor: (u64, u64), strategy: Strategy) -> Self {
15        SharedGroup {
16            clients: vec![],
17            current_client_index: 0,
18            cursor,
19            strategy,
20        }
21    }
22
23    pub fn is_empty(&self) -> bool {
24        self.clients.is_empty()
25    }
26
27    pub fn current_client(&self) -> Option<&String> {
28        self.clients.get(self.current_client_index)
29    }
30
31    pub fn add_client(&mut self, client: String) {
32        self.clients.push(client)
33    }
34
35    pub fn remove_client(&mut self, client: &String) {
36        // remove client from vec
37        self.clients.retain(|c| c != client);
38
39        // if there are no clients left, we have to avoid % by 0
40        if !self.clients.is_empty() {
41            // Make sure that we are within bounds and that next client is the correct client.
42            self.current_client_index %= self.clients.len();
43        }
44    }
45
46    pub fn update_next_client(&mut self) {
47        match self.strategy {
48            Strategy::RoundRobin => {
49                self.current_client_index = (self.current_client_index + 1) % self.clients.len();
50            }
51            Strategy::Random => {
52                self.current_client_index = rand::thread_rng().gen_range(0..self.clients.len());
53            }
54            Strategy::Sticky => {}
55        }
56    }
57}
58
59#[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq)]
60#[serde(rename_all = "lowercase")]
61pub enum Strategy {
62    #[default]
63    RoundRobin,
64    Random,
65    Sticky,
66}
67
68#[cfg(test)]
69mod tests {
70    use crate::router::shared_subs::Strategy;
71
72    use super::SharedGroup;
73
74    #[test]
75    fn performs_round_robin() {
76        let mut group = SharedGroup {
77            clients: vec!["A".into(), "B".into(), "C".into()],
78            current_client_index: 0,
79            cursor: (0, 0),
80            strategy: Strategy::RoundRobin,
81        };
82        group.update_next_client();
83        assert_eq!(group.current_client_index, 1);
84        group.update_next_client();
85        assert_eq!(group.current_client_index, 2);
86        group.update_next_client();
87        assert_eq!(group.current_client_index, 0);
88        group.add_client("D".into());
89        assert_eq!(group.current_client_index, 0);
90    }
91
92    #[test]
93    fn handles_round_robin_when_start_removed() {
94        // [ A, B, C ] => 0
95        // we remove A
96        // [ B, C ] => Should be the next client (B)
97        let mut group = SharedGroup {
98            clients: vec!["A".into(), "B".into(), "C".into()],
99            current_client_index: 0,
100            cursor: (0, 0),
101            strategy: Strategy::RoundRobin,
102        };
103        group.remove_client(&"A".into());
104        assert_eq!(group.current_client_index, 0);
105        group.update_next_client();
106        assert_eq!(group.current_client_index, 1);
107        group.update_next_client();
108        assert_eq!(group.current_client_index, 0);
109    }
110
111    #[test]
112    fn handles_round_robin_when_last_removed() {
113        // [ A, B, C ] => 2 (C)
114        // we remove C
115        // [ A, B ] => Should be the next client (A)
116        let mut group = SharedGroup {
117            clients: vec!["A".into(), "B".into(), "C".into()],
118            current_client_index: 0,
119            cursor: (0, 0),
120            strategy: Strategy::RoundRobin,
121        };
122        group.update_next_client();
123        assert_eq!(group.current_client_index, 1);
124        group.update_next_client();
125        assert_eq!(group.current_client_index, 2);
126        group.remove_client(&"C".into());
127        assert_eq!(group.current_client_index, 0);
128    }
129}