bitconch_jsonrpc_ws_server/
session.rs

1use std;
2use std::sync::{atomic, Arc};
3
4use core;
5use core::futures::{Async, Future, Poll};
6use core::futures::sync::oneshot;
7
8use parking_lot::Mutex;
9use slab::Slab;
10
11use server_utils::Pattern;
12use server_utils::cors::Origin;
13use server_utils::hosts::Host;
14use server_utils::session::{SessionId, SessionStats};
15use server_utils::tokio::runtime::TaskExecutor;
16use ws;
17
18use error;
19use metadata;
20
21/// Middleware to intercept server requests.
22/// You can either terminate the request (by returning a response)
23/// or just proceed with standard JSON-RPC handling.
24pub trait RequestMiddleware: Send + Sync + 'static {
25	/// Process a request and decide what to do next.
26	fn process(&self, req: &ws::Request) -> MiddlewareAction;
27}
28
29impl<F> RequestMiddleware for F where
30	F: Fn(&ws::Request) -> Option<ws::Response> + Send + Sync + 'static,
31{
32	fn process(&self, req: &ws::Request) -> MiddlewareAction {
33		(*self)(req).into()
34	}
35}
36
37/// Request middleware action
38#[derive(Debug)]
39pub enum MiddlewareAction {
40	/// Proceed with standard JSON-RPC behaviour.
41	Proceed,
42	/// Terminate the request and return a response.
43	Respond {
44		/// Response to return
45		response: ws::Response,
46		/// Should origin be validated before returning the response?
47		validate_origin: bool,
48		/// Should hosts be validated before returning the response?
49		validate_hosts: bool,
50	},
51}
52
53impl MiddlewareAction {
54	fn should_verify_origin(&self) -> bool {
55		use self::MiddlewareAction::*;
56
57		match *self {
58			Proceed => true,
59			Respond { validate_origin, .. } => validate_origin,
60		}
61	}
62
63	fn should_verify_hosts(&self) -> bool {
64		use self::MiddlewareAction::*;
65
66		match *self {
67			Proceed => true,
68			Respond { validate_hosts, .. } => validate_hosts,
69		}
70	}
71}
72
73impl From<Option<ws::Response>> for MiddlewareAction {
74	fn from(opt: Option<ws::Response>) -> Self {
75		match opt {
76			Some(res) => MiddlewareAction::Respond { response: res, validate_origin: true, validate_hosts: true },
77			None => MiddlewareAction::Proceed,
78		}
79	}
80}
81
82// the slab is only inserted into when live.
83type TaskSlab = Mutex<Slab<Option<oneshot::Sender<()>>>>;
84
85// future for checking session liveness.
86// this returns `NotReady` until the session it corresponds to is dropped.
87#[derive(Debug)]
88struct LivenessPoll {
89	task_slab: Arc<TaskSlab>,
90	slab_handle: usize,
91	rx: oneshot::Receiver<()>,
92}
93
94impl LivenessPoll {
95	fn create(task_slab: Arc<TaskSlab>) -> Self {
96		const INITIAL_SIZE: usize = 4;
97
98		let (index, rx) = {
99			let mut task_slab = task_slab.lock();
100			if task_slab.len() == task_slab.capacity() {
101				// grow the size if necessary.
102				// we don't expect this to get so big as to overflow.
103				let reserve = ::std::cmp::max(task_slab.capacity(), INITIAL_SIZE);
104				task_slab.reserve_exact(reserve);
105			}
106
107			let (tx, rx) = oneshot::channel();
108			let index = task_slab.insert(Some(tx));
109			(index, rx)
110		};
111
112		LivenessPoll { task_slab: task_slab, slab_handle: index, rx: rx }
113	}
114}
115
116impl Future for LivenessPoll {
117	type Item = ();
118	type Error = ();
119
120	fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
121		// if the future resolves ok then we've been signalled to return.
122		// it should never be cancelled, but if it was the session definitely
123		// isn't live.
124		match self.rx.poll() {
125			Ok(Async::Ready(_)) | Err(_) => Ok(Async::Ready(())),
126			Ok(Async::NotReady) => Ok(Async::NotReady),
127		}
128	}
129}
130
131impl Drop for LivenessPoll {
132	fn drop(&mut self) {
133		// remove the entry from the slab if it hasn't been destroyed yet.
134		self.task_slab.lock().remove(self.slab_handle);
135	}
136}
137
138pub struct Session<M: core::Metadata, S: core::Middleware<M>> {
139	active: Arc<atomic::AtomicBool>,
140	context: metadata::RequestContext,
141	handler: Arc<core::MetaIoHandler<M, S>>,
142	meta_extractor: Arc<metadata::MetaExtractor<M>>,
143	allowed_origins: Option<Vec<Origin>>,
144	allowed_hosts: Option<Vec<Host>>,
145	request_middleware: Option<Arc<RequestMiddleware>>,
146	stats: Option<Arc<SessionStats>>,
147	metadata: Option<M>,
148	executor: TaskExecutor,
149	task_slab: Arc<TaskSlab>,
150}
151
152impl<M: core::Metadata, S: core::Middleware<M>> Drop for Session<M, S> {
153	fn drop(&mut self) {
154		self.active.store(false, atomic::Ordering::SeqCst);
155		self.stats.as_ref().map(|stats| stats.close_session(self.context.session_id));
156
157		// signal to all still-live tasks that the session has been dropped.
158		for (_index, task) in self.task_slab.lock().iter_mut() {
159			if let Some(task) = task.take() {
160				let _ = task.send(());
161			}
162		}
163	}
164}
165
166impl<M: core::Metadata, S: core::Middleware<M>> Session<M, S> {
167	fn read_origin<'a>(&self, req: &'a ws::Request) -> Option<&'a [u8]> {
168		req.header("origin").map(|x| &x[..])
169	}
170
171	fn verify_origin(&self, origin: Option<&[u8]>) -> Option<ws::Response> {
172		if !header_is_allowed(&self.allowed_origins, origin) {
173			warn!(
174				"Blocked connection to WebSockets server from untrusted origin: {:?}",
175				origin.and_then(|s| std::str::from_utf8(s).ok()),
176			);
177			Some(forbidden("URL Blocked", "Connection Origin has been rejected."))
178		} else {
179			None
180		}
181	}
182
183	fn verify_host(&self, req: &ws::Request) -> Option<ws::Response> {
184		let host = req.header("host").map(|x| &x[..]);
185		if !header_is_allowed(&self.allowed_hosts, host) {
186			warn!(
187				"Blocked connection to WebSockets server with untrusted host: {:?}",
188				host.and_then(|s| std::str::from_utf8(s).ok()),
189			);
190			Some(forbidden("URL Blocked", "Connection Host has been rejected."))
191		} else {
192			None
193		}
194	}
195}
196
197impl<M: core::Metadata, S: core::Middleware<M>> ws::Handler for Session<M, S> {
198	fn on_request(&mut self, req: &ws::Request) -> ws::Result<ws::Response> {
199		// Run middleware
200		let action = if let Some(ref middleware) = self.request_middleware {
201			middleware.process(req)
202		} else {
203			MiddlewareAction::Proceed
204		};
205
206		let origin = self.read_origin(req);
207		if action.should_verify_origin() {
208			// Verify request origin.
209			if let Some(response) = self.verify_origin(origin) {
210				return Ok(response);
211			}
212		}
213
214		if action.should_verify_hosts() {
215			// Verify host header.
216			if let Some(response) = self.verify_host(req) {
217				return Ok(response);
218			}
219		}
220
221		self.context.origin = origin.and_then(|origin| ::std::str::from_utf8(origin).ok()).map(Into::into);
222		self.context.protocols = req.protocols().ok()
223			.map(|protos| protos.into_iter().map(Into::into).collect())
224			.unwrap_or_else(Vec::new);
225		self.metadata = Some(self.meta_extractor.extract(&self.context));
226
227		match action {
228			MiddlewareAction::Proceed => ws::Response::from_request(req).map(|mut res| {
229				if let Some(protocol) = self.context.protocols.get(0) {
230					res.set_protocol(protocol);
231				}
232				res
233			}),
234			MiddlewareAction::Respond { response, .. } => Ok(response),
235		}
236	}
237
238	fn on_message(&mut self, msg: ws::Message) -> ws::Result<()> {
239		let req = msg.as_text()?;
240		let out = self.context.out.clone();
241		let metadata = self.metadata.clone().expect("Metadata is always set in on_request; qed");
242
243		// TODO: creation requires allocating a `oneshot` channel and acquiring a
244		// mutex. we could alternatively do this lazily upon first poll if
245		// it becomes a bottleneck.
246		let poll_liveness = LivenessPoll::create(self.task_slab.clone());
247
248		let active_lock = self.active.clone();
249		let future = self.handler.handle_request(req, metadata)
250			.map(move |response| {
251				if !active_lock.load(atomic::Ordering::SeqCst) {
252					return;
253				}
254				if let Some(result) = response {
255					let res = out.send(result);
256					match res {
257						Err(error::Error(error::ErrorKind::ConnectionClosed, _)) => {
258							active_lock.store(false, atomic::Ordering::SeqCst);
259						},
260						Err(e) => {
261							warn!("Error while sending response: {:?}", e);
262						},
263						_ => {},
264					}
265				}
266			})
267			.select(poll_liveness)
268			.map(|_| ())
269			.map_err(|_| ());
270
271		self.executor.spawn(future);
272
273		Ok(())
274	}
275}
276
277pub struct Factory<M: core::Metadata, S: core::Middleware<M>> {
278	session_id: SessionId,
279	handler: Arc<core::MetaIoHandler<M, S>>,
280	meta_extractor: Arc<metadata::MetaExtractor<M>>,
281	allowed_origins: Option<Vec<Origin>>,
282	allowed_hosts: Option<Vec<Host>>,
283	request_middleware: Option<Arc<RequestMiddleware>>,
284	stats: Option<Arc<SessionStats>>,
285	executor: TaskExecutor,
286}
287
288impl<M: core::Metadata, S: core::Middleware<M>> Factory<M, S> {
289	pub fn new(
290		handler: Arc<core::MetaIoHandler<M, S>>,
291		meta_extractor: Arc<metadata::MetaExtractor<M>>,
292		allowed_origins: Option<Vec<Origin>>,
293		allowed_hosts: Option<Vec<Host>>,
294		request_middleware: Option<Arc<RequestMiddleware>>,
295		stats: Option<Arc<SessionStats>>,
296		executor: TaskExecutor,
297	) -> Self {
298		Factory {
299			session_id: 0,
300			handler: handler,
301			meta_extractor: meta_extractor,
302			allowed_origins: allowed_origins,
303			allowed_hosts: allowed_hosts,
304			request_middleware: request_middleware,
305			stats: stats,
306			executor: executor,
307		}
308	}
309}
310
311impl<M: core::Metadata, S: core::Middleware<M>> ws::Factory for Factory<M, S> {
312	type Handler = Session<M, S>;
313
314	fn connection_made(&mut self, sender: ws::Sender) -> Self::Handler {
315		self.session_id += 1;
316		self.stats.as_ref().map(|stats| stats.open_session(self.session_id));
317		let active = Arc::new(atomic::AtomicBool::new(true));
318
319		Session {
320			active: active.clone(),
321			context: metadata::RequestContext {
322				session_id: self.session_id,
323				origin: None,
324				protocols: Vec::new(),
325				out: metadata::Sender::new(sender, active),
326				executor: self.executor.clone(),
327			},
328			handler: self.handler.clone(),
329			meta_extractor: self.meta_extractor.clone(),
330			allowed_origins: self.allowed_origins.clone(),
331			allowed_hosts: self.allowed_hosts.clone(),
332			stats: self.stats.clone(),
333			request_middleware: self.request_middleware.clone(),
334			metadata: None,
335			executor: self.executor.clone(),
336			task_slab: Arc::new(Mutex::new(Slab::with_capacity(0))),
337		}
338	}
339}
340
341fn header_is_allowed<T>(allowed: &Option<Vec<T>>, header: Option<&[u8]>) -> bool where
342	T: Pattern,
343{
344	let header = header.map(std::str::from_utf8);
345
346	match (header, allowed.as_ref()) {
347		// Always allow if Origin/Host is not specified
348		(None, _) => true,
349		// Always allow if Origin/Host validation is disabled
350		(_, None) => true,
351		// Validate Origin
352		(Some(Ok(val)), Some(values)) => {
353			for v in values {
354				if v.matches(val) {
355					return true
356				}
357			}
358			false
359		},
360		// Disallow in other cases
361		_ => false,
362	}
363}
364
365
366fn forbidden(title: &str, message: &str) -> ws::Response {
367	let mut forbidden = ws::Response::new(403, "Forbidden");
368	forbidden.set_body(
369		format!("{}\n{}\n", title, message).as_bytes()
370	);
371	{
372		let headers = forbidden.headers_mut();
373		headers.push(("Connection".to_owned(), "close".as_bytes().to_vec()));
374	}
375	forbidden
376}