fire_http_api/stream/
server.rs

1use super::error::UnrecoverableError;
2use super::message::{Message, MessageData, MessageKind};
3use super::stream::{Stream, StreamKind};
4use super::streamer::RawStreamer;
5
6use std::borrow::Cow;
7use std::collections::HashMap;
8use std::future::poll_fn;
9use std::net::SocketAddr;
10use std::sync::Arc;
11use std::task::Poll;
12
13use tokio::sync::mpsc;
14use tokio::time::{interval, Duration};
15
16use tracing::{error, trace};
17
18use fire::header::{Method, RequestHeader};
19use fire::routes::{
20	HyperRequest, ParamsNames, PathParams, RawRoute, RoutePath,
21};
22pub use fire::util::PinnedFuture;
23use fire::ws::{self, JsonError, WebSocket};
24use fire::{resources::Resources, Response};
25
26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
27struct Request {
28	action: Cow<'static, str>,
29	kind: StreamKind,
30}
31
32pub trait IntoStreamHandler {
33	type Stream: Stream;
34	type Handler: StreamHandler;
35
36	fn into_handler(self) -> Self::Handler;
37}
38
39pub trait StreamHandler {
40	fn validate_requirements(
41		&self,
42		_params: &ParamsNames,
43		_resources: &Resources,
44	) {
45	}
46
47	/// every MessageData needs to correspond with the StreamTrait
48	///
49	/// ## Warning
50	/// Your not allowed to drop streamer before you return from the function
51	/// else that may leed to a busy loop in the StreamServer (todo improve that)
52	fn handle<'a>(
53		&'a self,
54		req: MessageData,
55		header: &'a RequestHeader,
56		params: &'a PathParams,
57		streamer: RawStreamer,
58		data: &'a Resources,
59	) -> PinnedFuture<'a, Result<MessageData, UnrecoverableError>>;
60}
61
62pub struct StreamServer {
63	uri: &'static str,
64	inner: Arc<HashMap<Request, Box<dyn StreamHandler + Send + Sync>>>,
65}
66
67impl StreamServer {
68	pub fn new(uri: &'static str) -> Self {
69		Self {
70			uri,
71			inner: Arc::new(HashMap::new()),
72		}
73	}
74
75	pub fn insert<H>(&mut self, handler: H)
76	where
77		H: IntoStreamHandler,
78		H::Handler: StreamHandler + Send + Sync + 'static,
79	{
80		Arc::get_mut(&mut self.inner).unwrap().insert(
81			Request {
82				action: H::Stream::ACTION.into(),
83				kind: H::Stream::KIND,
84			},
85			Box::new(handler.into_handler()),
86		);
87	}
88}
89
90impl RawRoute for StreamServer {
91	fn path(&self) -> RoutePath {
92		RoutePath {
93			method: Some(Method::GET),
94			path: self.uri.into(),
95		}
96	}
97
98	fn call<'a>(
99		&'a self,
100		req: &'a mut HyperRequest,
101		address: SocketAddr,
102		params: &'a PathParams,
103		resources: &'a Resources,
104	) -> PinnedFuture<'a, Option<fire::Result<Response>>> {
105		PinnedFuture::new(async move {
106			let (on_upgrade, ws_accept) = match ws::util::upgrade(req) {
107				Ok(o) => o,
108				Err(e) => return Some(Err(e)),
109			};
110
111			let header = fire::ws::util::hyper_req_to_header(req, address);
112			let header = match header {
113				Ok(h) => h,
114				Err(e) => return Some(Err(e)),
115			};
116
117			let handlers = self.inner.clone();
118			let resources = resources.clone();
119			let params = params.clone();
120
121			// we need to spawn a future because
122			// upgrade on can only be fufilled after
123			// we send SWITCHING_PROTOCOLS
124			tokio::task::spawn(async move {
125				match on_upgrade.await {
126					Ok(upgraded) => {
127						let ws = WebSocket::new(upgraded).await;
128
129						trace!("connection upgraded");
130
131						let res = handle_connection(
132							handlers, ws, header, params, resources,
133						)
134						.await;
135						if let Err(e) = res {
136							error!("websocket connection failed with {:?}", e);
137						}
138					}
139					Err(e) => ws::util::upgrade_error(e),
140				}
141			});
142
143			Some(Ok(ws::util::switching_protocols(ws_accept)))
144		})
145	}
146}
147
148async fn handle_connection(
149	handlers: Arc<HashMap<Request, Box<dyn StreamHandler + Send + Sync>>>,
150	mut ws: WebSocket,
151	header: RequestHeader,
152	params: PathParams,
153	data: Resources,
154) -> Result<(), UnrecoverableError> {
155	let mut receivers = Receivers::new();
156	let mut senders = Senders::new();
157	// data: (Request, MessageData)
158	let (close_tx, mut close_rx) = mpsc::channel(10);
159	let mut ping_interval = interval(Duration::from_secs(30));
160
161	loop {
162		tokio::select! {
163			msg = ws.deserialize() => {
164				let msg: Message = match msg {
165					Ok(None) => return Ok(()),
166					Ok(Some(m)) => m,
167					Err(JsonError::ConnectionError(e)) => {
168						return Err(e.to_string().into())
169					},
170					Err(JsonError::SerdeError(e)) => {
171						error!("could not deserialize message {:?}", e);
172						// json error just ignore the message
173						continue
174					}
175				};
176
177				trace!("received message {:?}", msg);
178
179				let req = Request {
180					action: msg.action,
181					kind: msg.kind.into()
182				};
183
184				match msg.kind {
185					k @ MessageKind::SenderRequest |
186					k @ MessageKind::ReceiverRequest => {
187						// no handler
188						if !handlers.contains_key(&req) {
189							error!("no handler for {:?} found", req);
190							ws.serialize(&Message {
191								kind: msg.kind.into_close(),
192								action: req.action.clone(),
193								data: MessageData::null()
194							}).await.map_err(|e| e.to_string())?;
195							continue
196						}
197
198						// we know the handler exists
199						let (tx, rx) = mpsc::channel(10);
200
201						let streamer = match req.kind {
202							// the client want's to send us data
203							StreamKind::Sender => {
204								if !senders.insert(req.clone(), tx) {
205									// the sender already exist
206									// don't create a new handler
207									continue
208								}
209								RawStreamer::receiver(rx)
210							},
211							// the client want's to receive data from us
212							StreamKind::Receiver => {
213								if !receivers.insert(req.clone(), rx) {
214									// the handler already exists
215									continue
216								}
217								RawStreamer::sender(tx)
218							}
219						};
220
221						// let's send a success message
222						ws.serialize(&Message {
223							kind: k,
224							action: req.action.clone(),
225							data: MessageData::null()
226						}).await.map_err(|e| e.to_string())?;
227
228						let data = data.clone();
229						let handlers = handlers.clone();
230						let msg_data = msg.data;
231						let header = header.clone();
232						let close_tx = close_tx.clone();
233						let params = params.clone();
234
235						// the first task only catches panics
236						// and the seconds starts the handler
237						// we could also detect a panic when trying to send
238						// or receive via a mpsc channel.
239						// but that could lead to multiple close messages being
240						// sent when the task succesfully exits
241						tokio::spawn(async move {
242							let panic_close_tx = close_tx.clone();
243							let panic_req = req.clone();
244
245							let r = tokio::spawn(async move {
246
247								let handler = match handlers.get(&req) {
248									Some(h) => h,
249									None => unreachable!()
250								};
251
252								let r = handler.handle(
253									msg_data,
254									&header,
255									&params,
256									streamer,
257									&data
258								).await;
259								match r {
260									Ok(m) => {
261										let _ = close_tx.send((req, m)).await;
262									},
263									Err(e) => {
264										error!("stream handler unrecoverable \
265											error {:?}", e
266										);
267										let _ = close_tx.send(
268											(req, MessageData::null())
269										).await;
270									}
271								}
272							}).await;
273
274							if r.is_err() {
275								// some error happened so let's send a close req
276								let _ = panic_close_tx.send(
277									(panic_req, MessageData::null())
278								).await;
279							}
280						});
281
282					},
283					MessageKind::SenderMessage => {
284						// if a handler is already closed don't do anything
285						// since it is guaranteed to get closed via close_tx
286						// if a handler does not exist
287						// this is a protocol error since you would get a
288						// a
289						let _ = senders.send(&req, msg.data).await;
290					},
291					MessageKind::ReceiverMessage => {
292						// we should not receive this message
293						// this is a protocol error
294					},
295					MessageKind::SenderClose => {
296						senders.remove(&req);
297					},
298					MessageKind::ReceiverClose => {
299						receivers.remove(&req);
300					}
301				}
302			},
303			(req, data) = receivers.recv(), if !receivers.is_empty() => {
304				ws.serialize(&Message {
305					kind: req.kind.into_kind_message(),
306					action: req.action,
307					data: data
308				}).await.map_err(|e| e.to_string())?;
309			},
310			_ping = ping_interval.tick() => {
311				ws.ping().await
312					.map_err(|e| e.to_string())?;
313			},
314			msg = close_rx.recv() => {
315				// cannot fail since we always have a close_tx
316				let (req, data) = msg.unwrap();
317
318				match req.kind {
319					StreamKind::Sender => {
320						if senders.remove(&req) {
321							ws.serialize(&Message {
322								kind: MessageKind::SenderClose,
323								action: req.action,
324								data: data
325							}).await.map_err(|e| e.to_string())?;
326						}
327					},
328					StreamKind::Receiver => {
329						if receivers.remove(&req) {
330							ws.serialize(&Message {
331								kind: MessageKind::ReceiverClose,
332								action: req.action,
333								data: data
334							}).await.map_err(|e| e.to_string())?;
335						}
336					}
337				}
338			}
339		}
340	}
341}
342
343struct Receivers {
344	inner: HashMap<Request, mpsc::Receiver<MessageData>>,
345	// we use a recv queue to make polling more fair since we poll
346	// all futures and check if they have available data and the store everything
347	// in the queue
348	// the problem here is that we don't return on the first one and always poll
349	// every future (which is not great)
350	//
351	// todo use a crate for this
352	recv_queue: Vec<(Request, MessageData)>,
353}
354
355impl Receivers {
356	pub fn new() -> Self {
357		Self {
358			inner: HashMap::new(),
359			recv_queue: vec![],
360		}
361	}
362
363	pub fn is_empty(&self) -> bool {
364		self.inner.is_empty() && self.recv_queue.is_empty()
365	}
366
367	/// if no receivers exist this will wait for every
368	/// you should don't call recv when `is_empty` returns `true`
369	pub async fn recv(&mut self) -> (Request, MessageData) {
370		if let Some(msg) = self.recv_queue.pop() {
371			return msg;
372		}
373
374		debug_assert!(!self.inner.is_empty(), "will wait for ever");
375
376		poll_fn(|ctx| {
377			for (req, rx) in self.inner.iter_mut() {
378				match rx.poll_recv(ctx) {
379					Poll::Pending => continue,
380					Poll::Ready(Some(data)) => {
381						self.recv_queue.push((req.clone(), data))
382					}
383					// todo maybe we should remove those
384					// but since the receiver will probably quickly be removed
385					// it should not be a problem
386					Poll::Ready(None) => continue,
387				}
388			}
389
390			match self.recv_queue.pop() {
391				Some(m) => Poll::Ready(m),
392				None => Poll::Pending,
393			}
394		})
395		.await
396	}
397
398	pub fn insert(
399		&mut self,
400		req: Request,
401		recv: mpsc::Receiver<MessageData>,
402	) -> bool {
403		if self.inner.contains_key(&req) {
404			return false;
405		}
406
407		self.inner.insert(req, recv).is_none()
408	}
409
410	pub fn remove(&mut self, req: &Request) -> bool {
411		self.inner.remove(req).is_some()
412	}
413}
414
415struct Senders {
416	inner: HashMap<Request, mpsc::Sender<MessageData>>,
417}
418
419impl Senders {
420	pub fn new() -> Self {
421		Self {
422			inner: HashMap::new(),
423		}
424	}
425
426	pub fn insert(
427		&mut self,
428		req: Request,
429		sender: mpsc::Sender<MessageData>,
430	) -> bool {
431		if self.inner.contains_key(&req) {
432			return false;
433		}
434
435		self.inner.insert(req, sender).is_none()
436	}
437
438	pub async fn send(&mut self, req: &Request, data: MessageData) {
439		if let Some(sender) = self.inner.get(req) {
440			// todo should we send an error here??
441			let _ = sender.send(data).await;
442		}
443	}
444
445	pub fn remove(&mut self, req: &Request) -> bool {
446		self.inner.remove(req).is_some()
447	}
448}