jsonrpc_ws_server/
session.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::{atomic, Arc};
4use std::task::{Context, Poll};
5
6use crate::core;
7use futures::channel::oneshot;
8use futures::future;
9use futures::FutureExt;
10
11use parking_lot::Mutex;
12use slab::Slab;
13
14use crate::server_utils::cors::Origin;
15use crate::server_utils::hosts::Host;
16use crate::server_utils::reactor::TaskExecutor;
17use crate::server_utils::session::{SessionId, SessionStats};
18use crate::server_utils::Pattern;
19use crate::ws;
20
21use crate::error;
22use crate::metadata;
23
24/// Middleware to intercept server requests.
25/// You can either terminate the request (by returning a response)
26/// or just proceed with standard JSON-RPC handling.
27pub trait RequestMiddleware: Send + Sync + 'static {
28	/// Process a request and decide what to do next.
29	fn process(&self, req: &ws::Request) -> MiddlewareAction;
30}
31
32impl<F> RequestMiddleware for F
33where
34	F: Fn(&ws::Request) -> Option<ws::Response> + Send + Sync + 'static,
35{
36	fn process(&self, req: &ws::Request) -> MiddlewareAction {
37		(*self)(req).into()
38	}
39}
40
41/// Request middleware action
42#[derive(Debug)]
43pub enum MiddlewareAction {
44	/// Proceed with standard JSON-RPC behaviour.
45	Proceed,
46	/// Terminate the request and return a response.
47	Respond {
48		/// Response to return
49		response: ws::Response,
50		/// Should origin be validated before returning the response?
51		validate_origin: bool,
52		/// Should hosts be validated before returning the response?
53		validate_hosts: bool,
54	},
55}
56
57impl MiddlewareAction {
58	fn should_verify_origin(&self) -> bool {
59		use self::MiddlewareAction::*;
60
61		match *self {
62			Proceed => true,
63			Respond { validate_origin, .. } => validate_origin,
64		}
65	}
66
67	fn should_verify_hosts(&self) -> bool {
68		use self::MiddlewareAction::*;
69
70		match *self {
71			Proceed => true,
72			Respond { validate_hosts, .. } => validate_hosts,
73		}
74	}
75}
76
77impl From<Option<ws::Response>> for MiddlewareAction {
78	fn from(opt: Option<ws::Response>) -> Self {
79		match opt {
80			Some(res) => MiddlewareAction::Respond {
81				response: res,
82				validate_origin: true,
83				validate_hosts: true,
84			},
85			None => MiddlewareAction::Proceed,
86		}
87	}
88}
89
90// the slab is only inserted into when live.
91type TaskSlab = Mutex<Slab<Option<oneshot::Sender<()>>>>;
92
93// future for checking session liveness.
94// this returns `NotReady` until the session it corresponds to is dropped.
95#[derive(Debug)]
96struct LivenessPoll {
97	task_slab: Arc<TaskSlab>,
98	slab_handle: usize,
99	rx: oneshot::Receiver<()>,
100}
101
102impl LivenessPoll {
103	fn create(task_slab: Arc<TaskSlab>) -> Self {
104		const INITIAL_SIZE: usize = 4;
105
106		let (index, rx) = {
107			let mut task_slab = task_slab.lock();
108			if task_slab.len() == task_slab.capacity() {
109				// grow the size if necessary.
110				// we don't expect this to get so big as to overflow.
111				let reserve = ::std::cmp::max(task_slab.capacity(), INITIAL_SIZE);
112				task_slab.reserve_exact(reserve);
113			}
114
115			let (tx, rx) = oneshot::channel();
116			let index = task_slab.insert(Some(tx));
117			(index, rx)
118		};
119
120		LivenessPoll {
121			task_slab,
122			slab_handle: index,
123			rx,
124		}
125	}
126}
127
128impl Future for LivenessPoll {
129	type Output = ();
130
131	fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
132		let this = Pin::into_inner(self);
133		// if the future resolves ok then we've been signalled to return.
134		// it should never be cancelled, but if it was the session definitely
135		// isn't live.
136		match Pin::new(&mut this.rx).poll(cx) {
137			Poll::Ready(_) => Poll::Ready(()),
138			Poll::Pending => Poll::Pending,
139		}
140	}
141}
142
143impl Drop for LivenessPoll {
144	fn drop(&mut self) {
145		// remove the entry from the slab if it hasn't been destroyed yet.
146		self.task_slab.lock().remove(self.slab_handle);
147	}
148}
149
150pub struct Session<M: core::Metadata, S: core::Middleware<M>> {
151	active: Arc<atomic::AtomicBool>,
152	context: metadata::RequestContext,
153	handler: Arc<core::MetaIoHandler<M, S>>,
154	meta_extractor: Arc<dyn metadata::MetaExtractor<M>>,
155	allowed_origins: Option<Vec<Origin>>,
156	allowed_hosts: Option<Vec<Host>>,
157	request_middleware: Option<Arc<dyn RequestMiddleware>>,
158	stats: Option<Arc<dyn SessionStats>>,
159	metadata: Option<M>,
160	executor: TaskExecutor,
161	task_slab: Arc<TaskSlab>,
162}
163
164impl<M: core::Metadata, S: core::Middleware<M>> Drop for Session<M, S> {
165	fn drop(&mut self) {
166		self.active.store(false, atomic::Ordering::SeqCst);
167		if let Some(stats) = self.stats.as_ref() {
168			stats.close_session(self.context.session_id)
169		}
170
171		// signal to all still-live tasks that the session has been dropped.
172		for (_index, task) in self.task_slab.lock().iter_mut() {
173			if let Some(task) = task.take() {
174				let _ = task.send(());
175			}
176		}
177	}
178}
179
180impl<M: core::Metadata, S: core::Middleware<M>> Session<M, S> {
181	fn read_origin<'a>(&self, req: &'a ws::Request) -> Option<&'a [u8]> {
182		req.header("origin").map(|x| &x[..])
183	}
184
185	fn verify_origin(&self, origin: Option<&[u8]>) -> Option<ws::Response> {
186		if !header_is_allowed(&self.allowed_origins, origin) {
187			warn!(
188				"Blocked connection to WebSockets server from untrusted origin: {:?}",
189				origin.and_then(|s| std::str::from_utf8(s).ok()),
190			);
191			Some(forbidden("URL Blocked", "Connection Origin has been rejected."))
192		} else {
193			None
194		}
195	}
196
197	fn verify_host(&self, req: &ws::Request) -> Option<ws::Response> {
198		let host = req.header("host").map(|x| &x[..]);
199		if !header_is_allowed(&self.allowed_hosts, host) {
200			warn!(
201				"Blocked connection to WebSockets server with untrusted host: {:?}",
202				host.and_then(|s| std::str::from_utf8(s).ok()),
203			);
204			Some(forbidden("URL Blocked", "Connection Host has been rejected."))
205		} else {
206			None
207		}
208	}
209}
210
211impl<M: core::Metadata, S: core::Middleware<M>> ws::Handler for Session<M, S>
212where
213	S::Future: Unpin,
214	S::CallFuture: Unpin,
215{
216	fn on_request(&mut self, req: &ws::Request) -> ws::Result<ws::Response> {
217		// Run middleware
218		let action = if let Some(ref middleware) = self.request_middleware {
219			middleware.process(req)
220		} else {
221			MiddlewareAction::Proceed
222		};
223
224		let origin = self.read_origin(req);
225		if action.should_verify_origin() {
226			// Verify request origin.
227			if let Some(response) = self.verify_origin(origin) {
228				return Ok(response);
229			}
230		}
231
232		if action.should_verify_hosts() {
233			// Verify host header.
234			if let Some(response) = self.verify_host(req) {
235				return Ok(response);
236			}
237		}
238
239		self.context.origin = origin
240			.and_then(|origin| ::std::str::from_utf8(origin).ok())
241			.map(Into::into);
242		self.context.protocols = req
243			.protocols()
244			.ok()
245			.map(|protos| protos.into_iter().map(Into::into).collect())
246			.unwrap_or_else(Vec::new);
247		self.metadata = Some(self.meta_extractor.extract(&self.context));
248
249		match action {
250			MiddlewareAction::Proceed => ws::Response::from_request(req).map(|mut res| {
251				if let Some(protocol) = self.context.protocols.get(0) {
252					res.set_protocol(protocol);
253				}
254				res
255			}),
256			MiddlewareAction::Respond { response, .. } => Ok(response),
257		}
258	}
259
260	fn on_message(&mut self, msg: ws::Message) -> ws::Result<()> {
261		let req = msg.as_text()?;
262		let out = self.context.out.clone();
263		let metadata = self
264			.metadata
265			.clone()
266			.expect("Metadata is always set in on_request; qed");
267
268		// TODO: creation requires allocating a `oneshot` channel and acquiring a
269		// mutex. we could alternatively do this lazily upon first poll if
270		// it becomes a bottleneck.
271		let poll_liveness = LivenessPoll::create(self.task_slab.clone());
272
273		let active_lock = self.active.clone();
274		let response = self.handler.handle_request(req, metadata);
275
276		let future = response.map(move |response| {
277			if !active_lock.load(atomic::Ordering::SeqCst) {
278				return;
279			}
280			if let Some(result) = response {
281				let res = out.send(result);
282				match res {
283					Err(error::Error::ConnectionClosed) => {
284						active_lock.store(false, atomic::Ordering::SeqCst);
285					}
286					Err(e) => {
287						warn!("Error while sending response: {:?}", e);
288					}
289					_ => {}
290				}
291			}
292		});
293
294		let future = future::select(future, poll_liveness);
295		self.executor.spawn(future);
296
297		Ok(())
298	}
299}
300
301pub struct Factory<M: core::Metadata, S: core::Middleware<M>> {
302	session_id: SessionId,
303	handler: Arc<core::MetaIoHandler<M, S>>,
304	meta_extractor: Arc<dyn metadata::MetaExtractor<M>>,
305	allowed_origins: Option<Vec<Origin>>,
306	allowed_hosts: Option<Vec<Host>>,
307	request_middleware: Option<Arc<dyn RequestMiddleware>>,
308	stats: Option<Arc<dyn SessionStats>>,
309	executor: TaskExecutor,
310}
311
312impl<M: core::Metadata, S: core::Middleware<M>> Factory<M, S> {
313	pub fn new(
314		handler: Arc<core::MetaIoHandler<M, S>>,
315		meta_extractor: Arc<dyn metadata::MetaExtractor<M>>,
316		allowed_origins: Option<Vec<Origin>>,
317		allowed_hosts: Option<Vec<Host>>,
318		request_middleware: Option<Arc<dyn RequestMiddleware>>,
319		stats: Option<Arc<dyn SessionStats>>,
320		executor: TaskExecutor,
321	) -> Self {
322		Factory {
323			session_id: 0,
324			handler,
325			meta_extractor,
326			allowed_origins,
327			allowed_hosts,
328			request_middleware,
329			stats,
330			executor,
331		}
332	}
333}
334
335impl<M: core::Metadata, S: core::Middleware<M>> ws::Factory for Factory<M, S>
336where
337	S::Future: Unpin,
338	S::CallFuture: Unpin,
339{
340	type Handler = Session<M, S>;
341
342	fn connection_made(&mut self, sender: ws::Sender) -> Self::Handler {
343		self.session_id += 1;
344		if let Some(executor) = self.stats.as_ref() {
345			executor.open_session(self.session_id)
346		}
347		let active = Arc::new(atomic::AtomicBool::new(true));
348
349		Session {
350			active: active.clone(),
351			context: metadata::RequestContext {
352				session_id: self.session_id,
353				origin: None,
354				protocols: Vec::new(),
355				out: metadata::Sender::new(sender, active),
356				executor: self.executor.clone(),
357			},
358			handler: self.handler.clone(),
359			meta_extractor: self.meta_extractor.clone(),
360			allowed_origins: self.allowed_origins.clone(),
361			allowed_hosts: self.allowed_hosts.clone(),
362			stats: self.stats.clone(),
363			request_middleware: self.request_middleware.clone(),
364			metadata: None,
365			executor: self.executor.clone(),
366			task_slab: Arc::new(Mutex::new(Slab::with_capacity(0))),
367		}
368	}
369}
370
371fn header_is_allowed<T>(allowed: &Option<Vec<T>>, header: Option<&[u8]>) -> bool
372where
373	T: Pattern,
374{
375	let header = header.map(std::str::from_utf8);
376
377	match (header, allowed.as_ref()) {
378		// Always allow if Origin/Host is not specified
379		(None, _) => true,
380		// Always allow if Origin/Host validation is disabled
381		(_, None) => true,
382		// Validate Origin
383		(Some(Ok(val)), Some(values)) => {
384			for v in values {
385				if v.matches(val) {
386					return true;
387				}
388			}
389			false
390		}
391		// Disallow in other cases
392		_ => false,
393	}
394}
395
396fn forbidden(title: &str, message: &str) -> ws::Response {
397	let mut forbidden = ws::Response::new(403, "Forbidden", format!("{}\n{}\n", title, message).into_bytes());
398	{
399		let headers = forbidden.headers_mut();
400		headers.push(("Connection".to_owned(), b"close".to_vec()));
401	}
402	forbidden
403}