jsonrpc_pubsub/
manager.rs

1//! The SubscriptionManager used to manage subscription based RPCs.
2//!
3//! The manager provides four main things in terms of functionality:
4//!
5//! 1. The ability to create unique subscription IDs through the
6//! use of the `IdProvider` trait. Two implementations are availble
7//! out of the box, a `NumericIdProvider` and a `RandomStringIdProvider`.
8//!
9//! 2. An executor with which to drive `Future`s to completion.
10//!
11//! 3. A way to add new subscriptions. Subscriptions should come in the form
12//! of a `Stream`. These subscriptions will be transformed into notifications
13//! by the manager, which can be consumed by the client.
14//!
15//! 4. A way to cancel any currently active subscription.
16
17use std::collections::HashMap;
18use std::iter;
19use std::sync::{
20	atomic::{AtomicUsize, Ordering},
21	Arc,
22};
23
24use crate::core::futures::channel::oneshot;
25use crate::core::futures::{self, task, Future, FutureExt, TryFutureExt};
26use crate::{
27	typed::{Sink, Subscriber},
28	SubscriptionId,
29};
30
31use log::{error, warn};
32use parking_lot::Mutex;
33use rand::distributions::Alphanumeric;
34use rand::{thread_rng, Rng};
35
36/// Cloneable `Spawn` handle.
37pub type TaskExecutor = Arc<dyn futures::task::Spawn + Send + Sync>;
38
39type ActiveSubscriptions = Arc<Mutex<HashMap<SubscriptionId, oneshot::Sender<()>>>>;
40
41/// Trait used to provide unique subscription IDs.
42pub trait IdProvider {
43	/// A unique ID used to identify a subscription.
44	type Id: Default + Into<SubscriptionId>;
45
46	/// Returns the next ID for the subscription.
47	fn next_id(&self) -> Self::Id;
48}
49
50/// Provides a thread-safe incrementing integer which
51/// can be used as a subscription ID.
52#[derive(Clone, Debug)]
53pub struct NumericIdProvider {
54	current_id: Arc<AtomicUsize>,
55}
56
57impl NumericIdProvider {
58	/// Create a new NumericIdProvider.
59	pub fn new() -> Self {
60		Default::default()
61	}
62
63	/// Create a new NumericIdProvider starting from
64	/// the given ID.
65	pub fn with_id(id: AtomicUsize) -> Self {
66		Self {
67			current_id: Arc::new(id),
68		}
69	}
70}
71
72impl IdProvider for NumericIdProvider {
73	type Id = u64;
74
75	fn next_id(&self) -> Self::Id {
76		self.current_id.fetch_add(1, Ordering::AcqRel) as u64
77	}
78}
79
80impl Default for NumericIdProvider {
81	fn default() -> Self {
82		NumericIdProvider {
83			current_id: Arc::new(AtomicUsize::new(1)),
84		}
85	}
86}
87
88/// Used to generate random strings for use as
89/// subscription IDs.
90#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
91pub struct RandomStringIdProvider {
92	len: usize,
93}
94
95impl RandomStringIdProvider {
96	/// Create a new RandomStringIdProvider.
97	pub fn new() -> Self {
98		Default::default()
99	}
100
101	/// Create a new RandomStringIdProvider, which will generate
102	/// random id strings of the given length.
103	pub fn with_len(len: usize) -> Self {
104		Self { len }
105	}
106}
107
108impl IdProvider for RandomStringIdProvider {
109	type Id = String;
110
111	fn next_id(&self) -> Self::Id {
112		let mut rng = thread_rng();
113		let id: String = iter::repeat(())
114			.map(|()| rng.sample(Alphanumeric))
115			.take(self.len)
116			.collect();
117		id
118	}
119}
120
121impl Default for RandomStringIdProvider {
122	fn default() -> Self {
123		Self { len: 16 }
124	}
125}
126
127/// Subscriptions manager.
128///
129/// Takes care of assigning unique subscription ids and
130/// driving the sinks into completion.
131#[derive(Clone)]
132pub struct SubscriptionManager<I: IdProvider = RandomStringIdProvider> {
133	id_provider: I,
134	active_subscriptions: ActiveSubscriptions,
135	executor: TaskExecutor,
136}
137
138impl SubscriptionManager {
139	/// Creates a new SubscriptionManager.
140	///
141	/// Uses `RandomStringIdProvider` as the ID provider.
142	pub fn new(executor: TaskExecutor) -> Self {
143		Self {
144			id_provider: RandomStringIdProvider::default(),
145			active_subscriptions: Default::default(),
146			executor,
147		}
148	}
149}
150
151impl<I: IdProvider> SubscriptionManager<I> {
152	/// Creates a new SubscriptionManager with the specified
153	/// ID provider.
154	pub fn with_id_provider(id_provider: I, executor: TaskExecutor) -> Self {
155		Self {
156			id_provider,
157			active_subscriptions: Default::default(),
158			executor,
159		}
160	}
161
162	/// Borrows the internal task executor.
163	///
164	/// This can be used to spawn additional tasks on the underlying event loop.
165	pub fn executor(&self) -> &TaskExecutor {
166		&self.executor
167	}
168
169	/// Creates new subscription for given subscriber.
170	///
171	/// Second parameter is a function that converts Subscriber Sink into a Future.
172	/// This future will be driven to completion by the underlying event loop
173	pub fn add<T, E, G, F>(&self, subscriber: Subscriber<T, E>, into_future: G) -> SubscriptionId
174	where
175		G: FnOnce(Sink<T, E>) -> F,
176		F: Future<Output = ()> + Send + 'static,
177	{
178		let id = self.id_provider.next_id();
179		let subscription_id: SubscriptionId = id.into();
180		if let Ok(sink) = subscriber.assign_id(subscription_id.clone()) {
181			let (tx, rx) = oneshot::channel();
182			let f = into_future(sink).fuse();
183			let rx = rx.map_err(|e| warn!("Error timing out: {:?}", e)).fuse();
184			let future = async move {
185				futures::pin_mut!(f);
186				futures::pin_mut!(rx);
187				futures::select! {
188					a = f => a,
189					_ = rx => (),
190				}
191			};
192
193			self.active_subscriptions.lock().insert(subscription_id.clone(), tx);
194			if self.executor.spawn_obj(task::FutureObj::new(Box::pin(future))).is_err() {
195				error!("Failed to spawn RPC subscription task");
196			}
197		}
198
199		subscription_id
200	}
201
202	/// Cancel subscription.
203	///
204	/// Returns true if subscription existed or false otherwise.
205	pub fn cancel(&self, id: SubscriptionId) -> bool {
206		if let Some(tx) = self.active_subscriptions.lock().remove(&id) {
207			let _ = tx.send(());
208			return true;
209		}
210
211		false
212	}
213}
214
215impl<I: Default + IdProvider> SubscriptionManager<I> {
216	/// Creates a new SubscriptionManager.
217	pub fn with_executor(executor: TaskExecutor) -> Self {
218		Self {
219			id_provider: Default::default(),
220			active_subscriptions: Default::default(),
221			executor,
222		}
223	}
224}
225
226#[cfg(test)]
227mod tests {
228	use super::*;
229	use crate::typed::Subscriber;
230	use futures::{executor, stream};
231	use futures::{FutureExt, StreamExt};
232
233	// Executor shared by all tests.
234	//
235	// This shared executor is used to prevent `Too many open files` errors
236	// on systems with a lot of cores.
237	lazy_static::lazy_static! {
238		static ref EXECUTOR: executor::ThreadPool = executor::ThreadPool::new()
239			.expect("Failed to create thread pool executor for tests");
240	}
241
242	pub struct TestTaskExecutor;
243	impl task::Spawn for TestTaskExecutor {
244		fn spawn_obj(&self, future: task::FutureObj<'static, ()>) -> Result<(), task::SpawnError> {
245			EXECUTOR.spawn_obj(future)
246		}
247
248		fn status(&self) -> Result<(), task::SpawnError> {
249			EXECUTOR.status()
250		}
251	}
252
253	#[test]
254	fn making_a_numeric_id_provider_works() {
255		let provider = NumericIdProvider::new();
256		let expected_id = 1;
257		let actual_id = provider.next_id();
258
259		assert_eq!(actual_id, expected_id);
260	}
261
262	#[test]
263	fn default_numeric_id_provider_works() {
264		let provider: NumericIdProvider = Default::default();
265		let expected_id = 1;
266		let actual_id = provider.next_id();
267
268		assert_eq!(actual_id, expected_id);
269	}
270
271	#[test]
272	fn numeric_id_provider_with_id_works() {
273		let provider = NumericIdProvider::with_id(AtomicUsize::new(5));
274		let expected_id = 5;
275		let actual_id = provider.next_id();
276
277		assert_eq!(actual_id, expected_id);
278	}
279
280	#[test]
281	fn random_string_provider_returns_id_with_correct_default_len() {
282		let provider = RandomStringIdProvider::new();
283		let expected_len = 16;
284		let actual_len = provider.next_id().len();
285
286		assert_eq!(actual_len, expected_len);
287	}
288
289	#[test]
290	fn random_string_provider_returns_id_with_correct_user_given_len() {
291		let expected_len = 10;
292		let provider = RandomStringIdProvider::with_len(expected_len);
293		let actual_len = provider.next_id().len();
294
295		assert_eq!(actual_len, expected_len);
296	}
297
298	#[test]
299	fn new_subscription_manager_defaults_to_random_string_provider() {
300		let manager = SubscriptionManager::new(Arc::new(TestTaskExecutor));
301		let subscriber = Subscriber::<u64>::new_test("test_subTest").0;
302		let stream = stream::iter(vec![Ok(Ok(1))]);
303
304		let id = manager.add(subscriber, move |sink| stream.forward(sink).map(|_| ()));
305
306		assert!(matches!(id, SubscriptionId::String(_)))
307	}
308
309	#[test]
310	fn new_subscription_manager_works_with_numeric_id_provider() {
311		let id_provider = NumericIdProvider::default();
312		let manager = SubscriptionManager::with_id_provider(id_provider, Arc::new(TestTaskExecutor));
313
314		let subscriber = Subscriber::<u64>::new_test("test_subTest").0;
315		let stream = stream::iter(vec![Ok(Ok(1))]);
316
317		let id = manager.add(subscriber, move |sink| stream.forward(sink).map(|_| ()));
318
319		assert!(matches!(id, SubscriptionId::Number(_)))
320	}
321
322	#[test]
323	fn new_subscription_manager_works_with_random_string_provider() {
324		let id_provider = RandomStringIdProvider::default();
325		let manager = SubscriptionManager::with_id_provider(id_provider, Arc::new(TestTaskExecutor));
326
327		let subscriber = Subscriber::<u64>::new_test("test_subTest").0;
328		let stream = stream::iter(vec![Ok(Ok(1))]);
329
330		let id = manager.add(subscriber, move |sink| stream.forward(sink).map(|_| ()));
331
332		assert!(matches!(id, SubscriptionId::String(_)))
333	}
334
335	#[test]
336	fn subscription_is_canceled_if_it_existed() {
337		let manager = SubscriptionManager::<NumericIdProvider>::with_executor(Arc::new(TestTaskExecutor));
338		// Need to bind receiver here (unlike the other tests) or else the subscriber
339		// will think the client has disconnected and not update `active_subscriptions`
340		let (subscriber, _recv, _) = Subscriber::<u64>::new_test("test_subTest");
341
342		let (mut tx, rx) = futures::channel::mpsc::channel(8);
343		tx.start_send(1).unwrap();
344		let id = manager.add(subscriber, move |sink| {
345			let rx = rx.map(|v| Ok(Ok(v)));
346			rx.forward(sink).map(|_| ())
347		});
348
349		let is_cancelled = manager.cancel(id);
350		assert!(is_cancelled);
351	}
352
353	#[test]
354	fn subscription_is_not_canceled_because_it_didnt_exist() {
355		let manager = SubscriptionManager::new(Arc::new(TestTaskExecutor));
356
357		let id: SubscriptionId = 23u32.into();
358		let is_cancelled = manager.cancel(id);
359		let is_not_cancelled = !is_cancelled;
360
361		assert!(is_not_cancelled);
362	}
363
364	#[test]
365	fn is_send_sync() {
366		fn send_sync<T: Send + Sync>() {}
367
368		send_sync::<SubscriptionManager>();
369	}
370}