1use super::error::UnrecoverableError;
2use super::message::{Message, MessageData, MessageKind};
3use super::stream::{Stream, StreamKind};
4use super::streamer::RawStreamer;
5
6use std::borrow::Cow;
7use std::collections::HashMap;
8use std::future::poll_fn;
9use std::net::SocketAddr;
10use std::sync::Arc;
11use std::task::Poll;
12
13use tokio::sync::mpsc;
14use tokio::time::{interval, Duration};
15
16use tracing::{error, trace};
17
18use fire::header::{Method, RequestHeader};
19use fire::routes::{
20 HyperRequest, ParamsNames, PathParams, RawRoute, RoutePath,
21};
22pub use fire::util::PinnedFuture;
23use fire::ws::{self, JsonError, WebSocket};
24use fire::{resources::Resources, Response};
25
26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
27struct Request {
28 action: Cow<'static, str>,
29 kind: StreamKind,
30}
31
32pub trait IntoStreamHandler {
33 type Stream: Stream;
34 type Handler: StreamHandler;
35
36 fn into_handler(self) -> Self::Handler;
37}
38
39pub trait StreamHandler {
40 fn validate_requirements(
41 &self,
42 _params: &ParamsNames,
43 _resources: &Resources,
44 ) {
45 }
46
47 fn handle<'a>(
53 &'a self,
54 req: MessageData,
55 header: &'a RequestHeader,
56 params: &'a PathParams,
57 streamer: RawStreamer,
58 data: &'a Resources,
59 ) -> PinnedFuture<'a, Result<MessageData, UnrecoverableError>>;
60}
61
62pub struct StreamServer {
63 uri: &'static str,
64 inner: Arc<HashMap<Request, Box<dyn StreamHandler + Send + Sync>>>,
65}
66
67impl StreamServer {
68 pub fn new(uri: &'static str) -> Self {
69 Self {
70 uri,
71 inner: Arc::new(HashMap::new()),
72 }
73 }
74
75 pub fn insert<H>(&mut self, handler: H)
76 where
77 H: IntoStreamHandler,
78 H::Handler: StreamHandler + Send + Sync + 'static,
79 {
80 Arc::get_mut(&mut self.inner).unwrap().insert(
81 Request {
82 action: H::Stream::ACTION.into(),
83 kind: H::Stream::KIND,
84 },
85 Box::new(handler.into_handler()),
86 );
87 }
88}
89
90impl RawRoute for StreamServer {
91 fn path(&self) -> RoutePath {
92 RoutePath {
93 method: Some(Method::GET),
94 path: self.uri.into(),
95 }
96 }
97
98 fn call<'a>(
99 &'a self,
100 req: &'a mut HyperRequest,
101 address: SocketAddr,
102 params: &'a PathParams,
103 resources: &'a Resources,
104 ) -> PinnedFuture<'a, Option<fire::Result<Response>>> {
105 PinnedFuture::new(async move {
106 let (on_upgrade, ws_accept) = match ws::util::upgrade(req) {
107 Ok(o) => o,
108 Err(e) => return Some(Err(e)),
109 };
110
111 let header = fire::ws::util::hyper_req_to_header(req, address);
112 let header = match header {
113 Ok(h) => h,
114 Err(e) => return Some(Err(e)),
115 };
116
117 let handlers = self.inner.clone();
118 let resources = resources.clone();
119 let params = params.clone();
120
121 tokio::task::spawn(async move {
125 match on_upgrade.await {
126 Ok(upgraded) => {
127 let ws = WebSocket::new(upgraded).await;
128
129 trace!("connection upgraded");
130
131 let res = handle_connection(
132 handlers, ws, header, params, resources,
133 )
134 .await;
135 if let Err(e) = res {
136 error!("websocket connection failed with {:?}", e);
137 }
138 }
139 Err(e) => ws::util::upgrade_error(e),
140 }
141 });
142
143 Some(Ok(ws::util::switching_protocols(ws_accept)))
144 })
145 }
146}
147
148async fn handle_connection(
149 handlers: Arc<HashMap<Request, Box<dyn StreamHandler + Send + Sync>>>,
150 mut ws: WebSocket,
151 header: RequestHeader,
152 params: PathParams,
153 data: Resources,
154) -> Result<(), UnrecoverableError> {
155 let mut receivers = Receivers::new();
156 let mut senders = Senders::new();
157 let (close_tx, mut close_rx) = mpsc::channel(10);
159 let mut ping_interval = interval(Duration::from_secs(30));
160
161 loop {
162 tokio::select! {
163 msg = ws.deserialize() => {
164 let msg: Message = match msg {
165 Ok(None) => return Ok(()),
166 Ok(Some(m)) => m,
167 Err(JsonError::ConnectionError(e)) => {
168 return Err(e.to_string().into())
169 },
170 Err(JsonError::SerdeError(e)) => {
171 error!("could not deserialize message {:?}", e);
172 continue
174 }
175 };
176
177 trace!("received message {:?}", msg);
178
179 let req = Request {
180 action: msg.action,
181 kind: msg.kind.into()
182 };
183
184 match msg.kind {
185 k @ MessageKind::SenderRequest |
186 k @ MessageKind::ReceiverRequest => {
187 if !handlers.contains_key(&req) {
189 error!("no handler for {:?} found", req);
190 ws.serialize(&Message {
191 kind: msg.kind.into_close(),
192 action: req.action.clone(),
193 data: MessageData::null()
194 }).await.map_err(|e| e.to_string())?;
195 continue
196 }
197
198 let (tx, rx) = mpsc::channel(10);
200
201 let streamer = match req.kind {
202 StreamKind::Sender => {
204 if !senders.insert(req.clone(), tx) {
205 continue
208 }
209 RawStreamer::receiver(rx)
210 },
211 StreamKind::Receiver => {
213 if !receivers.insert(req.clone(), rx) {
214 continue
216 }
217 RawStreamer::sender(tx)
218 }
219 };
220
221 ws.serialize(&Message {
223 kind: k,
224 action: req.action.clone(),
225 data: MessageData::null()
226 }).await.map_err(|e| e.to_string())?;
227
228 let data = data.clone();
229 let handlers = handlers.clone();
230 let msg_data = msg.data;
231 let header = header.clone();
232 let close_tx = close_tx.clone();
233 let params = params.clone();
234
235 tokio::spawn(async move {
242 let panic_close_tx = close_tx.clone();
243 let panic_req = req.clone();
244
245 let r = tokio::spawn(async move {
246
247 let handler = match handlers.get(&req) {
248 Some(h) => h,
249 None => unreachable!()
250 };
251
252 let r = handler.handle(
253 msg_data,
254 &header,
255 ¶ms,
256 streamer,
257 &data
258 ).await;
259 match r {
260 Ok(m) => {
261 let _ = close_tx.send((req, m)).await;
262 },
263 Err(e) => {
264 error!("stream handler unrecoverable \
265 error {:?}", e
266 );
267 let _ = close_tx.send(
268 (req, MessageData::null())
269 ).await;
270 }
271 }
272 }).await;
273
274 if r.is_err() {
275 let _ = panic_close_tx.send(
277 (panic_req, MessageData::null())
278 ).await;
279 }
280 });
281
282 },
283 MessageKind::SenderMessage => {
284 let _ = senders.send(&req, msg.data).await;
290 },
291 MessageKind::ReceiverMessage => {
292 },
295 MessageKind::SenderClose => {
296 senders.remove(&req);
297 },
298 MessageKind::ReceiverClose => {
299 receivers.remove(&req);
300 }
301 }
302 },
303 (req, data) = receivers.recv(), if !receivers.is_empty() => {
304 ws.serialize(&Message {
305 kind: req.kind.into_kind_message(),
306 action: req.action,
307 data: data
308 }).await.map_err(|e| e.to_string())?;
309 },
310 _ping = ping_interval.tick() => {
311 ws.ping().await
312 .map_err(|e| e.to_string())?;
313 },
314 msg = close_rx.recv() => {
315 let (req, data) = msg.unwrap();
317
318 match req.kind {
319 StreamKind::Sender => {
320 if senders.remove(&req) {
321 ws.serialize(&Message {
322 kind: MessageKind::SenderClose,
323 action: req.action,
324 data: data
325 }).await.map_err(|e| e.to_string())?;
326 }
327 },
328 StreamKind::Receiver => {
329 if receivers.remove(&req) {
330 ws.serialize(&Message {
331 kind: MessageKind::ReceiverClose,
332 action: req.action,
333 data: data
334 }).await.map_err(|e| e.to_string())?;
335 }
336 }
337 }
338 }
339 }
340 }
341}
342
343struct Receivers {
344 inner: HashMap<Request, mpsc::Receiver<MessageData>>,
345 recv_queue: Vec<(Request, MessageData)>,
353}
354
355impl Receivers {
356 pub fn new() -> Self {
357 Self {
358 inner: HashMap::new(),
359 recv_queue: vec![],
360 }
361 }
362
363 pub fn is_empty(&self) -> bool {
364 self.inner.is_empty() && self.recv_queue.is_empty()
365 }
366
367 pub async fn recv(&mut self) -> (Request, MessageData) {
370 if let Some(msg) = self.recv_queue.pop() {
371 return msg;
372 }
373
374 debug_assert!(!self.inner.is_empty(), "will wait for ever");
375
376 poll_fn(|ctx| {
377 for (req, rx) in self.inner.iter_mut() {
378 match rx.poll_recv(ctx) {
379 Poll::Pending => continue,
380 Poll::Ready(Some(data)) => {
381 self.recv_queue.push((req.clone(), data))
382 }
383 Poll::Ready(None) => continue,
387 }
388 }
389
390 match self.recv_queue.pop() {
391 Some(m) => Poll::Ready(m),
392 None => Poll::Pending,
393 }
394 })
395 .await
396 }
397
398 pub fn insert(
399 &mut self,
400 req: Request,
401 recv: mpsc::Receiver<MessageData>,
402 ) -> bool {
403 if self.inner.contains_key(&req) {
404 return false;
405 }
406
407 self.inner.insert(req, recv).is_none()
408 }
409
410 pub fn remove(&mut self, req: &Request) -> bool {
411 self.inner.remove(req).is_some()
412 }
413}
414
415struct Senders {
416 inner: HashMap<Request, mpsc::Sender<MessageData>>,
417}
418
419impl Senders {
420 pub fn new() -> Self {
421 Self {
422 inner: HashMap::new(),
423 }
424 }
425
426 pub fn insert(
427 &mut self,
428 req: Request,
429 sender: mpsc::Sender<MessageData>,
430 ) -> bool {
431 if self.inner.contains_key(&req) {
432 return false;
433 }
434
435 self.inner.insert(req, sender).is_none()
436 }
437
438 pub async fn send(&mut self, req: &Request, data: MessageData) {
439 if let Some(sender) = self.inner.get(req) {
440 let _ = sender.send(data).await;
442 }
443 }
444
445 pub fn remove(&mut self, req: &Request) -> bool {
446 self.inner.remove(req).is_some()
447 }
448}