atm0s_reverse_proxy_relayer/proxy/
tls.rs1use anyhow::anyhow;
2use tls_parser::{parse_tls_extensions, parse_tls_plaintext, TlsExtension, TlsMessage, TlsMessageHandshake};
3use tokio::net::TcpStream;
4
5use super::{ProxyDestination, ProxyDestinationDetector};
6
7#[derive(Debug, Default)]
8pub struct TlsDestinationDetector {
9 service: Option<u16>,
10}
11
12impl TlsDestinationDetector {
13 pub fn custom_service(service: u16) -> Self {
14 Self { service: Some(service) }
15 }
16
17 fn get_domain(&self, packet: &[u8]) -> Option<String> {
18 log::info!("[TlsDomainDetector] check domain for buffer {} bytes", packet.len());
19 let res = match parse_tls_plaintext(packet) {
20 Ok(res) => res,
21 Err(e) => {
22 log::error!("parse_tls_plaintext error {:?}", e);
23 return None;
24 }
25 };
26
27 let tls_message = &res.1.msg[0];
28 if let TlsMessage::Handshake(TlsMessageHandshake::ClientHello(client_hello)) = tls_message {
29 let extensions: &[u8] = client_hello.ext?;
31 let res = match parse_tls_extensions(extensions) {
33 Ok(res) => res,
34 Err(e) => {
35 log::error!("parse_tls_extensions error {:?}", e);
36 return None;
37 }
38 };
39 for extension in res.1 {
41 if let TlsExtension::SNI(sni) = extension {
42 let hostname: &[u8] = sni[0].1;
44 let s: String = match String::from_utf8(hostname.to_vec()) {
45 Ok(v) => v,
46 Err(e) => panic!("Invalid UTF-8 sequence: {}", e),
47 };
48 return Some(s);
49 }
50 }
51 }
52 None
53 }
54}
55
56impl ProxyDestinationDetector for TlsDestinationDetector {
57 async fn determine(&self, stream: &mut TcpStream) -> anyhow::Result<ProxyDestination> {
58 let mut buf = [0; 4096];
59 let buf_len = stream.peek(&mut buf).await?;
60 let domain = self.get_domain(&buf[..buf_len]).ok_or(anyhow!("domain not found"))?;
61 Ok(ProxyDestination {
62 domain,
63 service: self.service,
64 tls: true,
65 })
66 }
67}