atm0s_sdn_pub_sub/relay/
source_binding.rs

1use std::{
2    collections::{hash_map::Entry, HashMap, VecDeque},
3    sync::Arc,
4};
5
6use atm0s_sdn_identity::NodeId;
7use atm0s_sdn_utils::awaker::Awaker;
8
9use super::{ChannelUuid, LocalSubId};
10
11struct ChannelContainer {
12    sources: Vec<NodeId>,
13    subs: Vec<LocalSubId>,
14}
15
16#[derive(Debug, PartialEq, Eq)]
17pub enum SourceBindingAction {
18    Subscribe(ChannelUuid),
19    Unsubscribe(ChannelUuid),
20}
21
22pub struct SourceBinding {
23    channels: HashMap<ChannelUuid, ChannelContainer>,
24    actions: VecDeque<SourceBindingAction>,
25    awaker: Arc<dyn Awaker>,
26}
27
28impl SourceBinding {
29    pub fn new() -> Self {
30        Self {
31            channels: HashMap::new(),
32            actions: VecDeque::new(),
33            awaker: Arc::new(atm0s_sdn_utils::awaker::MockAwaker::default()),
34        }
35    }
36
37    pub fn set_awaker(&mut self, awaker: Arc<dyn Awaker>) {
38        self.awaker = awaker;
39    }
40
41    pub fn on_local_sub(&mut self, channel: ChannelUuid, sub: LocalSubId) -> Option<Vec<NodeId>> {
42        match self.channels.entry(channel) {
43            Entry::Occupied(mut entry) => {
44                // only push to subs when not exist
45                if !entry.get().subs.contains(&sub) {
46                    if entry.get().subs.is_empty() {
47                        self.actions.push_back(SourceBindingAction::Subscribe(channel));
48                        self.awaker.notify();
49                    }
50                    entry.get_mut().subs.push(sub);
51                    if entry.get().sources.is_empty() {
52                        None
53                    } else {
54                        Some(entry.get().sources.clone())
55                    }
56                } else {
57                    None
58                }
59            }
60            Entry::Vacant(entry) => {
61                entry.insert(ChannelContainer { sources: vec![], subs: vec![sub] });
62                self.actions.push_back(SourceBindingAction::Subscribe(channel));
63                self.awaker.notify();
64                None
65            }
66        }
67    }
68
69    pub fn on_local_unsub(&mut self, channel: ChannelUuid, sub: LocalSubId) -> Option<Vec<NodeId>> {
70        let container = self.channels.get_mut(&channel)?;
71        let index = container.subs.iter().position(|x| *x == sub)?;
72        container.subs.remove(index);
73
74        if container.subs.is_empty() {
75            self.actions.push_back(SourceBindingAction::Unsubscribe(channel));
76            self.awaker.notify();
77        }
78
79        if container.subs.is_empty() && container.sources.is_empty() {
80            self.channels.remove(&channel);
81            None
82        } else {
83            if container.sources.is_empty() {
84                None
85            } else {
86                Some(container.sources.clone())
87            }
88        }
89    }
90
91    pub fn on_source_added(&mut self, channel: ChannelUuid, source: NodeId) -> Option<Vec<LocalSubId>> {
92        match self.channels.entry(channel) {
93            Entry::Occupied(mut entry) => {
94                // only push to sources when not exist
95                if !entry.get().sources.contains(&source) {
96                    entry.get_mut().sources.push(source);
97                    if entry.get().subs.is_empty() {
98                        None
99                    } else {
100                        Some(entry.get().subs.clone())
101                    }
102                } else {
103                    None
104                }
105            }
106            Entry::Vacant(entry) => {
107                entry.insert(ChannelContainer { sources: vec![source], subs: vec![] });
108                None
109            }
110        }
111    }
112
113    pub fn on_source_removed(&mut self, channel: ChannelUuid, source: NodeId) -> Option<Vec<LocalSubId>> {
114        let container = self.channels.get_mut(&channel)?;
115        let index = container.sources.iter().position(|x| *x == source)?;
116        container.sources.remove(index);
117
118        if container.subs.is_empty() && container.sources.is_empty() {
119            self.channels.remove(&channel);
120            None
121        } else {
122            if container.subs.is_empty() {
123                None
124            } else {
125                Some(container.subs.clone())
126            }
127        }
128    }
129
130    pub fn sources_for(&self, channel: ChannelUuid) -> Vec<NodeId> {
131        self.channels.get(&channel).map(|x| x.sources.clone()).unwrap_or_default()
132    }
133
134    pub fn consumers_for(&self, channel: ChannelUuid) -> Vec<LocalSubId> {
135        self.channels.get(&channel).map(|x| x.subs.clone()).unwrap_or_default()
136    }
137
138    pub fn pop_action(&mut self) -> Option<SourceBindingAction> {
139        self.actions.pop_front()
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use std::sync::Arc;
146
147    use atm0s_sdn_utils::awaker::Awaker;
148
149    use crate::relay::source_binding::SourceBindingAction;
150
151    use super::SourceBinding;
152
153    #[test]
154    fn source_for_should_correct() {
155        let mut bindding = SourceBinding::new();
156        assert_eq!(bindding.sources_for(1), vec![]);
157
158        bindding.on_source_added(1, 1000);
159        bindding.on_source_added(1, 1001);
160
161        assert_eq!(bindding.sources_for(1), vec![1000, 1001]);
162    }
163
164    #[test]
165    fn local_sub_unsub_should_correct() {
166        let awake = Arc::new(atm0s_sdn_utils::awaker::MockAwaker::default());
167        let mut bindding = SourceBinding::new();
168        bindding.set_awaker(awake.clone());
169
170        assert_eq!(bindding.on_source_added(1, 1000), None);
171        assert_eq!(bindding.on_source_added(1, 1001), None);
172
173        assert_eq!(bindding.on_local_unsub(1, 10), None);
174        assert_eq!(bindding.on_local_unsub(1, 11), None);
175
176        assert_eq!(bindding.on_local_sub(1, 10), Some(vec![1000, 1001]));
177        assert_eq!(bindding.pop_action(), Some(SourceBindingAction::Subscribe(1)));
178        assert_eq!(bindding.pop_action(), None);
179        assert_eq!(awake.pop_awake_count(), 1);
180
181        assert_eq!(bindding.on_local_sub(1, 10), None); // already sub
182        assert_eq!(bindding.on_local_unsub(1, 11), None);
183
184        assert_eq!(bindding.on_local_sub(1, 11), Some(vec![1000, 1001]));
185        assert_eq!(bindding.pop_action(), None);
186        assert_eq!(awake.pop_awake_count(), 0);
187
188        assert_eq!(bindding.on_local_unsub(1, 10), Some(vec![1000, 1001]));
189        assert_eq!(bindding.pop_action(), None);
190        assert_eq!(awake.pop_awake_count(), 0);
191
192        assert_eq!(bindding.on_local_unsub(1, 11), Some(vec![1000, 1001]));
193        assert_eq!(bindding.pop_action(), Some(SourceBindingAction::Unsubscribe(1)));
194        assert_eq!(bindding.pop_action(), None);
195        assert_eq!(awake.pop_awake_count(), 1);
196
197        assert_eq!(bindding.on_local_unsub(1, 10), None);
198        assert_eq!(bindding.on_local_unsub(1, 11), None);
199    }
200
201    #[test]
202    fn source_add_remove_should_correct() {
203        let mut bindding = SourceBinding::new();
204
205        assert_eq!(bindding.on_local_sub(1, 10), None);
206        assert_eq!(bindding.on_local_sub(1, 11), None);
207
208        assert_eq!(bindding.on_source_removed(1, 1000), None);
209        assert_eq!(bindding.on_source_removed(1, 1001), None);
210
211        assert_eq!(bindding.on_source_added(1, 1000), Some(vec![10, 11]));
212        assert_eq!(bindding.on_source_added(1, 1001), Some(vec![10, 11]));
213
214        assert_eq!(bindding.on_source_added(1, 1000), None); // already added
215        assert_eq!(bindding.on_source_added(1, 1001), None); // already added
216
217        assert_eq!(bindding.on_source_removed(1, 1000), Some(vec![10, 11]));
218        assert_eq!(bindding.on_source_removed(1, 1001), Some(vec![10, 11]));
219
220        assert_eq!(bindding.on_source_removed(1, 1000), None); // already removed
221        assert_eq!(bindding.on_source_removed(1, 1001), None); // already removed
222    }
223}