1use parking_lot::Mutex;
4use std::collections::HashMap;
5use std::fmt;
6use std::sync::Arc;
7
8use crate::core::futures::sync::mpsc;
9use crate::core::futures::{self, future, Future, Sink as FuturesSink};
10use crate::core::{self, BoxFuture};
11
12use crate::handler::{SubscribeRpcMethod, UnsubscribeRpcMethod};
13use crate::types::{PubSubMetadata, SinkResult, SubscriptionId, TransportError, TransportSender};
14
15pub struct Session {
18 active_subscriptions: Mutex<HashMap<(SubscriptionId, String), Box<Fn(SubscriptionId) + Send + 'static>>>,
19 transport: TransportSender,
20 on_drop: Mutex<Vec<Box<FnMut() + Send>>>,
21}
22
23impl fmt::Debug for Session {
24 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
25 fmt.debug_struct("pubsub::Session")
26 .field("active_subscriptions", &self.active_subscriptions.lock().len())
27 .field("transport", &self.transport)
28 .finish()
29 }
30}
31
32impl Session {
33 pub fn new(sender: TransportSender) -> Self {
36 Session {
37 active_subscriptions: Default::default(),
38 transport: sender,
39 on_drop: Default::default(),
40 }
41 }
42
43 pub fn sender(&self) -> TransportSender {
45 self.transport.clone()
46 }
47
48 pub fn on_drop<F: FnOnce() + Send + 'static>(&self, on_drop: F) {
50 let mut func = Some(on_drop);
51 self.on_drop.lock().push(Box::new(move || {
52 if let Some(f) = func.take() {
53 f();
54 }
55 }));
56 }
57
58 fn add_subscription<F>(&self, name: &str, id: &SubscriptionId, remove: F)
60 where
61 F: Fn(SubscriptionId) + Send + 'static,
62 {
63 let ret = self
64 .active_subscriptions
65 .lock()
66 .insert((id.clone(), name.into()), Box::new(remove));
67 if let Some(remove) = ret {
68 warn!("SubscriptionId collision. Unsubscribing previous client.");
69 remove(id.clone());
70 }
71 }
72
73 fn remove_subscription(&self, name: &str, id: &SubscriptionId) {
75 self.active_subscriptions.lock().remove(&(id.clone(), name.into()));
76 }
77}
78
79impl Drop for Session {
80 fn drop(&mut self) {
81 let mut active = self.active_subscriptions.lock();
82 for (id, remove) in active.drain() {
83 remove(id.0)
84 }
85
86 let mut on_drop = self.on_drop.lock();
87 for mut on_drop in on_drop.drain(..) {
88 on_drop();
89 }
90 }
91}
92
93#[derive(Debug, Clone)]
95pub struct Sink {
96 notification: String,
97 transport: TransportSender,
98}
99
100impl Sink {
101 pub fn notify(&self, val: core::Params) -> SinkResult {
103 let val = self.params_to_string(val);
104 self.transport.clone().send(val.0)
105 }
106
107 fn params_to_string(&self, val: core::Params) -> (String, core::Params) {
108 let notification = core::Notification {
109 jsonrpc: Some(core::Version::V2),
110 method: self.notification.clone(),
111 params: val,
112 };
113 (
114 core::to_string(¬ification).expect("Notification serialization never fails."),
115 notification.params,
116 )
117 }
118}
119
120impl FuturesSink for Sink {
121 type SinkItem = core::Params;
122 type SinkError = TransportError;
123
124 fn start_send(&mut self, item: Self::SinkItem) -> futures::StartSend<Self::SinkItem, Self::SinkError> {
125 let (val, params) = self.params_to_string(item);
126 self.transport.start_send(val).map(|result| match result {
127 futures::AsyncSink::Ready => futures::AsyncSink::Ready,
128 futures::AsyncSink::NotReady(_) => futures::AsyncSink::NotReady(params),
129 })
130 }
131
132 fn poll_complete(&mut self) -> futures::Poll<(), Self::SinkError> {
133 self.transport.poll_complete()
134 }
135
136 fn close(&mut self) -> futures::Poll<(), Self::SinkError> {
137 self.transport.close()
138 }
139}
140
141#[derive(Debug)]
144pub struct Subscriber {
145 notification: String,
146 transport: TransportSender,
147 sender: crate::oneshot::Sender<Result<SubscriptionId, core::Error>>,
148}
149
150impl Subscriber {
151 pub fn new_test<T: Into<String>>(
155 method: T,
156 ) -> (
157 Self,
158 crate::oneshot::Receiver<Result<SubscriptionId, core::Error>>,
159 mpsc::Receiver<String>,
160 ) {
161 let (sender, id_receiver) = crate::oneshot::channel();
162 let (transport, transport_receiver) = mpsc::channel(1);
163
164 let subscriber = Subscriber {
165 notification: method.into(),
166 transport,
167 sender,
168 };
169
170 (subscriber, id_receiver, transport_receiver)
171 }
172
173 pub fn assign_id(self, id: SubscriptionId) -> Result<Sink, ()> {
177 let Self {
178 notification,
179 transport,
180 sender,
181 } = self;
182 sender
183 .send(Ok(id))
184 .map(|_| Sink {
185 notification,
186 transport,
187 })
188 .map_err(|_| ())
189 }
190
191 pub fn assign_id_async(self, id: SubscriptionId) -> impl Future<Item = Sink, Error = ()> {
196 let Self {
197 notification,
198 transport,
199 sender,
200 } = self;
201 sender
202 .send_and_wait(Ok(id))
203 .map(|_| Sink {
204 notification,
205 transport,
206 })
207 .map_err(|_| ())
208 }
209
210 pub fn reject(self, error: core::Error) -> Result<(), ()> {
214 self.sender.send(Err(error)).map_err(|_| ())
215 }
216
217 pub fn reject_async(self, error: core::Error) -> impl Future<Item = (), Error = ()> {
222 self.sender.send_and_wait(Err(error)).map(|_| ()).map_err(|_| ())
223 }
224}
225
226pub fn new_subscription<M, F, G>(notification: &str, subscribe: F, unsubscribe: G) -> (Subscribe<F, G>, Unsubscribe<G>)
228where
229 M: PubSubMetadata,
230 F: SubscribeRpcMethod<M>,
231 G: UnsubscribeRpcMethod<M>,
232{
233 let unsubscribe = Arc::new(unsubscribe);
234 let subscribe = Subscribe {
235 notification: notification.to_owned(),
236 unsubscribe: unsubscribe.clone(),
237 subscribe,
238 };
239
240 let unsubscribe = Unsubscribe {
241 notification: notification.into(),
242 unsubscribe,
243 };
244
245 (subscribe, unsubscribe)
246}
247
248fn subscription_rejected() -> core::Error {
249 core::Error {
250 code: core::ErrorCode::ServerError(-32091),
251 message: "Subscription rejected".into(),
252 data: None,
253 }
254}
255
256fn subscriptions_unavailable() -> core::Error {
257 core::Error {
258 code: core::ErrorCode::ServerError(-32090),
259 message: "Subscriptions are not available on this transport.".into(),
260 data: None,
261 }
262}
263
264pub struct Subscribe<F, G> {
266 notification: String,
267 subscribe: F,
268 unsubscribe: Arc<G>,
269}
270
271impl<M, F, G> core::RpcMethod<M> for Subscribe<F, G>
272where
273 M: PubSubMetadata,
274 F: SubscribeRpcMethod<M>,
275 G: UnsubscribeRpcMethod<M>,
276{
277 fn call(&self, params: core::Params, meta: M) -> BoxFuture<core::Value> {
278 match meta.session() {
279 Some(session) => {
280 let (tx, rx) = crate::oneshot::channel();
281
282 let subscriber = Subscriber {
284 notification: self.notification.clone(),
285 transport: session.sender(),
286 sender: tx,
287 };
288 self.subscribe.call(params, meta, subscriber);
289
290 let unsub = self.unsubscribe.clone();
291 let notification = self.notification.clone();
292 let subscribe_future = rx.map_err(|_| subscription_rejected()).and_then(move |result| {
293 futures::done(match result {
294 Ok(id) => {
295 session.add_subscription(¬ification, &id, move |id| {
296 let _ = unsub.call(id, None).wait();
297 });
298 Ok(id.into())
299 }
300 Err(e) => Err(e),
301 })
302 });
303 Box::new(subscribe_future)
304 }
305 None => Box::new(future::err(subscriptions_unavailable())),
306 }
307 }
308}
309
310pub struct Unsubscribe<G> {
312 notification: String,
313 unsubscribe: Arc<G>,
314}
315
316impl<M, G> core::RpcMethod<M> for Unsubscribe<G>
317where
318 M: PubSubMetadata,
319 G: UnsubscribeRpcMethod<M>,
320{
321 fn call(&self, params: core::Params, meta: M) -> BoxFuture<core::Value> {
322 let id = match params {
323 core::Params::Array(ref vec) if vec.len() == 1 => SubscriptionId::parse_value(&vec[0]),
324 _ => None,
325 };
326 match (meta.session(), id) {
327 (Some(session), Some(id)) => {
328 session.remove_subscription(&self.notification, &id);
329 Box::new(self.unsubscribe.call(id, Some(meta)))
330 }
331 (Some(_), None) => Box::new(future::err(core::Error::invalid_params("Expected subscription id."))),
332 _ => Box::new(future::err(subscriptions_unavailable())),
333 }
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use crate::core;
340 use crate::core::futures::sync::mpsc;
341 use crate::core::futures::{Async, Future, Stream};
342 use crate::core::RpcMethod;
343 use crate::types::{PubSubMetadata, SubscriptionId};
344 use std::sync::atomic::{AtomicBool, Ordering};
345 use std::sync::Arc;
346
347 use super::{new_subscription, Session, Sink, Subscriber};
348
349 fn session() -> (Session, mpsc::Receiver<String>) {
350 let (tx, rx) = mpsc::channel(1);
351 (Session::new(tx), rx)
352 }
353
354 #[test]
355 fn should_unregister_on_drop() {
356 let id = SubscriptionId::Number(1);
358 let called = Arc::new(AtomicBool::new(false));
359 let called2 = called.clone();
360 let session = session().0;
361 session.add_subscription("test", &id, move |id| {
362 assert_eq!(id, SubscriptionId::Number(1));
363 called2.store(true, Ordering::SeqCst);
364 });
365
366 drop(session);
368
369 assert_eq!(called.load(Ordering::SeqCst), true);
371 }
372
373 #[test]
374 fn should_remove_subscription() {
375 let id = SubscriptionId::Number(1);
377 let called = Arc::new(AtomicBool::new(false));
378 let called2 = called.clone();
379 let session = session().0;
380 session.add_subscription("test", &id, move |id| {
381 assert_eq!(id, SubscriptionId::Number(1));
382 called2.store(true, Ordering::SeqCst);
383 });
384
385 session.remove_subscription("test", &id);
387 drop(session);
388
389 assert_eq!(called.load(Ordering::SeqCst), false);
391 }
392
393 #[test]
394 fn should_unregister_in_case_of_collision() {
395 let id = SubscriptionId::Number(1);
397 let called = Arc::new(AtomicBool::new(false));
398 let called2 = called.clone();
399 let session = session().0;
400 session.add_subscription("test", &id, move |id| {
401 assert_eq!(id, SubscriptionId::Number(1));
402 called2.store(true, Ordering::SeqCst);
403 });
404
405 session.add_subscription("test", &id, |_| {});
407
408 assert_eq!(called.load(Ordering::SeqCst), true);
410 }
411
412 #[test]
413 fn should_send_notification_to_the_transport() {
414 let (tx, mut rx) = mpsc::channel(1);
416 let sink = Sink {
417 notification: "test".into(),
418 transport: tx,
419 };
420
421 sink.notify(core::Params::Array(vec![core::Value::Number(10.into())]))
423 .wait()
424 .unwrap();
425
426 assert_eq!(
428 rx.poll().unwrap(),
429 Async::Ready(Some(r#"{"jsonrpc":"2.0","method":"test","params":[10]}"#.into()))
430 );
431 }
432
433 #[test]
434 fn should_assign_id() {
435 let (transport, _) = mpsc::channel(1);
437 let (tx, mut rx) = crate::oneshot::channel();
438 let subscriber = Subscriber {
439 notification: "test".into(),
440 transport,
441 sender: tx,
442 };
443
444 let sink = subscriber.assign_id_async(SubscriptionId::Number(5));
446
447 assert_eq!(rx.poll().unwrap(), Async::Ready(Ok(SubscriptionId::Number(5))));
449 let sink = sink.wait().unwrap();
450 assert_eq!(sink.notification, "test".to_owned());
451 }
452
453 #[test]
454 fn should_reject() {
455 let (transport, _) = mpsc::channel(1);
457 let (tx, mut rx) = crate::oneshot::channel();
458 let subscriber = Subscriber {
459 notification: "test".into(),
460 transport,
461 sender: tx,
462 };
463 let error = core::Error {
464 code: core::ErrorCode::InvalidRequest,
465 message: "Cannot start subscription now.".into(),
466 data: None,
467 };
468
469 let reject = subscriber.reject_async(error.clone());
471
472 assert_eq!(rx.poll().unwrap(), Async::Ready(Err(error)));
474 reject.wait().unwrap();
475 }
476
477 #[derive(Clone, Default)]
478 struct Metadata;
479 impl core::Metadata for Metadata {}
480 impl PubSubMetadata for Metadata {
481 fn session(&self) -> Option<Arc<Session>> {
482 Some(Arc::new(session().0))
483 }
484 }
485
486 #[test]
487 fn should_subscribe() {
488 let called = Arc::new(AtomicBool::new(false));
490 let called2 = called.clone();
491 let (subscribe, _) = new_subscription(
492 "test".into(),
493 move |params, _meta, _subscriber| {
494 assert_eq!(params, core::Params::None);
495 called2.store(true, Ordering::SeqCst);
496 },
497 |_id, _meta| Ok(core::Value::Bool(true)),
498 );
499 let meta = Metadata;
500
501 let result = subscribe.call(core::Params::None, meta);
503
504 assert_eq!(called.load(Ordering::SeqCst), true);
506 assert_eq!(
507 result.wait(),
508 Err(core::Error {
509 code: core::ErrorCode::ServerError(-32091),
510 message: "Subscription rejected".into(),
511 data: None,
512 })
513 );
514 }
515}