1use itertools::Itertools;
2use std::error::Error;
3use std::marker::PhantomData;
4use std::{
5 collections::{BTreeSet, HashMap},
6 hash::Hash,
7};
8use wildmatch::WildMatch;
9
10pub trait UniqueIdentifier: Ord + Eq + Hash {}
16impl<TIdentifier: Ord + Hash> UniqueIdentifier for TIdentifier {}
17
18pub struct Message<'a, TMessage> {
19 pub contents: TMessage,
20 pub source: &'a str,
21}
22
23pub trait Client<TIdentifier: UniqueIdentifier, TMessage> {
106 fn get_id(&self) -> TIdentifier;
108
109 fn send(&mut self, message: &Message<TMessage>);
111}
112
113#[derive(Debug)]
115pub enum PubSubError {
116 ClientAlreadySubscribedError,
117 ClientNotSubscribedError,
118 ChannelDoesNotExistError,
119 ClientWithIdentifierAlreadyExistsError,
120 ClientDoesNotExistError,
121}
122
123impl Error for PubSubError {}
124impl std::fmt::Display for PubSubError {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 match self {
127 Self::ClientAlreadySubscribedError => {
128 write!(f, "Client already subscribed to channel.")
129 }
130 Self::ClientNotSubscribedError => write!(f, "Client is not subscribed to channel."),
131 Self::ChannelDoesNotExistError => write!(f, "Channel does not exist."),
132 Self::ClientDoesNotExistError => write!(f, "Client does not exist."),
133 Self::ClientWithIdentifierAlreadyExistsError => {
134 write!(f, "Client with that identifier already exists.")
135 }
136 }
137 }
138}
139
140#[derive(Clone)]
142pub struct PubSub<
143 'a,
144 TClient: Client<TIdentifier, TMessage>,
145 TIdentifier: UniqueIdentifier,
146 TMessage,
147> {
148 clients: HashMap<TIdentifier, TClient>,
149 channels: HashMap<&'a str, BTreeSet<TIdentifier>>,
150 pattern_channels: HashMap<&'a str, BTreeSet<TIdentifier>>,
151 phantom: PhantomData<TMessage>,
152}
153
154fn channel_is_pattern(channel: &str) -> bool {
155 channel.contains('*') || channel.contains('?')
156}
157
158impl<
167 'a,
168 TClient: Client<TIdentifier, TMessage>,
169 TIdentifier: UniqueIdentifier,
170 TMessage: Clone + Copy,
171 > PubSub<'a, TClient, TIdentifier, TMessage>
172{
173 pub fn new() -> PubSub<'a, TClient, TIdentifier, TMessage> {
178 PubSub {
179 clients: HashMap::new(),
180 channels: HashMap::new(),
181 pattern_channels: HashMap::new(),
182 phantom: PhantomData,
183 }
184 }
185
186 pub fn add_client(&mut self, client: TClient) {
188 let token = client.get_id();
189 self.clients.insert(token, client);
190 }
191
192 pub fn remove_client(&mut self, client: TClient) {
194 let identifier = &client.get_id();
195 self.clients.remove(identifier);
196
197 for subbed_clients in self.channels.values_mut() {
198 subbed_clients.remove(identifier);
199 }
200
201 for subbed_clients in self.pattern_channels.values_mut() {
202 subbed_clients.remove(identifier);
203 }
204 }
205
206 fn get_channels_for_subscription(
207 &mut self,
208 channel: &'a str,
209 ) -> &mut HashMap<&'a str, BTreeSet<TIdentifier>> {
210 match channel_is_pattern(channel) {
211 true => &mut self.pattern_channels,
212 false => &mut self.channels,
213 }
214 }
215
216 pub fn sub_client(&mut self, client: TClient, channel: &'a str) -> Result<(), PubSubError> {
221 let target_channels = self.get_channels_for_subscription(channel);
222
223 let subbed_clients = target_channels.entry(channel).or_insert_with(BTreeSet::new);
224
225 let result = subbed_clients.insert(client.get_id());
226
227 if result {
228 Ok(())
229 } else {
230 Err(PubSubError::ClientAlreadySubscribedError)
231 }
232 }
233
234 pub fn unsub_client(&mut self, client: TClient, channel: &'a str) -> Result<(), PubSubError> {
239 let target_channels = self.get_channels_for_subscription(channel);
240
241 if let Some(subbed_clients) = target_channels.get_mut(channel) {
242 match subbed_clients.remove(&client.get_id()) {
243 true => Ok(()),
244 false => Err(PubSubError::ClientNotSubscribedError),
245 }
246 } else {
247 Err(PubSubError::ChannelDoesNotExistError)
248 }
249 }
250
251 pub fn pub_message<TInputMessage: Into<TMessage>>(
253 &mut self,
254 channel: &str,
255 msg: TInputMessage,
256 ) {
257 let msg_ref = msg.into();
258
259 let message = Message {
260 contents: msg_ref,
261 source: channel,
262 };
263
264 let pattern_client_identifiers = self
265 .pattern_channels
266 .iter()
267 .filter(|(pattern, _)| WildMatch::new(pattern) == channel)
268 .map(|(_, clients)| clients.iter())
269 .flatten();
270
271 let subbed_clients = self.channels.get_mut(channel);
272 let subbed_client_identifiers = subbed_clients.iter().map(|client| client.iter()).flatten();
273
274 let unique_client_identifiers = subbed_client_identifiers
275 .chain(pattern_client_identifiers)
276 .unique();
277
278 for identifier in unique_client_identifiers {
279 if let Some(client) = self.clients.get_mut(identifier) {
280 client.send(&message);
281 }
282 }
283 }
284}
285
286impl<
287 'a,
288 TClient: Client<TIdentifier, TMessage>,
289 TIdentifier: UniqueIdentifier,
290 TMessage: Clone + Copy,
291 > Default for PubSub<'a, TClient, TIdentifier, TMessage>
292{
293 fn default() -> Self {
294 Self::new()
295 }
296}