1use std::future::Future;
2use std::pin::Pin;
3use std::sync::{atomic, Arc};
4use std::task::{Context, Poll};
5
6use crate::core;
7use futures::channel::oneshot;
8use futures::future;
9use futures::FutureExt;
10
11use parking_lot::Mutex;
12use slab::Slab;
13
14use crate::server_utils::cors::Origin;
15use crate::server_utils::hosts::Host;
16use crate::server_utils::reactor::TaskExecutor;
17use crate::server_utils::session::{SessionId, SessionStats};
18use crate::server_utils::Pattern;
19use crate::ws;
20
21use crate::error;
22use crate::metadata;
23
24pub trait RequestMiddleware: Send + Sync + 'static {
28 fn process(&self, req: &ws::Request) -> MiddlewareAction;
30}
31
32impl<F> RequestMiddleware for F
33where
34 F: Fn(&ws::Request) -> Option<ws::Response> + Send + Sync + 'static,
35{
36 fn process(&self, req: &ws::Request) -> MiddlewareAction {
37 (*self)(req).into()
38 }
39}
40
41#[derive(Debug)]
43pub enum MiddlewareAction {
44 Proceed,
46 Respond {
48 response: ws::Response,
50 validate_origin: bool,
52 validate_hosts: bool,
54 },
55}
56
57impl MiddlewareAction {
58 fn should_verify_origin(&self) -> bool {
59 use self::MiddlewareAction::*;
60
61 match *self {
62 Proceed => true,
63 Respond { validate_origin, .. } => validate_origin,
64 }
65 }
66
67 fn should_verify_hosts(&self) -> bool {
68 use self::MiddlewareAction::*;
69
70 match *self {
71 Proceed => true,
72 Respond { validate_hosts, .. } => validate_hosts,
73 }
74 }
75}
76
77impl From<Option<ws::Response>> for MiddlewareAction {
78 fn from(opt: Option<ws::Response>) -> Self {
79 match opt {
80 Some(res) => MiddlewareAction::Respond {
81 response: res,
82 validate_origin: true,
83 validate_hosts: true,
84 },
85 None => MiddlewareAction::Proceed,
86 }
87 }
88}
89
90type TaskSlab = Mutex<Slab<Option<oneshot::Sender<()>>>>;
92
93#[derive(Debug)]
96struct LivenessPoll {
97 task_slab: Arc<TaskSlab>,
98 slab_handle: usize,
99 rx: oneshot::Receiver<()>,
100}
101
102impl LivenessPoll {
103 fn create(task_slab: Arc<TaskSlab>) -> Self {
104 const INITIAL_SIZE: usize = 4;
105
106 let (index, rx) = {
107 let mut task_slab = task_slab.lock();
108 if task_slab.len() == task_slab.capacity() {
109 let reserve = ::std::cmp::max(task_slab.capacity(), INITIAL_SIZE);
112 task_slab.reserve_exact(reserve);
113 }
114
115 let (tx, rx) = oneshot::channel();
116 let index = task_slab.insert(Some(tx));
117 (index, rx)
118 };
119
120 LivenessPoll {
121 task_slab,
122 slab_handle: index,
123 rx,
124 }
125 }
126}
127
128impl Future for LivenessPoll {
129 type Output = ();
130
131 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
132 let this = Pin::into_inner(self);
133 match Pin::new(&mut this.rx).poll(cx) {
137 Poll::Ready(_) => Poll::Ready(()),
138 Poll::Pending => Poll::Pending,
139 }
140 }
141}
142
143impl Drop for LivenessPoll {
144 fn drop(&mut self) {
145 self.task_slab.lock().remove(self.slab_handle);
147 }
148}
149
150pub struct Session<M: core::Metadata, S: core::Middleware<M>> {
151 active: Arc<atomic::AtomicBool>,
152 context: metadata::RequestContext,
153 handler: Arc<core::MetaIoHandler<M, S>>,
154 meta_extractor: Arc<dyn metadata::MetaExtractor<M>>,
155 allowed_origins: Option<Vec<Origin>>,
156 allowed_hosts: Option<Vec<Host>>,
157 request_middleware: Option<Arc<dyn RequestMiddleware>>,
158 stats: Option<Arc<dyn SessionStats>>,
159 metadata: Option<M>,
160 executor: TaskExecutor,
161 task_slab: Arc<TaskSlab>,
162}
163
164impl<M: core::Metadata, S: core::Middleware<M>> Drop for Session<M, S> {
165 fn drop(&mut self) {
166 self.active.store(false, atomic::Ordering::SeqCst);
167 if let Some(stats) = self.stats.as_ref() {
168 stats.close_session(self.context.session_id)
169 }
170
171 for (_index, task) in self.task_slab.lock().iter_mut() {
173 if let Some(task) = task.take() {
174 let _ = task.send(());
175 }
176 }
177 }
178}
179
180impl<M: core::Metadata, S: core::Middleware<M>> Session<M, S> {
181 fn read_origin<'a>(&self, req: &'a ws::Request) -> Option<&'a [u8]> {
182 req.header("origin").map(|x| &x[..])
183 }
184
185 fn verify_origin(&self, origin: Option<&[u8]>) -> Option<ws::Response> {
186 if !header_is_allowed(&self.allowed_origins, origin) {
187 warn!(
188 "Blocked connection to WebSockets server from untrusted origin: {:?}",
189 origin.and_then(|s| std::str::from_utf8(s).ok()),
190 );
191 Some(forbidden("URL Blocked", "Connection Origin has been rejected."))
192 } else {
193 None
194 }
195 }
196
197 fn verify_host(&self, req: &ws::Request) -> Option<ws::Response> {
198 let host = req.header("host").map(|x| &x[..]);
199 if !header_is_allowed(&self.allowed_hosts, host) {
200 warn!(
201 "Blocked connection to WebSockets server with untrusted host: {:?}",
202 host.and_then(|s| std::str::from_utf8(s).ok()),
203 );
204 Some(forbidden("URL Blocked", "Connection Host has been rejected."))
205 } else {
206 None
207 }
208 }
209}
210
211impl<M: core::Metadata, S: core::Middleware<M>> ws::Handler for Session<M, S>
212where
213 S::Future: Unpin,
214 S::CallFuture: Unpin,
215{
216 fn on_request(&mut self, req: &ws::Request) -> ws::Result<ws::Response> {
217 let action = if let Some(ref middleware) = self.request_middleware {
219 middleware.process(req)
220 } else {
221 MiddlewareAction::Proceed
222 };
223
224 let origin = self.read_origin(req);
225 if action.should_verify_origin() {
226 if let Some(response) = self.verify_origin(origin) {
228 return Ok(response);
229 }
230 }
231
232 if action.should_verify_hosts() {
233 if let Some(response) = self.verify_host(req) {
235 return Ok(response);
236 }
237 }
238
239 self.context.origin = origin
240 .and_then(|origin| ::std::str::from_utf8(origin).ok())
241 .map(Into::into);
242 self.context.protocols = req
243 .protocols()
244 .ok()
245 .map(|protos| protos.into_iter().map(Into::into).collect())
246 .unwrap_or_else(Vec::new);
247 self.metadata = Some(self.meta_extractor.extract(&self.context));
248
249 match action {
250 MiddlewareAction::Proceed => ws::Response::from_request(req).map(|mut res| {
251 if let Some(protocol) = self.context.protocols.get(0) {
252 res.set_protocol(protocol);
253 }
254 res
255 }),
256 MiddlewareAction::Respond { response, .. } => Ok(response),
257 }
258 }
259
260 fn on_message(&mut self, msg: ws::Message) -> ws::Result<()> {
261 let req = msg.as_text()?;
262 let out = self.context.out.clone();
263 let metadata = self
264 .metadata
265 .clone()
266 .expect("Metadata is always set in on_request; qed");
267
268 let poll_liveness = LivenessPoll::create(self.task_slab.clone());
272
273 let active_lock = self.active.clone();
274 let response = self.handler.handle_request(req, metadata);
275
276 let future = response.map(move |response| {
277 if !active_lock.load(atomic::Ordering::SeqCst) {
278 return;
279 }
280 if let Some(result) = response {
281 let res = out.send(result);
282 match res {
283 Err(error::Error::ConnectionClosed) => {
284 active_lock.store(false, atomic::Ordering::SeqCst);
285 }
286 Err(e) => {
287 warn!("Error while sending response: {:?}", e);
288 }
289 _ => {}
290 }
291 }
292 });
293
294 let future = future::select(future, poll_liveness);
295 self.executor.spawn(future);
296
297 Ok(())
298 }
299}
300
301pub struct Factory<M: core::Metadata, S: core::Middleware<M>> {
302 session_id: SessionId,
303 handler: Arc<core::MetaIoHandler<M, S>>,
304 meta_extractor: Arc<dyn metadata::MetaExtractor<M>>,
305 allowed_origins: Option<Vec<Origin>>,
306 allowed_hosts: Option<Vec<Host>>,
307 request_middleware: Option<Arc<dyn RequestMiddleware>>,
308 stats: Option<Arc<dyn SessionStats>>,
309 executor: TaskExecutor,
310}
311
312impl<M: core::Metadata, S: core::Middleware<M>> Factory<M, S> {
313 pub fn new(
314 handler: Arc<core::MetaIoHandler<M, S>>,
315 meta_extractor: Arc<dyn metadata::MetaExtractor<M>>,
316 allowed_origins: Option<Vec<Origin>>,
317 allowed_hosts: Option<Vec<Host>>,
318 request_middleware: Option<Arc<dyn RequestMiddleware>>,
319 stats: Option<Arc<dyn SessionStats>>,
320 executor: TaskExecutor,
321 ) -> Self {
322 Factory {
323 session_id: 0,
324 handler,
325 meta_extractor,
326 allowed_origins,
327 allowed_hosts,
328 request_middleware,
329 stats,
330 executor,
331 }
332 }
333}
334
335impl<M: core::Metadata, S: core::Middleware<M>> ws::Factory for Factory<M, S>
336where
337 S::Future: Unpin,
338 S::CallFuture: Unpin,
339{
340 type Handler = Session<M, S>;
341
342 fn connection_made(&mut self, sender: ws::Sender) -> Self::Handler {
343 self.session_id += 1;
344 if let Some(executor) = self.stats.as_ref() {
345 executor.open_session(self.session_id)
346 }
347 let active = Arc::new(atomic::AtomicBool::new(true));
348
349 Session {
350 active: active.clone(),
351 context: metadata::RequestContext {
352 session_id: self.session_id,
353 origin: None,
354 protocols: Vec::new(),
355 out: metadata::Sender::new(sender, active),
356 executor: self.executor.clone(),
357 },
358 handler: self.handler.clone(),
359 meta_extractor: self.meta_extractor.clone(),
360 allowed_origins: self.allowed_origins.clone(),
361 allowed_hosts: self.allowed_hosts.clone(),
362 stats: self.stats.clone(),
363 request_middleware: self.request_middleware.clone(),
364 metadata: None,
365 executor: self.executor.clone(),
366 task_slab: Arc::new(Mutex::new(Slab::with_capacity(0))),
367 }
368 }
369}
370
371fn header_is_allowed<T>(allowed: &Option<Vec<T>>, header: Option<&[u8]>) -> bool
372where
373 T: Pattern,
374{
375 let header = header.map(std::str::from_utf8);
376
377 match (header, allowed.as_ref()) {
378 (None, _) => true,
380 (_, None) => true,
382 (Some(Ok(val)), Some(values)) => {
384 for v in values {
385 if v.matches(val) {
386 return true;
387 }
388 }
389 false
390 }
391 _ => false,
393 }
394}
395
396fn forbidden(title: &str, message: &str) -> ws::Response {
397 let mut forbidden = ws::Response::new(403, "Forbidden", format!("{}\n{}\n", title, message).into_bytes());
398 {
399 let headers = forbidden.headers_mut();
400 headers.push(("Connection".to_owned(), b"close".to_vec()));
401 }
402 forbidden
403}