1use super::{SendBack, StreamSender, StreamReceiver, Configurator};
2use crate::error::{ResponseError, TaskError};
3use crate::util::{watch, poll_fn};
4use crate::packet::{
5 Packet, Kind, Flags, PacketHeader, PacketBytes, PacketError
6};
7use crate::server::Config;
8
9use std::collections::{HashMap, hash_map::Entry};
10use std::future::Future;
11use std::task::Poll;
12use std::marker::PhantomData;
13use std::pin::Pin;
14
15use tokio::sync::{mpsc, oneshot};
16
17
18pub(crate) struct Receiver<P> {
20 inner: mpsc::Receiver<Message<P>>,
21 cfg: watch::Sender<Config>
22}
23
24impl<P> Receiver<P> {
25 pub async fn receive(&mut self) -> Option<Message<P>> {
27 self.inner.recv().await
28 }
29
30 pub fn update_config(&self, cfg: Config) {
31 self.cfg.send(cfg);
32 }
33
34 pub fn configurator(&self) -> Configurator<Config> {
35 Configurator::new(self.cfg.clone())
36 }
37}
38
39#[derive(Debug)]
41pub enum Message<P> {
42 Request(P, ResponseSender<P>),
43 RequestSender(P, StreamReceiver<P>),
45 RequestReceiver(P, StreamSender<P>)
47}
48
49#[derive(Debug)]
51pub struct ResponseSender<P> {
52 pub(crate) inner: oneshot::Sender<P>
53}
54
55impl<P> ResponseSender<P> {
56 pub(crate) fn new(inner: oneshot::Sender<P>) -> Self {
57 Self { inner }
58 }
59
60 pub fn send(
65 self,
66 packet: P
67 ) -> Result<(), ResponseError> {
68 self.inner.send(packet)
69 .map_err(|_| ResponseError::ConnectionClosed)
70 }
71}
72
73pub enum Response<P> {
74 Request(oneshot::Receiver<P>),
75 Receiver(mpsc::Receiver<P>)
77}
78
79struct WaitingOnServer<P, B> {
81 inner: HashMap<u32, Response<P>>,
83 marker: PhantomData<B>
84}
85
86
87impl<P, B> WaitingOnServer<P, B>
88where
89 P: Packet<B>,
90 B: PacketBytes
91{
92 fn new() -> Self {
93 Self {
94 inner: HashMap::new(),
95 marker: PhantomData
96 }
97 }
98
99 fn insert(
100 &mut self,
101 id: u32,
102 receiver: Response<P>
103 ) -> Result<(), TaskError> {
104 match self.inner.entry(id) {
105 Entry::Occupied(occ) => Err(TaskError::ExistingId(*occ.key())),
106 Entry::Vacant(v) => {
107 v.insert(receiver);
108 Ok(())
109 }
110 }
111 }
112
113 pub async fn to_send(&mut self) -> Option<P> {
114 if self.inner.is_empty() {
115 return None
116 }
117
118 let (packet, rem) = poll_fn(|ctx| {
119
120 for (id, resp) in &mut self.inner {
121 match resp {
122 Response::Request(resp) => {
123 match Pin::new(resp).poll(ctx) {
124 Poll::Pending => {},
125 Poll::Ready(Ok(mut packet)) => {
126 let flags = Flags::new(Kind::Response);
128 packet.header_mut().set_flags(flags);
129 packet.header_mut().set_id(*id);
130
131 return Poll::Ready((packet, Some(*id)))
132 },
133 Poll::Ready(Err(_)) => {
134 let mut p = P::empty();
136 let flags = Flags::new(Kind::NoResponse);
137 p.header_mut().set_flags(flags);
138 p.header_mut().set_id(*id);
139
140 return Poll::Ready((p, Some(*id)))
141 }
142 }
143 },
144 Response::Receiver(resp) => {
145 match resp.poll_recv(ctx) {
146 Poll::Pending => {},
147 Poll::Ready(Some(mut packet)) => {
148 let flags = Flags::new(Kind::Stream);
150 packet.header_mut().set_flags(flags);
151 packet.header_mut().set_id(*id);
152
153 return Poll::Ready((packet, None))
154 },
155 Poll::Ready(None) => {
156 let mut p = P::empty();
159 let flags = Flags::new(Kind::StreamClosed);
160 p.header_mut().set_flags(flags);
161 p.header_mut().set_id(*id);
162
163 return Poll::Ready((p, Some(*id)))
164 }
165 }
166 }
167 }
168 }
169
170 Poll::Pending
171 }).await;
172
173 if let Some(rem) = rem {
174 self.inner.remove(&rem);
175 }
176
177 Some(packet)
178 }
179
180 pub fn close_all(&mut self) {
181 for resp in self.inner.values_mut() {
182 match resp {
183 Response::Request(resp) => resp.close(),
184 Response::Receiver(resp) => resp.close()
185 }
186 }
187 }
188
189 pub fn close(&mut self, id: u32) {
190 match self.inner.get_mut(&id) {
191 Some(Response::Request(resp)) => resp.close(),
192 Some(Response::Receiver(resp)) => resp.close(),
193 _ => {}
194 }
195 }
196}
197
198pub struct Handler<P, B>
200where
201 P: Packet<B>,
202 B: PacketBytes
203{
204 msg_to_server: mpsc::Sender<Message<P>>,
206 waiting_on_client: HashMap<u32, mpsc::Sender<P>>,
208 waiting_on_server: WaitingOnServer<P, B>
210}
211
212impl<P, B> Handler<P, B>
213where
214 P: Packet<B>,
215 B: PacketBytes
216{
217 pub(crate) fn new(
220 cfg: Config
221 ) -> (Receiver<P>, watch::Receiver<Config>, Self) {
222 let (tx, rx) = mpsc::channel(10);
223 let (cfg_tx, cfg_rx) = watch::channel(cfg);
224
225 (
226 Receiver {
227 inner: rx,
228 cfg: cfg_tx
229 },
230 cfg_rx,
231 Self {
232 msg_to_server: tx,
233 waiting_on_client: HashMap::new(),
234 waiting_on_server: WaitingOnServer::new()
235 }
236 )
237 }
238
239 pub(crate) fn ping_packet(&self) -> P {
240 let mut p = P::empty();
241 let flags = Flags::new(Kind::Ping);
242 p.header_mut().set_flags(flags);
243 p
244 }
245
246 fn stream_close_packet(&self, id: u32) -> P {
247 let mut p = P::empty();
248 let flags = Flags::new(Kind::StreamClosed);
249 p.header_mut().set_flags(flags);
250 p.header_mut().set_id(id);
251 p
252 }
253
254 pub(crate) async fn send(
256 &mut self,
257 packet: P
258 ) -> Result<SendBack<P>, TaskError> {
259 let flags = packet.header().flags();
260 let id = packet.header().id();
261 let kind = flags.kind();
262
263 match kind {
264 Kind::Request => {
265 let (tx, rx) = oneshot::channel();
266
267 self.waiting_on_server.insert(id, Response::Request(rx))?;
268
269 let sr = self.msg_to_server.send(Message::Request(
270 packet,
271 ResponseSender::new(tx)
272 )).await;
273
274 match sr {
275 Ok(_) => Ok(SendBack::None),
276 Err(_) => Ok(SendBack::CloseWithPacket)
279 }
280 },
281 Kind::RequestReceiver => {
282 let (tx, rx) = mpsc::channel(10);
283 self.waiting_on_server.insert(id, Response::Receiver(rx))?;
284
285 let sr = self.msg_to_server.send(Message::RequestReceiver(
286 packet,
287 StreamSender::new(tx)
288 )).await;
289
290 match sr {
291 Ok(_) => Ok(SendBack::None),
292 Err(_) => Ok(SendBack::CloseWithPacket)
295 }
296 },
297 Kind::RequestSender => {
298 let (tx, rx) = mpsc::channel(10);
299
300 match self.waiting_on_client.entry(id) {
301 Entry::Occupied(occ) => {
302 return Err(TaskError::ExistingId(*occ.key()))
303 },
304 Entry::Vacant(v) => {
305 v.insert(tx);
306 }
307 }
308
309 let sr = self.msg_to_server.send(Message::RequestSender(
310 packet,
311 StreamReceiver::new(rx)
312 )).await;
313
314 match sr {
315 Ok(_) => Ok(SendBack::None),
316 Err(_) => Ok(SendBack::CloseWithPacket)
319 }
320 },
321 Kind::Stream => {
322 match self.waiting_on_client.entry(id) {
323 Entry::Occupied(mut occ) => {
324 if let Err(_) = occ.get_mut().send(packet).await {
325 occ.remove_entry();
327 let p = self.stream_close_packet(id);
330 Ok(SendBack::Packet(p))
331 } else {
332 Ok(SendBack::None)
333 }
334 },
335 Entry::Vacant(_) => {
336 let p = self.stream_close_packet(id);
340 Ok(SendBack::Packet(p))
341 }
342 }
343 },
344 Kind::StreamClosed => {
345 let _ = self.waiting_on_client.remove(&id);
346 self.waiting_on_server.close(id);
347 Ok(SendBack::None)
348 },
349 Kind::Close => Ok(SendBack::Close),
350 Kind::Ping => Ok(SendBack::None),
351 k => Err(TaskError::WrongPacketKind(k.to_str()))
352 }
353 }
354
355 pub async fn to_send(&mut self) -> Option<P> {
358 self.waiting_on_server.to_send().await
359 }
360
361 fn malformed_request(&self, id: u32) -> P {
362 let mut p = P::empty();
363 let flags = Flags::new(Kind::MalformedRequest);
365 p.header_mut().set_flags(flags);
366 p.header_mut().set_id(id);
367
368 p
369 }
370
371 pub(crate) fn packet_error(
373 &mut self,
374 header: P::Header,
375 e: PacketError
376 ) -> Result<SendBack<P>, TaskError> {
377 let flags = header.flags();
378 let id = header.id();
379 let kind = flags.kind();
380
381 match kind {
382 Kind::Request => Ok(SendBack::Packet(self.malformed_request(id))),
383 Kind::RequestSender |
384 Kind::RequestReceiver => {
385 Ok(SendBack::Packet(self.stream_close_packet(id)))
386 },
387 Kind::Stream => {
388 tracing::error!(
390 "failed to parse stream packet {} {:?}",
391 header.id(),
392 e
393 );
394 Ok(SendBack::None)
395 },
396 Kind::Close |
398 Kind::Ping |
399 Kind::StreamClosed => Err(TaskError::Packet(e)),
400 k => Err(TaskError::WrongPacketKind(k.to_str()))
401 }
402 }
403
404 pub fn close(&mut self) -> P {
405 self.waiting_on_server.close_all();
406
407 let mut p = P::empty();
408 let flags = Flags::new(Kind::Close);
409 p.header_mut().set_flags(flags);
410
411 p
412 }
413}