general_pub_sub/
lib.rs

1use itertools::Itertools;
2use std::error::Error;
3use std::marker::PhantomData;
4use std::{
5    collections::{BTreeSet, HashMap},
6    hash::Hash,
7};
8use wildmatch::WildMatch;
9
10/// A Unique Identifier
11///
12/// The "unique" aspect of this trait is enforced within the PubSub
13/// itself.  However, in addition to being unique, the identifier must
14/// implement (or derive) core::cmp::Ord and std::hash::Hash.
15pub trait UniqueIdentifier: Ord + Eq + Hash {}
16impl<TIdentifier: Ord + Hash> UniqueIdentifier for TIdentifier {}
17
18pub struct Message<'a, TMessage> {
19    pub contents: TMessage,
20    pub source: &'a str,
21}
22
23/// A PubSub Client
24///
25/// Trait describing a generic PubSub Client.
26///
27/// The identifier can be any data type so long as it conforms to
28/// the `UniqueIdentifier` trait.
29///
30/// Message can also be of any type.
31///
32/// # Examples
33///
34/// Basic Usage:
35///
36/// ```
37/// struct BasicClient {
38///   id: u32   
39/// }
40///
41/// impl Client<u32, &str> for BasicClient {
42///   fn get_id(&self) -> u32 {
43///      return self.id;
44///   }
45///
46///   fn send(&self, message: &str) {
47///       println!("Client ({}) Received: {}", self.id, message);
48///   }
49/// }
50/// ```
51///
52/// Multi-client Example:
53///
54/// ```
55/// struct ConsoleClient {
56///   id: u32
57/// }
58///
59/// impl Client<u32, &str> for ConsoleClient {
60///   fn get_id(&self) -> u32 {
61///      return self.id;
62///   }
63///
64///   fn send(&self, message: &str) {
65///       println!("Client ({}) Received: {}", self.id, message);
66///   }
67/// }
68///
69/// struct TcpClient {
70///   id: &str,
71///   stream: std::net::TcpStream
72/// }
73///
74/// impl Client<&str, &str> for TcpClient {
75///   fn get_id(&self) -> &str {
76///     return self.id;
77///   }
78///
79///   fn send(&self, message: &str) {
80///     self.stream.write(format!("Client ({}) Received: {}", self.id, message).as_bytes())
81///   }
82/// }
83///
84/// enum Clients {
85///   Console(ConsoleClient),
86///   Tcp(TcpClient)
87/// }
88///
89/// impl Client<&str, &str> for Clients {
90///   fn get_id(&self) -> &str {
91///     match self {
92///       Self::Console(client) => client.get_id().to_string(),
93///       Self::Tcp(client) => client.get_id()
94///     }
95///   }
96///
97///   fn send(&self, message: &str) {
98///     match self {
99///       Self::Console(client) => client.send(message),
100///       Self::Console(client) => client.send(message)
101///     }
102///   }
103/// }
104/// ```
105pub trait Client<TIdentifier: UniqueIdentifier, TMessage> {
106    /// Gets the `ID` of the `Client`. Must be unique.
107    fn get_id(&self) -> TIdentifier;
108
109    /// Sends a `Message` to a `Client`.
110    fn send(&mut self, message: &Message<TMessage>);
111}
112
113/// PubSubError is used for errors specific to `PubSub` (such as adding or removing `Client`s)
114#[derive(Debug)]
115pub enum PubSubError {
116    ClientAlreadySubscribedError,
117    ClientNotSubscribedError,
118    ChannelDoesNotExistError,
119    ClientWithIdentifierAlreadyExistsError,
120    ClientDoesNotExistError,
121}
122
123impl Error for PubSubError {}
124impl std::fmt::Display for PubSubError {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        match self {
127            Self::ClientAlreadySubscribedError => {
128                write!(f, "Client already subscribed to channel.")
129            }
130            Self::ClientNotSubscribedError => write!(f, "Client is not subscribed to channel."),
131            Self::ChannelDoesNotExistError => write!(f, "Channel does not exist."),
132            Self::ClientDoesNotExistError => write!(f, "Client does not exist."),
133            Self::ClientWithIdentifierAlreadyExistsError => {
134                write!(f, "Client with that identifier already exists.")
135            }
136        }
137    }
138}
139
140/// A PubSub
141#[derive(Clone)]
142pub struct PubSub<
143    'a,
144    TClient: Client<TIdentifier, TMessage>,
145    TIdentifier: UniqueIdentifier,
146    TMessage,
147> {
148    clients: HashMap<TIdentifier, TClient>,
149    channels: HashMap<&'a str, BTreeSet<TIdentifier>>,
150    pattern_channels: HashMap<&'a str, BTreeSet<TIdentifier>>,
151    phantom: PhantomData<TMessage>,
152}
153
154fn channel_is_pattern(channel: &str) -> bool {
155    channel.contains('*') || channel.contains('?')
156}
157
158/// Implementation for a `PubSub`
159///
160/// The standard workflow for a `PubSub` is to:
161///
162/// 1. Create a new `PubSub`.
163/// 2. Add one or more `Clients`.
164/// 3. Subscribe the `Clients` to `Channels` of interest.
165/// 4. Publish `Messages` to the `Channels`. The `Message` is broadcast to all `Clients` subscribed to the `Channel`.
166impl<
167        'a,
168        TClient: Client<TIdentifier, TMessage>,
169        TIdentifier: UniqueIdentifier,
170        TMessage: Clone + Copy,
171    > PubSub<'a, TClient, TIdentifier, TMessage>
172{
173    /// Creates a new `PubSub`
174    ///
175    /// All `Clients` of the `PubSub` must use the same type of `Identifier`
176    /// and receive the same type of `Message`.
177    pub fn new() -> PubSub<'a, TClient, TIdentifier, TMessage> {
178        PubSub {
179            clients: HashMap::new(),
180            channels: HashMap::new(),
181            pattern_channels: HashMap::new(),
182            phantom: PhantomData,
183        }
184    }
185
186    /// Adds a `Client` to the `PubSub`
187    pub fn add_client(&mut self, client: TClient) {
188        let token = client.get_id();
189        self.clients.insert(token, client);
190    }
191
192    // Unsubscribes a `Client` from all `Channels` and removes the `Client` from the `PubSub`.
193    pub fn remove_client(&mut self, client: TClient) {
194        let identifier = &client.get_id();
195        self.clients.remove(identifier);
196
197        for subbed_clients in self.channels.values_mut() {
198            subbed_clients.remove(identifier);
199        }
200
201        for subbed_clients in self.pattern_channels.values_mut() {
202            subbed_clients.remove(identifier);
203        }
204    }
205
206    fn get_channels_for_subscription(
207        &mut self,
208        channel: &'a str,
209    ) -> &mut HashMap<&'a str, BTreeSet<TIdentifier>> {
210        match channel_is_pattern(channel) {
211            true => &mut self.pattern_channels,
212            false => &mut self.channels,
213        }
214    }
215
216    /// Subscribes a `Client` to a `Channel`.
217    ///
218    /// Results in a `PubSubError` when a `Client` attempts to subscribe to a
219    /// `Channel` that it is already subscribed to.
220    pub fn sub_client(&mut self, client: TClient, channel: &'a str) -> Result<(), PubSubError> {
221        let target_channels = self.get_channels_for_subscription(channel);
222
223        let subbed_clients = target_channels.entry(channel).or_insert_with(BTreeSet::new);
224
225        let result = subbed_clients.insert(client.get_id());
226
227        if result {
228            Ok(())
229        } else {
230            Err(PubSubError::ClientAlreadySubscribedError)
231        }
232    }
233
234    /// Unsubscribes a `Client` from a `Channel`
235    ///
236    /// Results in a `PubSubError` when a `Client` attempts to unsubscribe
237    /// from a `Channel` it is not subscribed to.
238    pub fn unsub_client(&mut self, client: TClient, channel: &'a str) -> Result<(), PubSubError> {
239        let target_channels = self.get_channels_for_subscription(channel);
240
241        if let Some(subbed_clients) = target_channels.get_mut(channel) {
242            match subbed_clients.remove(&client.get_id()) {
243                true => Ok(()),
244                false => Err(PubSubError::ClientNotSubscribedError),
245            }
246        } else {
247            Err(PubSubError::ChannelDoesNotExistError)
248        }
249    }
250
251    /// Publishes a `Message` to all `Clients` subscribed to the provided `Channel`.
252    pub fn pub_message<TInputMessage: Into<TMessage>>(
253        &mut self,
254        channel: &str,
255        msg: TInputMessage,
256    ) {
257        let msg_ref = msg.into();
258
259        let message = Message {
260            contents: msg_ref,
261            source: channel,
262        };
263
264        let pattern_client_identifiers = self
265            .pattern_channels
266            .iter()
267            .filter(|(pattern, _)| WildMatch::new(pattern) == channel)
268            .map(|(_, clients)| clients.iter())
269            .flatten();
270
271        let subbed_clients = self.channels.get_mut(channel);
272        let subbed_client_identifiers = subbed_clients.iter().map(|client| client.iter()).flatten();
273
274        let unique_client_identifiers = subbed_client_identifiers
275            .chain(pattern_client_identifiers)
276            .unique();
277
278        for identifier in unique_client_identifiers {
279            if let Some(client) = self.clients.get_mut(identifier) {
280                client.send(&message);
281            }
282        }
283    }
284}
285
286impl<
287        'a,
288        TClient: Client<TIdentifier, TMessage>,
289        TIdentifier: UniqueIdentifier,
290        TMessage: Clone + Copy,
291    > Default for PubSub<'a, TClient, TIdentifier, TMessage>
292{
293    fn default() -> Self {
294        Self::new()
295    }
296}