atm0s_reverse_proxy_relayer/proxy/
tls.rs

1use anyhow::anyhow;
2use tls_parser::{parse_tls_extensions, parse_tls_plaintext, TlsExtension, TlsMessage, TlsMessageHandshake};
3use tokio::net::TcpStream;
4
5use super::{ProxyDestination, ProxyDestinationDetector};
6
7#[derive(Debug, Default)]
8pub struct TlsDestinationDetector {
9    service: Option<u16>,
10}
11
12impl TlsDestinationDetector {
13    pub fn custom_service(service: u16) -> Self {
14        Self { service: Some(service) }
15    }
16
17    fn get_domain(&self, packet: &[u8]) -> Option<String> {
18        log::info!("[TlsDomainDetector] check domain for buffer {} bytes", packet.len());
19        let res = match parse_tls_plaintext(packet) {
20            Ok(res) => res,
21            Err(e) => {
22                log::error!("parse_tls_plaintext error {:?}", e);
23                return None;
24            }
25        };
26
27        let tls_message = &res.1.msg[0];
28        if let TlsMessage::Handshake(TlsMessageHandshake::ClientHello(client_hello)) = tls_message {
29            // get the extensions
30            let extensions: &[u8] = client_hello.ext?;
31            // parse the extensions
32            let res = match parse_tls_extensions(extensions) {
33                Ok(res) => res,
34                Err(e) => {
35                    log::error!("parse_tls_extensions error {:?}", e);
36                    return None;
37                }
38            };
39            // iterate over the extensions and find the SNI
40            for extension in res.1 {
41                if let TlsExtension::SNI(sni) = extension {
42                    // get the hostname
43                    let hostname: &[u8] = sni[0].1;
44                    let s: String = match String::from_utf8(hostname.to_vec()) {
45                        Ok(v) => v,
46                        Err(e) => panic!("Invalid UTF-8 sequence: {}", e),
47                    };
48                    return Some(s);
49                }
50            }
51        }
52        None
53    }
54}
55
56impl ProxyDestinationDetector for TlsDestinationDetector {
57    async fn determine(&self, stream: &mut TcpStream) -> anyhow::Result<ProxyDestination> {
58        let mut buf = [0; 4096];
59        let buf_len = stream.peek(&mut buf).await?;
60        let domain = self.get_domain(&buf[..buf_len]).ok_or(anyhow!("domain not found"))?;
61        Ok(ProxyDestination {
62            domain,
63            service: self.service,
64            tls: true,
65        })
66    }
67}