bitconch_jsonrpc_ws_server/
session.rs1use std;
2use std::sync::{atomic, Arc};
3
4use core;
5use core::futures::{Async, Future, Poll};
6use core::futures::sync::oneshot;
7
8use parking_lot::Mutex;
9use slab::Slab;
10
11use server_utils::Pattern;
12use server_utils::cors::Origin;
13use server_utils::hosts::Host;
14use server_utils::session::{SessionId, SessionStats};
15use server_utils::tokio::runtime::TaskExecutor;
16use ws;
17
18use error;
19use metadata;
20
21pub trait RequestMiddleware: Send + Sync + 'static {
25 fn process(&self, req: &ws::Request) -> MiddlewareAction;
27}
28
29impl<F> RequestMiddleware for F where
30 F: Fn(&ws::Request) -> Option<ws::Response> + Send + Sync + 'static,
31{
32 fn process(&self, req: &ws::Request) -> MiddlewareAction {
33 (*self)(req).into()
34 }
35}
36
37#[derive(Debug)]
39pub enum MiddlewareAction {
40 Proceed,
42 Respond {
44 response: ws::Response,
46 validate_origin: bool,
48 validate_hosts: bool,
50 },
51}
52
53impl MiddlewareAction {
54 fn should_verify_origin(&self) -> bool {
55 use self::MiddlewareAction::*;
56
57 match *self {
58 Proceed => true,
59 Respond { validate_origin, .. } => validate_origin,
60 }
61 }
62
63 fn should_verify_hosts(&self) -> bool {
64 use self::MiddlewareAction::*;
65
66 match *self {
67 Proceed => true,
68 Respond { validate_hosts, .. } => validate_hosts,
69 }
70 }
71}
72
73impl From<Option<ws::Response>> for MiddlewareAction {
74 fn from(opt: Option<ws::Response>) -> Self {
75 match opt {
76 Some(res) => MiddlewareAction::Respond { response: res, validate_origin: true, validate_hosts: true },
77 None => MiddlewareAction::Proceed,
78 }
79 }
80}
81
82type TaskSlab = Mutex<Slab<Option<oneshot::Sender<()>>>>;
84
85#[derive(Debug)]
88struct LivenessPoll {
89 task_slab: Arc<TaskSlab>,
90 slab_handle: usize,
91 rx: oneshot::Receiver<()>,
92}
93
94impl LivenessPoll {
95 fn create(task_slab: Arc<TaskSlab>) -> Self {
96 const INITIAL_SIZE: usize = 4;
97
98 let (index, rx) = {
99 let mut task_slab = task_slab.lock();
100 if task_slab.len() == task_slab.capacity() {
101 let reserve = ::std::cmp::max(task_slab.capacity(), INITIAL_SIZE);
104 task_slab.reserve_exact(reserve);
105 }
106
107 let (tx, rx) = oneshot::channel();
108 let index = task_slab.insert(Some(tx));
109 (index, rx)
110 };
111
112 LivenessPoll { task_slab: task_slab, slab_handle: index, rx: rx }
113 }
114}
115
116impl Future for LivenessPoll {
117 type Item = ();
118 type Error = ();
119
120 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
121 match self.rx.poll() {
125 Ok(Async::Ready(_)) | Err(_) => Ok(Async::Ready(())),
126 Ok(Async::NotReady) => Ok(Async::NotReady),
127 }
128 }
129}
130
131impl Drop for LivenessPoll {
132 fn drop(&mut self) {
133 self.task_slab.lock().remove(self.slab_handle);
135 }
136}
137
138pub struct Session<M: core::Metadata, S: core::Middleware<M>> {
139 active: Arc<atomic::AtomicBool>,
140 context: metadata::RequestContext,
141 handler: Arc<core::MetaIoHandler<M, S>>,
142 meta_extractor: Arc<metadata::MetaExtractor<M>>,
143 allowed_origins: Option<Vec<Origin>>,
144 allowed_hosts: Option<Vec<Host>>,
145 request_middleware: Option<Arc<RequestMiddleware>>,
146 stats: Option<Arc<SessionStats>>,
147 metadata: Option<M>,
148 executor: TaskExecutor,
149 task_slab: Arc<TaskSlab>,
150}
151
152impl<M: core::Metadata, S: core::Middleware<M>> Drop for Session<M, S> {
153 fn drop(&mut self) {
154 self.active.store(false, atomic::Ordering::SeqCst);
155 self.stats.as_ref().map(|stats| stats.close_session(self.context.session_id));
156
157 for (_index, task) in self.task_slab.lock().iter_mut() {
159 if let Some(task) = task.take() {
160 let _ = task.send(());
161 }
162 }
163 }
164}
165
166impl<M: core::Metadata, S: core::Middleware<M>> Session<M, S> {
167 fn read_origin<'a>(&self, req: &'a ws::Request) -> Option<&'a [u8]> {
168 req.header("origin").map(|x| &x[..])
169 }
170
171 fn verify_origin(&self, origin: Option<&[u8]>) -> Option<ws::Response> {
172 if !header_is_allowed(&self.allowed_origins, origin) {
173 warn!(
174 "Blocked connection to WebSockets server from untrusted origin: {:?}",
175 origin.and_then(|s| std::str::from_utf8(s).ok()),
176 );
177 Some(forbidden("URL Blocked", "Connection Origin has been rejected."))
178 } else {
179 None
180 }
181 }
182
183 fn verify_host(&self, req: &ws::Request) -> Option<ws::Response> {
184 let host = req.header("host").map(|x| &x[..]);
185 if !header_is_allowed(&self.allowed_hosts, host) {
186 warn!(
187 "Blocked connection to WebSockets server with untrusted host: {:?}",
188 host.and_then(|s| std::str::from_utf8(s).ok()),
189 );
190 Some(forbidden("URL Blocked", "Connection Host has been rejected."))
191 } else {
192 None
193 }
194 }
195}
196
197impl<M: core::Metadata, S: core::Middleware<M>> ws::Handler for Session<M, S> {
198 fn on_request(&mut self, req: &ws::Request) -> ws::Result<ws::Response> {
199 let action = if let Some(ref middleware) = self.request_middleware {
201 middleware.process(req)
202 } else {
203 MiddlewareAction::Proceed
204 };
205
206 let origin = self.read_origin(req);
207 if action.should_verify_origin() {
208 if let Some(response) = self.verify_origin(origin) {
210 return Ok(response);
211 }
212 }
213
214 if action.should_verify_hosts() {
215 if let Some(response) = self.verify_host(req) {
217 return Ok(response);
218 }
219 }
220
221 self.context.origin = origin.and_then(|origin| ::std::str::from_utf8(origin).ok()).map(Into::into);
222 self.context.protocols = req.protocols().ok()
223 .map(|protos| protos.into_iter().map(Into::into).collect())
224 .unwrap_or_else(Vec::new);
225 self.metadata = Some(self.meta_extractor.extract(&self.context));
226
227 match action {
228 MiddlewareAction::Proceed => ws::Response::from_request(req).map(|mut res| {
229 if let Some(protocol) = self.context.protocols.get(0) {
230 res.set_protocol(protocol);
231 }
232 res
233 }),
234 MiddlewareAction::Respond { response, .. } => Ok(response),
235 }
236 }
237
238 fn on_message(&mut self, msg: ws::Message) -> ws::Result<()> {
239 let req = msg.as_text()?;
240 let out = self.context.out.clone();
241 let metadata = self.metadata.clone().expect("Metadata is always set in on_request; qed");
242
243 let poll_liveness = LivenessPoll::create(self.task_slab.clone());
247
248 let active_lock = self.active.clone();
249 let future = self.handler.handle_request(req, metadata)
250 .map(move |response| {
251 if !active_lock.load(atomic::Ordering::SeqCst) {
252 return;
253 }
254 if let Some(result) = response {
255 let res = out.send(result);
256 match res {
257 Err(error::Error(error::ErrorKind::ConnectionClosed, _)) => {
258 active_lock.store(false, atomic::Ordering::SeqCst);
259 },
260 Err(e) => {
261 warn!("Error while sending response: {:?}", e);
262 },
263 _ => {},
264 }
265 }
266 })
267 .select(poll_liveness)
268 .map(|_| ())
269 .map_err(|_| ());
270
271 self.executor.spawn(future);
272
273 Ok(())
274 }
275}
276
277pub struct Factory<M: core::Metadata, S: core::Middleware<M>> {
278 session_id: SessionId,
279 handler: Arc<core::MetaIoHandler<M, S>>,
280 meta_extractor: Arc<metadata::MetaExtractor<M>>,
281 allowed_origins: Option<Vec<Origin>>,
282 allowed_hosts: Option<Vec<Host>>,
283 request_middleware: Option<Arc<RequestMiddleware>>,
284 stats: Option<Arc<SessionStats>>,
285 executor: TaskExecutor,
286}
287
288impl<M: core::Metadata, S: core::Middleware<M>> Factory<M, S> {
289 pub fn new(
290 handler: Arc<core::MetaIoHandler<M, S>>,
291 meta_extractor: Arc<metadata::MetaExtractor<M>>,
292 allowed_origins: Option<Vec<Origin>>,
293 allowed_hosts: Option<Vec<Host>>,
294 request_middleware: Option<Arc<RequestMiddleware>>,
295 stats: Option<Arc<SessionStats>>,
296 executor: TaskExecutor,
297 ) -> Self {
298 Factory {
299 session_id: 0,
300 handler: handler,
301 meta_extractor: meta_extractor,
302 allowed_origins: allowed_origins,
303 allowed_hosts: allowed_hosts,
304 request_middleware: request_middleware,
305 stats: stats,
306 executor: executor,
307 }
308 }
309}
310
311impl<M: core::Metadata, S: core::Middleware<M>> ws::Factory for Factory<M, S> {
312 type Handler = Session<M, S>;
313
314 fn connection_made(&mut self, sender: ws::Sender) -> Self::Handler {
315 self.session_id += 1;
316 self.stats.as_ref().map(|stats| stats.open_session(self.session_id));
317 let active = Arc::new(atomic::AtomicBool::new(true));
318
319 Session {
320 active: active.clone(),
321 context: metadata::RequestContext {
322 session_id: self.session_id,
323 origin: None,
324 protocols: Vec::new(),
325 out: metadata::Sender::new(sender, active),
326 executor: self.executor.clone(),
327 },
328 handler: self.handler.clone(),
329 meta_extractor: self.meta_extractor.clone(),
330 allowed_origins: self.allowed_origins.clone(),
331 allowed_hosts: self.allowed_hosts.clone(),
332 stats: self.stats.clone(),
333 request_middleware: self.request_middleware.clone(),
334 metadata: None,
335 executor: self.executor.clone(),
336 task_slab: Arc::new(Mutex::new(Slab::with_capacity(0))),
337 }
338 }
339}
340
341fn header_is_allowed<T>(allowed: &Option<Vec<T>>, header: Option<&[u8]>) -> bool where
342 T: Pattern,
343{
344 let header = header.map(std::str::from_utf8);
345
346 match (header, allowed.as_ref()) {
347 (None, _) => true,
349 (_, None) => true,
351 (Some(Ok(val)), Some(values)) => {
353 for v in values {
354 if v.matches(val) {
355 return true
356 }
357 }
358 false
359 },
360 _ => false,
362 }
363}
364
365
366fn forbidden(title: &str, message: &str) -> ws::Response {
367 let mut forbidden = ws::Response::new(403, "Forbidden");
368 forbidden.set_body(
369 format!("{}\n{}\n", title, message).as_bytes()
370 );
371 {
372 let headers = forbidden.headers_mut();
373 headers.push(("Connection".to_owned(), "close".as_bytes().to_vec()));
374 }
375 forbidden
376}