1use std::collections::BTreeMap;
2use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
3
4use crate::conversation::ParticipantPid;
5use crate::routing::{ConsumerId, ConsumerStateView, RoutingFunction};
6
7#[derive(Clone, Debug, PartialEq, Eq)]
9pub struct ConsumerRegistration {
10 consumer: ConsumerId,
11 participant: ParticipantPid,
12 state: ConsumerStateView,
13}
14
15impl ConsumerRegistration {
16 #[must_use]
18 pub fn new(participant: ParticipantPid, state: ConsumerStateView) -> Self {
19 Self {
20 consumer: state.consumer.clone(),
21 participant,
22 state,
23 }
24 }
25
26 #[must_use]
28 pub fn with_default_state(consumer: ConsumerId, participant: ParticipantPid) -> Self {
29 let state = ConsumerStateView::new(consumer, 0, 1, 0, Vec::new());
30 Self::new(participant, state)
31 }
32
33 #[must_use]
35 pub const fn consumer(&self) -> &ConsumerId {
36 &self.consumer
37 }
38
39 #[must_use]
41 pub const fn participant(&self) -> ParticipantPid {
42 self.participant
43 }
44
45 #[must_use]
47 pub const fn state(&self) -> &ConsumerStateView {
48 &self.state
49 }
50}
51
52#[derive(Clone, Debug)]
54pub struct ConsumerGroupSnapshot {
55 routing_function: RoutingFunction,
56 consumers: Arc<[ConsumerRegistration]>,
57}
58
59impl ConsumerGroupSnapshot {
60 #[must_use]
62 pub const fn routing_function(&self) -> &RoutingFunction {
63 &self.routing_function
64 }
65
66 #[must_use]
68 pub fn consumers(&self) -> &[ConsumerRegistration] {
69 &self.consumers
70 }
71
72 #[must_use]
74 pub fn consumer_ids(&self) -> Vec<ConsumerId> {
75 self.consumers
76 .iter()
77 .map(|registration| registration.consumer.clone())
78 .collect()
79 }
80}
81
82#[derive(Clone, Debug)]
84pub struct ConsumerGroup {
85 inner: Arc<GroupInner>,
86}
87
88impl ConsumerGroup {
89 #[must_use]
91 pub fn new(routing_function: RoutingFunction) -> Self {
92 Self {
93 inner: Arc::new(GroupInner {
94 routing_function,
95 state: RwLock::new(GroupState::default()),
96 }),
97 }
98 }
99
100 #[must_use]
102 pub fn routing_function(&self) -> RoutingFunction {
103 self.inner.routing_function.clone()
104 }
105
106 #[must_use]
108 pub fn consumers(&self) -> Vec<ConsumerId> {
109 read_group_state(&self.inner.state)
110 .consumers
111 .keys()
112 .cloned()
113 .collect()
114 }
115
116 #[must_use]
118 pub fn snapshot(&self) -> ConsumerGroupSnapshot {
119 let consumers = read_group_state(&self.inner.state)
120 .consumers
121 .values()
122 .cloned()
123 .collect::<Vec<_>>();
124 ConsumerGroupSnapshot {
125 routing_function: self.routing_function(),
126 consumers: Arc::from(consumers.into_boxed_slice()),
127 }
128 }
129
130 #[must_use = "the boolean reports whether the consumer was newly inserted"]
132 pub fn add_consumer(&self, registration: ConsumerRegistration) -> bool {
133 write_group_state(&self.inner.state)
134 .consumers
135 .insert(registration.consumer.clone(), registration)
136 .is_none()
137 }
138
139 #[must_use = "the boolean reports whether a consumer was actually removed"]
141 pub fn remove_consumer(&self, consumer: &ConsumerId) -> bool {
142 write_group_state(&self.inner.state)
143 .consumers
144 .remove(consumer)
145 .is_some()
146 }
147}
148
149#[derive(Debug)]
150struct GroupInner {
151 routing_function: RoutingFunction,
152 state: RwLock<GroupState>,
153}
154
155#[derive(Debug, Default)]
156struct GroupState {
157 consumers: BTreeMap<ConsumerId, ConsumerRegistration>,
158}
159
160fn read_group_state(lock: &RwLock<GroupState>) -> RwLockReadGuard<'_, GroupState> {
161 match lock.read() {
162 Ok(guard) => guard,
163 Err(poisoned) => poisoned.into_inner(),
164 }
165}
166
167fn write_group_state(lock: &RwLock<GroupState>) -> RwLockWriteGuard<'_, GroupState> {
168 match lock.write() {
169 Ok(guard) => guard,
170 Err(poisoned) => poisoned.into_inner(),
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::{ConsumerGroup, ConsumerRegistration};
177 use crate::conversation::ParticipantPid;
178 use crate::routing::function::loader::{ModuleLoader, RoutingModule};
179 use crate::routing::{ConsumerId, ConsumerStateView, RoutingDecision};
180
181 fn function() -> crate::routing::RoutingFunction {
182 ModuleLoader::new().load(RoutingModule::new(b"group-test", |_message, consumers| {
183 consumers
184 .first()
185 .map_or_else(RoutingDecision::none, |consumer| {
186 RoutingDecision::select(consumer.consumer.clone())
187 })
188 }))
189 }
190
191 fn registration(id: &str, pid: u64) -> ConsumerRegistration {
192 ConsumerRegistration::new(
193 ParticipantPid::new(pid),
194 ConsumerStateView::new(ConsumerId::new(id), 0, 1, 0, Vec::new()),
195 )
196 }
197
198 #[test]
199 fn new_group_has_routing_function_and_empty_consumer_set() {
200 let routing_function = function();
201 let group = ConsumerGroup::new(routing_function.clone());
202
203 assert_eq!(
204 group.routing_function().content_hash(),
205 routing_function.content_hash()
206 );
207 assert!(group.consumers().is_empty());
208 assert!(format!("{group:?}").contains("ConsumerGroup"));
209 }
210
211 #[test]
212 fn consumer_set_is_ordered_and_deduplicated() {
213 let group = ConsumerGroup::new(function());
214
215 assert!(group.add_consumer(registration("B", 2)));
216 assert!(group.add_consumer(registration("A", 1)));
217 assert!(!group.add_consumer(registration("B", 22)));
218 assert!(group.add_consumer(registration("C", 3)));
219
220 assert_eq!(
221 ids(group.consumers()),
222 vec!["A".to_owned(), "B".to_owned(), "C".to_owned()]
223 );
224 assert_eq!(
225 group.snapshot().consumers()[1].participant(),
226 ParticipantPid::new(22)
227 );
228 }
229
230 #[test]
231 fn remove_consumer_affects_future_snapshots_only() {
232 let group = ConsumerGroup::new(function());
233 let _ = group.add_consumer(registration("A", 1));
234 let _ = group.add_consumer(registration("B", 2));
235 let _ = group.add_consumer(registration("C", 3));
236 let before = group.snapshot();
237
238 assert!(group.remove_consumer(&ConsumerId::new("B")));
239 assert!(!group.remove_consumer(&ConsumerId::new("B")));
240
241 assert_eq!(ids(group.consumers()), vec!["A".to_owned(), "C".to_owned()]);
242 assert_eq!(ids(before.consumer_ids()), vec!["A", "B", "C"]);
243 }
244
245 fn ids(consumers: Vec<ConsumerId>) -> Vec<String> {
246 consumers
247 .into_iter()
248 .map(|consumer| consumer.as_str().to_owned())
249 .collect()
250 }
251}