jsonrpc_pubsub/
manager.rs1use std::collections::HashMap;
18use std::iter;
19use std::sync::{
20 atomic::{AtomicUsize, Ordering},
21 Arc,
22};
23
24use crate::core::futures::channel::oneshot;
25use crate::core::futures::{self, task, Future, FutureExt, TryFutureExt};
26use crate::{
27 typed::{Sink, Subscriber},
28 SubscriptionId,
29};
30
31use log::{error, warn};
32use parking_lot::Mutex;
33use rand::distributions::Alphanumeric;
34use rand::{thread_rng, Rng};
35
36pub type TaskExecutor = Arc<dyn futures::task::Spawn + Send + Sync>;
38
39type ActiveSubscriptions = Arc<Mutex<HashMap<SubscriptionId, oneshot::Sender<()>>>>;
40
41pub trait IdProvider {
43 type Id: Default + Into<SubscriptionId>;
45
46 fn next_id(&self) -> Self::Id;
48}
49
50#[derive(Clone, Debug)]
53pub struct NumericIdProvider {
54 current_id: Arc<AtomicUsize>,
55}
56
57impl NumericIdProvider {
58 pub fn new() -> Self {
60 Default::default()
61 }
62
63 pub fn with_id(id: AtomicUsize) -> Self {
66 Self {
67 current_id: Arc::new(id),
68 }
69 }
70}
71
72impl IdProvider for NumericIdProvider {
73 type Id = u64;
74
75 fn next_id(&self) -> Self::Id {
76 self.current_id.fetch_add(1, Ordering::AcqRel) as u64
77 }
78}
79
80impl Default for NumericIdProvider {
81 fn default() -> Self {
82 NumericIdProvider {
83 current_id: Arc::new(AtomicUsize::new(1)),
84 }
85 }
86}
87
88#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
91pub struct RandomStringIdProvider {
92 len: usize,
93}
94
95impl RandomStringIdProvider {
96 pub fn new() -> Self {
98 Default::default()
99 }
100
101 pub fn with_len(len: usize) -> Self {
104 Self { len }
105 }
106}
107
108impl IdProvider for RandomStringIdProvider {
109 type Id = String;
110
111 fn next_id(&self) -> Self::Id {
112 let mut rng = thread_rng();
113 let id: String = iter::repeat(())
114 .map(|()| rng.sample(Alphanumeric))
115 .take(self.len)
116 .collect();
117 id
118 }
119}
120
121impl Default for RandomStringIdProvider {
122 fn default() -> Self {
123 Self { len: 16 }
124 }
125}
126
127#[derive(Clone)]
132pub struct SubscriptionManager<I: IdProvider = RandomStringIdProvider> {
133 id_provider: I,
134 active_subscriptions: ActiveSubscriptions,
135 executor: TaskExecutor,
136}
137
138impl SubscriptionManager {
139 pub fn new(executor: TaskExecutor) -> Self {
143 Self {
144 id_provider: RandomStringIdProvider::default(),
145 active_subscriptions: Default::default(),
146 executor,
147 }
148 }
149}
150
151impl<I: IdProvider> SubscriptionManager<I> {
152 pub fn with_id_provider(id_provider: I, executor: TaskExecutor) -> Self {
155 Self {
156 id_provider,
157 active_subscriptions: Default::default(),
158 executor,
159 }
160 }
161
162 pub fn executor(&self) -> &TaskExecutor {
166 &self.executor
167 }
168
169 pub fn add<T, E, G, F>(&self, subscriber: Subscriber<T, E>, into_future: G) -> SubscriptionId
174 where
175 G: FnOnce(Sink<T, E>) -> F,
176 F: Future<Output = ()> + Send + 'static,
177 {
178 let id = self.id_provider.next_id();
179 let subscription_id: SubscriptionId = id.into();
180 if let Ok(sink) = subscriber.assign_id(subscription_id.clone()) {
181 let (tx, rx) = oneshot::channel();
182 let f = into_future(sink).fuse();
183 let rx = rx.map_err(|e| warn!("Error timing out: {:?}", e)).fuse();
184 let future = async move {
185 futures::pin_mut!(f);
186 futures::pin_mut!(rx);
187 futures::select! {
188 a = f => a,
189 _ = rx => (),
190 }
191 };
192
193 self.active_subscriptions.lock().insert(subscription_id.clone(), tx);
194 if self.executor.spawn_obj(task::FutureObj::new(Box::pin(future))).is_err() {
195 error!("Failed to spawn RPC subscription task");
196 }
197 }
198
199 subscription_id
200 }
201
202 pub fn cancel(&self, id: SubscriptionId) -> bool {
206 if let Some(tx) = self.active_subscriptions.lock().remove(&id) {
207 let _ = tx.send(());
208 return true;
209 }
210
211 false
212 }
213}
214
215impl<I: Default + IdProvider> SubscriptionManager<I> {
216 pub fn with_executor(executor: TaskExecutor) -> Self {
218 Self {
219 id_provider: Default::default(),
220 active_subscriptions: Default::default(),
221 executor,
222 }
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229 use crate::typed::Subscriber;
230 use futures::{executor, stream};
231 use futures::{FutureExt, StreamExt};
232
233 lazy_static::lazy_static! {
238 static ref EXECUTOR: executor::ThreadPool = executor::ThreadPool::new()
239 .expect("Failed to create thread pool executor for tests");
240 }
241
242 pub struct TestTaskExecutor;
243 impl task::Spawn for TestTaskExecutor {
244 fn spawn_obj(&self, future: task::FutureObj<'static, ()>) -> Result<(), task::SpawnError> {
245 EXECUTOR.spawn_obj(future)
246 }
247
248 fn status(&self) -> Result<(), task::SpawnError> {
249 EXECUTOR.status()
250 }
251 }
252
253 #[test]
254 fn making_a_numeric_id_provider_works() {
255 let provider = NumericIdProvider::new();
256 let expected_id = 1;
257 let actual_id = provider.next_id();
258
259 assert_eq!(actual_id, expected_id);
260 }
261
262 #[test]
263 fn default_numeric_id_provider_works() {
264 let provider: NumericIdProvider = Default::default();
265 let expected_id = 1;
266 let actual_id = provider.next_id();
267
268 assert_eq!(actual_id, expected_id);
269 }
270
271 #[test]
272 fn numeric_id_provider_with_id_works() {
273 let provider = NumericIdProvider::with_id(AtomicUsize::new(5));
274 let expected_id = 5;
275 let actual_id = provider.next_id();
276
277 assert_eq!(actual_id, expected_id);
278 }
279
280 #[test]
281 fn random_string_provider_returns_id_with_correct_default_len() {
282 let provider = RandomStringIdProvider::new();
283 let expected_len = 16;
284 let actual_len = provider.next_id().len();
285
286 assert_eq!(actual_len, expected_len);
287 }
288
289 #[test]
290 fn random_string_provider_returns_id_with_correct_user_given_len() {
291 let expected_len = 10;
292 let provider = RandomStringIdProvider::with_len(expected_len);
293 let actual_len = provider.next_id().len();
294
295 assert_eq!(actual_len, expected_len);
296 }
297
298 #[test]
299 fn new_subscription_manager_defaults_to_random_string_provider() {
300 let manager = SubscriptionManager::new(Arc::new(TestTaskExecutor));
301 let subscriber = Subscriber::<u64>::new_test("test_subTest").0;
302 let stream = stream::iter(vec![Ok(Ok(1))]);
303
304 let id = manager.add(subscriber, move |sink| stream.forward(sink).map(|_| ()));
305
306 assert!(matches!(id, SubscriptionId::String(_)))
307 }
308
309 #[test]
310 fn new_subscription_manager_works_with_numeric_id_provider() {
311 let id_provider = NumericIdProvider::default();
312 let manager = SubscriptionManager::with_id_provider(id_provider, Arc::new(TestTaskExecutor));
313
314 let subscriber = Subscriber::<u64>::new_test("test_subTest").0;
315 let stream = stream::iter(vec![Ok(Ok(1))]);
316
317 let id = manager.add(subscriber, move |sink| stream.forward(sink).map(|_| ()));
318
319 assert!(matches!(id, SubscriptionId::Number(_)))
320 }
321
322 #[test]
323 fn new_subscription_manager_works_with_random_string_provider() {
324 let id_provider = RandomStringIdProvider::default();
325 let manager = SubscriptionManager::with_id_provider(id_provider, Arc::new(TestTaskExecutor));
326
327 let subscriber = Subscriber::<u64>::new_test("test_subTest").0;
328 let stream = stream::iter(vec![Ok(Ok(1))]);
329
330 let id = manager.add(subscriber, move |sink| stream.forward(sink).map(|_| ()));
331
332 assert!(matches!(id, SubscriptionId::String(_)))
333 }
334
335 #[test]
336 fn subscription_is_canceled_if_it_existed() {
337 let manager = SubscriptionManager::<NumericIdProvider>::with_executor(Arc::new(TestTaskExecutor));
338 let (subscriber, _recv, _) = Subscriber::<u64>::new_test("test_subTest");
341
342 let (mut tx, rx) = futures::channel::mpsc::channel(8);
343 tx.start_send(1).unwrap();
344 let id = manager.add(subscriber, move |sink| {
345 let rx = rx.map(|v| Ok(Ok(v)));
346 rx.forward(sink).map(|_| ())
347 });
348
349 let is_cancelled = manager.cancel(id);
350 assert!(is_cancelled);
351 }
352
353 #[test]
354 fn subscription_is_not_canceled_because_it_didnt_exist() {
355 let manager = SubscriptionManager::new(Arc::new(TestTaskExecutor));
356
357 let id: SubscriptionId = 23u32.into();
358 let is_cancelled = manager.cancel(id);
359 let is_not_cancelled = !is_cancelled;
360
361 assert!(is_not_cancelled);
362 }
363
364 #[test]
365 fn is_send_sync() {
366 fn send_sync<T: Send + Sync>() {}
367
368 send_sync::<SubscriptionManager>();
369 }
370}