atm0s_reverse_proxy_relayer/
proxy.rs

1use std::{net::SocketAddr, sync::Arc};
2
3use protocol::proxy::ProxyDestination;
4use tokio::{
5    net::{TcpListener, TcpStream},
6    select,
7    sync::mpsc::{channel, Receiver, Sender},
8};
9
10pub mod http;
11pub mod rtsp;
12pub mod tls;
13
14pub trait ProxyDestinationDetector: Send + Sync + 'static {
15    fn determine(&self, stream: &mut TcpStream) -> impl std::future::Future<Output = anyhow::Result<ProxyDestination>> + Send;
16}
17
18pub struct ProxyTcpListener<Detector> {
19    listener: TcpListener,
20    detector: Arc<Detector>,
21    rx: Receiver<(ProxyDestination, TcpStream)>,
22    tx: Sender<(ProxyDestination, TcpStream)>,
23}
24
25impl<Detector: ProxyDestinationDetector> ProxyTcpListener<Detector> {
26    pub async fn new(addr: SocketAddr, detector: Detector) -> anyhow::Result<Self> {
27        let (tx, rx) = channel(10);
28        Ok(Self {
29            detector: detector.into(),
30            listener: TcpListener::bind(addr).await?,
31            tx,
32            rx,
33        })
34    }
35
36    pub async fn recv(&mut self) -> anyhow::Result<(ProxyDestination, TcpStream)> {
37        loop {
38            select! {
39                event = self.listener.accept() => {
40                    let (mut stream, remote) = event?;
41                    let tx = self.tx.clone();
42                    let detector = self.detector.clone();
43                    tokio::spawn(async move {
44                        match detector.determine(&mut stream).await {
45                            Ok(destination) => {
46                                log::info!("[ProxyTcpListener] determine destination {}, service {:?} for remote {remote}", destination.domain, destination.service);
47                                tx.send((destination, stream)).await.expect("tcp listener channel should work");
48                            },
49                            Err(err) => {
50                                log::info!("[ProxyTcpListener] determine destination for {remote} error {err}");
51                            },
52                        }
53                    });
54                },
55                out = self.rx.recv() => {
56                    break Ok(out.expect("tcp listener channel should work"))
57                }
58            }
59        }
60    }
61}