hala_rproxy/
rproxy.rs

1use std::{
2    fmt::{Debug, Display},
3    io,
4    sync::{
5        atomic::{AtomicUsize, Ordering},
6        Arc,
7    },
8};
9
10use futures::{AsyncRead, AsyncWrite, Future};
11use hala_future::executor::future_spawn;
12
13use crate::{ConnId, Session};
14
15/// The inbound connection handshaker.
16pub trait Handshaker {
17    type Handshake<'a>: Future<Output = io::Result<Session>> + Send + 'a
18    where
19        Self: 'a;
20    /// Invoke inbound connection handshake processing and returns [`Session`] object.
21    fn handshake<C: AsyncWrite + AsyncRead + Send + 'static>(
22        &self,
23        conn_id: &ConnId<'_>,
24        conn: C,
25    ) -> Self::Handshake<'_>;
26}
27
28/// [Rproxy] listener should implement this trait.
29pub trait Listener {
30    /// Inbound connection type.
31    type Conn: AsyncRead + AsyncWrite + Send + 'static;
32
33    /// Future created by [`accept`](Rproxy::accept)
34    type Accept<'a>: Future<Output = Option<(ConnId<'static>, Self::Conn)>> + 'a
35    where
36        Self: 'a;
37
38    /// Accept next inbound connection.
39    fn accept(&mut self) -> Self::Accept<'_>;
40}
41
42/// The stats of [`Rproxy`], created by [`stats`](Rproxy::stats) fn
43pub struct RproxyStats {
44    pub actived: usize,
45    pub closed: usize,
46}
47
48impl Display for RproxyStats {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        write!(
51            f,
52            "stream reverse proxy: actived={}, closed={}",
53            self.actived, self.closed
54        )
55    }
56}
57
58/// rgnix reverse proxy config.
59pub struct Rproxy<H> {
60    handshaker: Arc<H>,
61    conns: Arc<AtomicUsize>,
62    closed_conns: Arc<AtomicUsize>,
63}
64
65impl<H> Clone for Rproxy<H> {
66    fn clone(&self) -> Self {
67        Self {
68            handshaker: self.handshaker.clone(),
69            conns: self.conns.clone(),
70            closed_conns: self.closed_conns.clone(),
71        }
72    }
73}
74
75impl<H> Rproxy<H>
76where
77    H: Handshaker + Sync + Send + 'static,
78{
79    /// Create new [`Rproxy`] instance.
80    pub fn new(handshaker: H) -> Self {
81        Self {
82            handshaker: Arc::new(handshaker),
83            conns: Default::default(),
84            closed_conns: Default::default(),
85        }
86    }
87
88    /// Invoke inbound connection handshake.
89    pub async fn handshake<C: AsyncWrite + AsyncRead + Send + 'static>(
90        &self,
91        conn_id: &ConnId<'_>,
92        conn: C,
93    ) -> io::Result<()> {
94        let session = self.handshaker.handshake(conn_id, conn).await?;
95
96        self.conns.fetch_add(1, Ordering::Relaxed);
97
98        let r = session.await;
99
100        self.conns.fetch_sub(1, Ordering::Relaxed);
101        self.closed_conns.fetch_add(1, Ordering::Relaxed);
102
103        r
104    }
105    /// Start reverse proxy accept loop.
106    pub async fn accept<L: Listener + Debug>(&self, mut listener: L) {
107        log::debug!(target: "ReverseProxy", "{:?}, start gateway loop", listener);
108
109        while let Some((id, conn)) = listener.accept().await {
110            let this = self.clone();
111
112            // A new task should be started to perform the handshake.
113            // Because the function will not return until this inbound
114            // connection session is closed.
115            future_spawn(async move {
116                match this.handshake(&id, conn).await {
117                    Ok(_) => {
118                        log::debug!(target: "ReverseProxy", "handshake successfully, id={:?}", id);
119                    }
120                    Err(err) => {
121                        log::debug!(target: "ReverseProxy", "handshake error, id={:?}, {}", id, err);
122                    }
123                }
124            });
125        }
126
127        log::debug!(target: "ReverseProxy", "{:?}, stop gateway loop", listener);
128    }
129}
130
131impl<H> Rproxy<H> {
132    /// Get reverse proxy stats.
133    pub fn stats(&self) -> RproxyStats {
134        RproxyStats {
135            actived: self.conns.load(Ordering::Relaxed),
136            closed: self.closed_conns.load(Ordering::Relaxed),
137        }
138    }
139}