rust_govhost/
lib.rs

1use std::io::{self, Cursor, Read, Write};
2use std::net::TcpStream;
3use std::sync::{Arc, Mutex};
4
5
6/// shared conn to fetch sni info and return value is still available
7///
8/// # Example
9/// ```
10/// let tls_conn = ShareConn::new(conn);
11/// let sni = tls_conn.get_sni();
12/// assert!("google.com", sni);
13/// ```
14pub(crate) struct SharedConn {
15    pub stream: TcpStream,
16    buffer: Arc<Mutex<Cursor<Vec<u8>>>>,
17
18    sni: String,
19}
20
21impl SharedConn {
22    pub fn new(mut stream: TcpStream) -> Result<SharedConn, std::io::Error> {
23        let buffer = Arc::new(Mutex::new(Cursor::new(Vec::new())));
24
25        // read tls handshake from stream, and then put data into buffer
26        let mut buf: [u8; 1024] = [0_u8; 1024];
27        let n = stream.read(&mut buf)?;
28        if n > 0 {
29            let mut buffer = buffer.lock().unwrap();
30            buffer.get_mut().extend_from_slice(&buf[..n]);
31        }
32
33        let sni = parse_sni(&buf, n)?;
34
35        Ok(SharedConn {
36            stream,
37            buffer,
38            sni,
39        })
40    }
41
42    pub fn get_sni(&self) -> String {
43        self.sni.clone()
44    }
45}
46
47impl Read for SharedConn {
48    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
49        let mut buffer = self.buffer.lock().unwrap();
50        if buffer.position() < buffer.get_ref().len() as u64 {
51            buffer.read(buf)
52        } else {
53            self.stream.read(buf)
54        }
55    }
56}
57
58impl Write for SharedConn {
59    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
60        self.stream.write(buf)
61    }
62
63    // 实现flush方法
64    fn flush(&mut self) -> io::Result<()> {
65        // 同样,这里简单地将标准输出的缓冲区刷新,实际应用中应根据需要进行操作
66        self.stream.flush()
67    }
68}
69
70fn parse_sni(buf: &[u8], n: usize) -> Result<String, io::Error> {
71    // 提取出 server name
72    if n < 42 {
73        return Err(io::Error::new(
74            io::ErrorKind::Other,
75            "tls handshake is too short",
76        ));
77    }
78
79    let mut m: String = "".to_string();
80
81    //m.vers = (buf[4] << 8 | buf[5]) as u16;
82
83    let session_id_len = buf[43] as usize;
84    if n < 44 + session_id_len {
85        return Err(io::Error::new(
86            io::ErrorKind::Other,
87            "tls handshake is too short",
88        ));
89    }
90
91    let mut cur = 44 + session_id_len;
92    if n < cur + 2 {
93        return Err(io::Error::new(
94            io::ErrorKind::Other,
95            "tls handshake is too short",
96        ));
97    }
98
99    let cipher_suites_len = ((buf[cur] as usize) << 8 | buf[cur + 1] as usize) as usize;
100    if n < cur + 2 + cipher_suites_len {
101        return Err(io::Error::new(
102            io::ErrorKind::Other,
103            "tls handshake is too short",
104        ));
105    }
106    cur = cur + 2 + cipher_suites_len;
107
108    let compression_methods_len = buf[cur] as usize;
109    if n < cur + 3 + cipher_suites_len + compression_methods_len {
110        return Err(io::Error::new(
111            io::ErrorKind::Other,
112            "tls handshake is too short",
113        ));
114    }
115
116    cur = cur + 1 + compression_methods_len;
117
118    let extension_len = (buf[cur] as usize) << 8 | (buf[cur + 1] as usize);
119    if n < cur + extension_len {
120        return Err(io::Error::new(
121            io::ErrorKind::Other,
122            "tls handshake is too short",
123        ));
124    }
125
126    cur = cur + 2;
127
128    let mut ext_cur = 0;
129    while ext_cur < extension_len {
130        let ext_type = (buf[cur] as u16) << 8 | buf[cur + 1] as u16;
131        let ext_len = (buf[cur + 2] as usize) << 8 | buf[cur + 3] as usize;
132        if ext_type == 0 {
133            m = String::from_utf8(buf[cur + 9..cur + 4 + ext_len].to_vec()).unwrap();
134            break;
135        }
136        cur += 4 + ext_len;
137        ext_cur += 4 + ext_len;
138    }
139
140    Ok(m)
141}
142
143// 为上面的代码添加测试
144#[cfg(test)]
145mod tests {
146    use super::*;
147
148    #[test]
149    fn test_parse_sni() {
150        // 监听 443 端口,来获取 tls 握手信息
151        use std::net::TcpListener;
152        let listener = TcpListener::bind("0.0.0.0:443").unwrap();
153        let (stream, _) = listener.accept().unwrap();
154        let tls_conn = SharedConn::new(stream).unwrap();
155        let sni = tls_conn.get_sni();
156        // 添加  assert 确保 sni 为 www.baidu.com
157        assert_eq!(sni, "www.baidu.com");
158
159        // local test curl
160        //  curl -vv --resolve www.baidu.com:443:127.0.0.1 https://www.baidu.com
161    }
162}