fire_stream/handler/
server.rs

1use super::{SendBack, StreamSender, StreamReceiver, Configurator};
2use crate::error::{ResponseError, TaskError};
3use crate::util::{watch, poll_fn};
4use crate::packet::{
5	Packet, Kind, Flags, PacketHeader, PacketBytes, PacketError
6};
7use crate::server::Config;
8
9use std::collections::{HashMap, hash_map::Entry};
10use std::future::Future;
11use std::task::Poll;
12use std::marker::PhantomData;
13use std::pin::Pin;
14
15use tokio::sync::{mpsc, oneshot};
16
17
18/// A receiver that waits on messages from the handler (client)
19pub(crate) struct Receiver<P> {
20	inner: mpsc::Receiver<Message<P>>,
21	cfg: watch::Sender<Config>
22}
23
24impl<P> Receiver<P> {
25	/// Receive a new message from the client
26	pub async fn receive(&mut self) -> Option<Message<P>> {
27		self.inner.recv().await
28	}
29
30	pub fn update_config(&self, cfg: Config) {
31		self.cfg.send(cfg);
32	}
33
34	pub fn configurator(&self) -> Configurator<Config> {
35		Configurator::new(self.cfg.clone())
36	}
37}
38
39/// All different kinds of messages.
40#[derive(Debug)]
41pub enum Message<P> {
42	Request(P, ResponseSender<P>),
43	// a request to receive a sender stream
44	RequestSender(P, StreamReceiver<P>),
45	// a request to receive a receiving stream
46	RequestReceiver(P, StreamSender<P>)
47}
48
49/// A sender used to respond to a request.
50#[derive(Debug)]
51pub struct ResponseSender<P> {
52	pub(crate) inner: oneshot::Sender<P>
53}
54
55impl<P> ResponseSender<P> {
56	pub(crate) fn new(inner: oneshot::Sender<P>) -> Self {
57		Self { inner }
58	}
59
60	/// Sends the packet as a response, adding the correct flags.
61	/// 
62	/// If this returns an Error it either means the connection was closed
63	/// or the requestor does not care about the response anymore.
64	pub fn send(
65		self,
66		packet: P
67	) -> Result<(), ResponseError> {
68		self.inner.send(packet)
69			.map_err(|_| ResponseError::ConnectionClosed)
70	}
71}
72
73pub enum Response<P> {
74	Request(oneshot::Receiver<P>),
75	// a request to receive a receiving stream
76	Receiver(mpsc::Receiver<P>)
77}
78
79/// A list of receivers that wait on a packet.
80struct WaitingOnServer<P, B> {
81	// hashmap because we need to check if the id is free
82	inner: HashMap<u32, Response<P>>,
83	marker: PhantomData<B>
84}
85
86
87impl<P, B> WaitingOnServer<P, B>
88where
89	P: Packet<B>,
90	B: PacketBytes
91{
92	fn new() -> Self {
93		Self {
94			inner: HashMap::new(),
95			marker: PhantomData
96		}
97	}
98
99	fn insert(
100		&mut self,
101		id: u32,
102		receiver: Response<P>
103	) -> Result<(), TaskError> {
104		match self.inner.entry(id) {
105			Entry::Occupied(occ) => Err(TaskError::ExistingId(*occ.key())),
106			Entry::Vacant(v) => {
107				v.insert(receiver);
108				Ok(())
109			}
110		}
111	}
112
113	pub async fn to_send(&mut self) -> Option<P> {
114		if self.inner.is_empty() {
115			return None
116		}
117
118		let (packet, rem) = poll_fn(|ctx| {
119
120			for (id, resp) in &mut self.inner {
121				match resp {
122					Response::Request(resp) => {
123						match Pin::new(resp).poll(ctx) {
124							Poll::Pending => {},
125							Poll::Ready(Ok(mut packet)) => {
126								// set kind::Stream and set the id
127								let flags = Flags::new(Kind::Response);
128								packet.header_mut().set_flags(flags);
129								packet.header_mut().set_id(*id);
130
131								return Poll::Ready((packet, Some(*id)))
132							},
133							Poll::Ready(Err(_)) => {
134								// channel closed
135								let mut p = P::empty();
136								let flags = Flags::new(Kind::NoResponse);
137								p.header_mut().set_flags(flags);
138								p.header_mut().set_id(*id);
139
140								return Poll::Ready((p, Some(*id)))
141							}
142						}
143					},
144					Response::Receiver(resp) => {
145						match resp.poll_recv(ctx) {
146							Poll::Pending => {},
147							Poll::Ready(Some(mut packet)) => {
148								// set kind::Stream and set the id
149								let flags = Flags::new(Kind::Stream);
150								packet.header_mut().set_flags(flags);
151								packet.header_mut().set_id(*id);
152
153								return Poll::Ready((packet, None))
154							},
155							Poll::Ready(None) => {
156								// channel closed
157
158								let mut p = P::empty();
159								let flags = Flags::new(Kind::StreamClosed);
160								p.header_mut().set_flags(flags);
161								p.header_mut().set_id(*id);
162
163								return Poll::Ready((p, Some(*id)))
164							}
165						}
166					}
167				}
168			}
169
170			Poll::Pending
171		}).await;
172
173		if let Some(rem) = rem {
174			self.inner.remove(&rem);
175		}
176
177		Some(packet)
178	}
179
180	pub fn close_all(&mut self) {
181		for resp in self.inner.values_mut() {
182			match resp {
183				Response::Request(resp) => resp.close(),
184				Response::Receiver(resp) => resp.close()
185			}
186		}
187	}
188
189	pub fn close(&mut self, id: u32) {
190		match self.inner.get_mut(&id) {
191			Some(Response::Request(resp)) => resp.close(),
192			Some(Response::Receiver(resp)) => resp.close(),
193			_ => {}
194		}
195	}
196}
197
198/// A handler that is responsible for the server.
199pub struct Handler<P, B>
200where
201	P: Packet<B>,
202	B: PacketBytes
203{
204	/// messages that should be sent to the server.
205	msg_to_server: mpsc::Sender<Message<P>>,
206	/// messages that are waiting on a packet from the client
207	waiting_on_client: HashMap<u32, mpsc::Sender<P>>,
208	/// messages that are waiting on a packet from the server
209	waiting_on_server: WaitingOnServer<P, B>
210}
211
212impl<P, B> Handler<P, B>
213where
214	P: Packet<B>,
215	B: PacketBytes
216{
217	/// Creates a new handler, return a receiver which can listens on new
218	/// messages.
219	pub(crate) fn new(
220		cfg: Config
221	) -> (Receiver<P>, watch::Receiver<Config>, Self) {
222		let (tx, rx) = mpsc::channel(10);
223		let (cfg_tx, cfg_rx) = watch::channel(cfg);
224
225		(
226			Receiver {
227				inner: rx,
228				cfg: cfg_tx
229			},
230			cfg_rx,
231			Self {
232				msg_to_server: tx,
233				waiting_on_client: HashMap::new(),
234				waiting_on_server: WaitingOnServer::new()
235			}
236		)
237	}
238
239	pub(crate) fn ping_packet(&self) -> P {
240		let mut p = P::empty();
241		let flags = Flags::new(Kind::Ping);
242		p.header_mut().set_flags(flags);
243		p
244	}
245
246	fn stream_close_packet(&self, id: u32) -> P {
247		let mut p = P::empty();
248		let flags = Flags::new(Kind::StreamClosed);
249		p.header_mut().set_flags(flags);
250		p.header_mut().set_id(id);
251		p
252	}
253
254	/// Should be called with a packet from the client.
255	pub(crate) async fn send(
256		&mut self,
257		packet: P
258	) -> Result<SendBack<P>, TaskError> {
259		let flags = packet.header().flags();
260		let id = packet.header().id();
261		let kind = flags.kind();
262
263		match kind {
264			Kind::Request => {
265				let (tx, rx) = oneshot::channel();
266
267				self.waiting_on_server.insert(id, Response::Request(rx))?;
268
269				let sr = self.msg_to_server.send(Message::Request(
270					packet,
271					ResponseSender::new(tx)
272				)).await;
273
274				match sr {
275					Ok(_) => Ok(SendBack::None),
276					// the server has no interest anymore
277					// Let's close the connection
278					Err(_) => Ok(SendBack::CloseWithPacket)
279				}
280			},
281			Kind::RequestReceiver => {
282				let (tx, rx) = mpsc::channel(10);	
283				self.waiting_on_server.insert(id, Response::Receiver(rx))?;
284
285				let sr = self.msg_to_server.send(Message::RequestReceiver(
286					packet,
287					StreamSender::new(tx)
288				)).await;
289
290				match sr {
291					Ok(_) => Ok(SendBack::None),
292					// the server has no interest anymore
293					// Let's close the connection
294					Err(_) => Ok(SendBack::CloseWithPacket)
295				}
296			},
297			Kind::RequestSender => {
298				let (tx, rx) = mpsc::channel(10);
299
300				match self.waiting_on_client.entry(id) {
301					Entry::Occupied(occ) => {
302						return Err(TaskError::ExistingId(*occ.key()))
303					},
304					Entry::Vacant(v) => {
305						v.insert(tx);
306					}
307				}
308
309				let sr = self.msg_to_server.send(Message::RequestSender(
310					packet,
311					StreamReceiver::new(rx)
312				)).await;
313
314				match sr {
315					Ok(_) => Ok(SendBack::None),
316					// the server has no interest anymore
317					// Let's close the connection
318					Err(_) => Ok(SendBack::CloseWithPacket)
319				}
320			},
321			Kind::Stream => {
322				match self.waiting_on_client.entry(id) {
323					Entry::Occupied(mut occ) => {
324						if let Err(_) = occ.get_mut().send(packet).await {
325							// since the stream is closed we should remove it
326							occ.remove_entry();
327							// we should send a response telling the other side
328							// that the stream is closed
329							let p = self.stream_close_packet(id);
330							Ok(SendBack::Packet(p))
331						} else {
332							Ok(SendBack::None)
333						}
334					},
335					Entry::Vacant(_) => {
336						// since the client could send multiple streams
337						// before we can send him a streamclosed
338						// we just ignore this packet
339						let p = self.stream_close_packet(id);
340						Ok(SendBack::Packet(p))
341					}
342				}
343			},
344			Kind::StreamClosed => {
345				let _ = self.waiting_on_client.remove(&id);
346				self.waiting_on_server.close(id);
347				Ok(SendBack::None)
348			},
349			Kind::Close => Ok(SendBack::Close),
350			Kind::Ping => Ok(SendBack::None),
351			k => Err(TaskError::WrongPacketKind(k.to_str()))
352		}
353	}
354
355	/// returns None if nothing is left to be done
356	/// if close=true is once set this cannot be reversed 
357	pub async fn to_send(&mut self) -> Option<P> {
358		self.waiting_on_server.to_send().await
359	}
360
361	fn malformed_request(&self, id: u32) -> P {
362		let mut p = P::empty();
363		// todo maybe replace with a malformed request
364		let flags = Flags::new(Kind::MalformedRequest);
365		p.header_mut().set_flags(flags);
366		p.header_mut().set_id(id);
367
368		p
369	}
370
371	/// we received a packet which had a malformed body
372	pub(crate) fn packet_error(
373		&mut self,
374		header: P::Header,
375		e: PacketError
376	) -> Result<SendBack<P>, TaskError> {
377		let flags = header.flags();
378		let id = header.id();
379		let kind = flags.kind();
380
381		match kind {
382			Kind::Request => Ok(SendBack::Packet(self.malformed_request(id))),
383			Kind::RequestSender |
384			Kind::RequestReceiver => {
385				Ok(SendBack::Packet(self.stream_close_packet(id)))
386			},
387			Kind::Stream => {
388				// ignore a stream packet which had an error
389				tracing::error!(
390					"failed to parse stream packet {} {:?}",
391					header.id(),
392					e
393				);
394				Ok(SendBack::None)
395			},
396			// this should not have a user generated so this should never fail
397			Kind::Close |
398			Kind::Ping |
399			Kind::StreamClosed => Err(TaskError::Packet(e)),
400			k => Err(TaskError::WrongPacketKind(k.to_str()))
401		}
402	}
403
404	pub fn close(&mut self) -> P {
405		self.waiting_on_server.close_all();
406
407		let mut p = P::empty();
408		let flags = Flags::new(Kind::Close);
409		p.header_mut().set_flags(flags);
410
411		p
412	}
413}