hala_rproxy/
utils.rs

1use std::{
2    io,
3    sync::Arc,
4    task::{Poll, Waker},
5};
6
7use futures::Future;
8
9use hala_quic::QuicConnectionId;
10use hala_sync::{spin_simple, Lockable};
11use uuid::Uuid;
12
13/// The connection id of transport layer.
14#[non_exhaustive]
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub enum ConnId<'a> {
17    /// Connection id for tcp.
18    Tcp(Uuid),
19    /// Connection id for quic stream.
20    QuicStream(QuicConnectionId<'a>, QuicConnectionId<'a>, u64),
21}
22
23impl<'a> ConnId<'a> {
24    /// Consume self and return an owned version [`ConnId`] instance.
25    #[inline]
26    pub fn into_owned(self) -> ConnId<'static> {
27        match self {
28            ConnId::Tcp(uuid) => ConnId::Tcp(uuid),
29            ConnId::QuicStream(scid, dcid, stream_id) => {
30                ConnId::QuicStream(scid.into_owned(), dcid.into_owned(), stream_id)
31            }
32        }
33    }
34}
35
36#[derive(Default)]
37struct SessionFlag {
38    closed: Option<io::Result<()>>,
39    waker: Option<Waker>,
40}
41
42/// The session object that represent the inbound connection session which
43/// created by [`handshake`](super::Rproxy::handshake)
44///
45/// Using this object to wait session closed.
46#[derive(Clone)]
47pub struct Session {
48    pub id: ConnId<'static>,
49    flag: Arc<spin_simple::SpinMutex<SessionFlag>>,
50}
51
52impl Session {
53    /// Create new [`Rproxy`](crate::Rproxy) session with provided [`ConnId`]
54    pub fn new(id: ConnId<'static>) -> Self {
55        Self {
56            id,
57            flag: Default::default(),
58        }
59    }
60    /// Notify session closed with [`io::Result`]
61    pub fn closed_with(&self, r: io::Result<()>) {
62        let mut flag = self.flag.lock();
63
64        flag.closed = Some(r);
65
66        if let Some(waker) = flag.waker.take() {
67            waker.wake();
68        }
69    }
70}
71
72impl Future for Session {
73    type Output = io::Result<()>;
74
75    fn poll(
76        self: std::pin::Pin<&mut Self>,
77        cx: &mut std::task::Context<'_>,
78    ) -> std::task::Poll<Self::Output> {
79        let mut flag = self.flag.lock();
80
81        if let Some(r) = flag.closed.take() {
82            Poll::Ready(r)
83        } else {
84            flag.waker = Some(cx.waker().clone());
85            Poll::Pending
86        }
87    }
88}