Skip to main content

borer_core/proto/
trojan.rs

1use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
2
3use anyhow::{Context, anyhow};
4use bytes::{BufMut, BytesMut};
5use socks5_proto::Address;
6use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
7
8use crate::{CRLF, stream::peekable::AsyncPeek};
9
10use super::padding::Padding;
11
12/// Trojan request
13///
14/// ```plain
15/// +-----------------------+---------+----------------+---------+----------+
16/// | hex(SHA224(password)) |  CRLF   | Trojan Request |  CRLF   | Payload  |
17/// +-----------------------+---------+----------------+---------+----------+
18/// |          56           | X'0D0A' |    Variable    | X'0D0A' | Variable |
19/// +-----------------------+---------+----------------+---------+----------+
20///
21/// where Trojan Request is a SOCKS5-like request:
22///
23/// +-----+------+----------+----------+
24/// | CMD | ATYP | DST.ADDR | DST.PORT |
25/// +-----+------+----------+----------+
26/// |  1  |  1   | Variable |    2     |
27/// +-----+------+----------+----------+
28///
29/// ```
30#[derive(Clone, Debug)]
31/// Parsed Trojan request header and destination metadata.
32pub struct Request {
33    pub hash: String,
34    pub command: Command,
35    pub address: Address,
36}
37
38impl Request {
39    const ATYP_IPV4: u8 = 0x01;
40    const ATYP_FQDN: u8 = 0x03;
41    const ATYP_IPV6: u8 = 0x04;
42
43    /// Create a Trojan request value from its already-decoded parts.
44    pub const fn new(hash: String, command: Command, address: Address) -> Self {
45        Self {
46            hash,
47            command,
48            address,
49        }
50    }
51
52    /// Peek the Trojan hash prefix without consuming the underlying stream.
53    pub async fn peek_head<T>(r: &mut T) -> anyhow::Result<Vec<u8>>
54    where
55        T: AsyncRead + AsyncPeek + Unpin,
56    {
57        let mut buf = Vec::new();
58        for _i in 0..56 {
59            let b1 = r.peek_u8().await.context("trojan peek u8 failed")?;
60            if b1 == b'\r' {
61                let b2 = r.peek_u8().await.context("trojan peek u8 failed")?;
62                if b2 == b'\n' {
63                    buf.push(b1);
64                    buf.push(b2);
65                    break;
66                }
67            } else {
68                buf.push(b1);
69            }
70        }
71
72        Ok(buf)
73    }
74
75    /// Read a full Trojan request from the stream.
76    pub async fn read_from<R>(r: &mut R) -> anyhow::Result<Self>
77    where
78        R: AsyncRead + Unpin,
79    {
80        let mut buf: [u8; 56] = [0; 56];
81        let len = r
82            .read(&mut buf[..])
83            .await
84            .context("trojan read hash failed")?;
85        if len != 56 {
86            return Err(anyhow!("the Request not Trojan"));
87        }
88
89        let hash = String::from_utf8_lossy(&buf[..]).to_string();
90
91        let _crlf = r.read_u16().await?;
92
93        let (cmd, addr) = Self::read_address_from(r)
94            .await
95            .context("trojan read Address failed")?;
96
97        let _crlf = r.read_u16().await?;
98
99        if let Command::Padding = cmd {
100            let _padding = Padding::read_from(r)
101                .await
102                .context("trojan read padding failed")?;
103        }
104
105        Ok(Self::new(hash, cmd, addr))
106    }
107
108    /// Read the Trojan command and destination address from the stream.
109    pub async fn read_address_from<R>(r: &mut R) -> anyhow::Result<(Command, Address)>
110    where
111        R: AsyncRead + Unpin,
112    {
113        let cmd = r.read_u8().await.context("address read cmd failed")?;
114        let cmd = Command::try_from(cmd).map_err(|cmd| anyhow!("Unknown cmd {cmd}"))?;
115
116        let atyp = r.read_u8().await.context("address read atyp failed")?;
117
118        match atyp {
119            Self::ATYP_IPV4 => {
120                let mut buf = [0; 6];
121                r.read_exact(&mut buf)
122                    .await
123                    .context("address read ipv4 failed")?;
124
125                let addr = Ipv4Addr::new(buf[0], buf[1], buf[2], buf[3]);
126
127                let port = u16::from_be_bytes([buf[4], buf[5]]);
128
129                let addr = Address::SocketAddress(SocketAddr::from((addr, port)));
130                Ok((cmd, addr))
131            }
132            Self::ATYP_FQDN => {
133                let len = r.read_u8().await? as usize;
134
135                let mut buf = vec![0; len + 2];
136                r.read_exact(&mut buf)
137                    .await
138                    .context("address read domain failed")?;
139
140                let port = u16::from_be_bytes([buf[len], buf[len + 1]]);
141                buf.truncate(len);
142
143                let addr = Address::DomainAddress(buf, port);
144                Ok((cmd, addr))
145            }
146            Self::ATYP_IPV6 => {
147                let mut buf = [0; 18];
148                r.read_exact(&mut buf)
149                    .await
150                    .context("address read ipv6 failed")?;
151
152                let addr = Ipv6Addr::new(
153                    u16::from_be_bytes([buf[0], buf[1]]),
154                    u16::from_be_bytes([buf[2], buf[3]]),
155                    u16::from_be_bytes([buf[4], buf[5]]),
156                    u16::from_be_bytes([buf[6], buf[7]]),
157                    u16::from_be_bytes([buf[8], buf[9]]),
158                    u16::from_be_bytes([buf[10], buf[11]]),
159                    u16::from_be_bytes([buf[12], buf[13]]),
160                    u16::from_be_bytes([buf[14], buf[15]]),
161                );
162
163                let port = u16::from_be_bytes([buf[16], buf[17]]);
164
165                let addr = Address::SocketAddress(SocketAddr::from((addr, port)));
166                Ok((cmd, addr))
167            }
168            atyp => Err(anyhow!("invalid type {atyp}")),
169        }
170    }
171
172    /// Serialize this request to the stream.
173    pub async fn write_to<W>(&self, w: &mut W) -> anyhow::Result<()>
174    where
175        W: AsyncWrite + Unpin,
176    {
177        let mut buf = BytesMut::with_capacity(self.serialized_len());
178        self.write_to_buf(&mut buf);
179        w.write_all(&buf).await.context("trojan Write buf failed")?;
180
181        Ok(())
182    }
183
184    /// Serialize this request into an existing byte buffer.
185    pub fn write_to_buf<B: BufMut>(&self, buf: &mut B) {
186        buf.put_slice(self.hash.as_bytes());
187        buf.put_slice(&CRLF);
188        buf.put_u8(u8::from(self.command));
189        self.write_to_buf_address(buf);
190        buf.put_slice(&CRLF);
191        if self.is_padding() {
192            Padding::default().write_to_buf(buf)
193        }
194    }
195    pub fn write_to_buf_address<B: BufMut>(&self, buf: &mut B) {
196        match &self.address {
197            Address::SocketAddress(SocketAddr::V4(addr)) => {
198                buf.put_u8(Self::ATYP_IPV4);
199                buf.put_slice(&addr.ip().octets());
200                buf.put_u16(addr.port());
201            }
202            Address::SocketAddress(SocketAddr::V6(addr)) => {
203                buf.put_u8(Self::ATYP_IPV6);
204                for seg in addr.ip().segments() {
205                    buf.put_u16(seg);
206                }
207                buf.put_u16(addr.port());
208            }
209            Address::DomainAddress(addr, port) => {
210                buf.put_u8(Self::ATYP_FQDN);
211                buf.put_u8(addr.len() as u8);
212                buf.put_slice(addr);
213                buf.put_u16(*port);
214            }
215        }
216    }
217
218    /// Return the serialized length of the Trojan request header.
219    pub fn serialized_len(&self) -> usize {
220        56 + 2 + 1 + self.address.serialized_len() + 2
221    }
222
223    /// Whether this request uses the Trojan padding command.
224    pub fn is_padding(&self) -> bool {
225        Command::Padding == self.command
226    }
227}
228
229/// SOCKS5 command
230#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
231/// Trojan command code.
232pub enum Command {
233    Connect,
234    Bind,
235    Associate,
236    Padding,
237}
238
239impl Command {
240    const CONNECT: u8 = 0x01;
241    const BIND: u8 = 0x02;
242    const ASSOCIATE: u8 = 0x03;
243    const PADDING: u8 = 0x04;
244}
245
246impl TryFrom<u8> for Command {
247    type Error = u8;
248
249    fn try_from(code: u8) -> Result<Self, Self::Error> {
250        match code {
251            Self::CONNECT => Ok(Self::Connect),
252            Self::BIND => Ok(Self::Bind),
253            Self::ASSOCIATE => Ok(Self::Associate),
254            Self::PADDING => Ok(Self::Padding),
255            code => Err(code),
256        }
257    }
258}
259
260impl From<Command> for u8 {
261    fn from(cmd: Command) -> Self {
262        match cmd {
263            Command::Connect => Command::CONNECT,
264            Command::Bind => Command::BIND,
265            Command::Associate => Command::ASSOCIATE,
266            Command::Padding => Command::PADDING,
267        }
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use std::{
274        io::Cursor,
275        net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
276    };
277
278    use bytes::BytesMut;
279    use socks5_proto::Address;
280    use tokio::io::AsyncReadExt;
281
282    use super::{Command, Request};
283    use crate::stream::peekable::{AsyncPeek, PeekableStream};
284
285    fn test_hash() -> String {
286        "a".repeat(56)
287    }
288
289    #[tokio::test]
290    async fn peek_head_reads_hash_prefix_without_consuming_stream() {
291        let payload = format!("{}\r\nrest", test_hash()).into_bytes();
292        let inner = Cursor::new(payload.clone());
293        let mut stream = PeekableStream::new(inner);
294
295        let head = Request::peek_head(&mut stream).await.unwrap();
296        let drained = stream.drain().unwrap();
297        let mut replay = Vec::new();
298        stream.read_to_end(&mut replay).await.unwrap();
299
300        assert_eq!(head, test_hash().into_bytes());
301        assert_eq!(drained, head);
302        assert_eq!(replay, b"\r\nrest");
303    }
304
305    #[tokio::test]
306    async fn read_address_from_parses_ipv4_domain_and_ipv6() {
307        let mut ipv4 = Cursor::new(vec![1, 1, 127, 0, 0, 1, 0x01, 0xbb]);
308        let mut domain = Cursor::new(vec![
309            4, 3, 11, b'e', b'x', b'a', b'm', b'p', b'l', b'e', b'.', b'c', b'o', b'm', 0, 80,
310        ]);
311        let mut ipv6 = Cursor::new(vec![
312            2, 4, 0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0x01, 0xbb,
313        ]);
314
315        let (cmd4, addr4) = Request::read_address_from(&mut ipv4).await.unwrap();
316        let (cmdd, addrd) = Request::read_address_from(&mut domain).await.unwrap();
317        let (cmd6, addr6) = Request::read_address_from(&mut ipv6).await.unwrap();
318
319        assert_eq!(cmd4, Command::Connect);
320        assert_eq!(
321            addr4,
322            Address::SocketAddress(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443))
323        );
324        assert_eq!(cmdd, Command::Padding);
325        assert_eq!(addrd, Address::DomainAddress(b"example.com".to_vec(), 80));
326        assert_eq!(cmd6, Command::Bind);
327        assert_eq!(
328            addr6,
329            Address::SocketAddress(SocketAddr::new(
330                IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)),
331                443
332            ))
333        );
334    }
335
336    #[tokio::test]
337    async fn read_from_round_trips_non_padding_request() {
338        let request = Request::new(
339            test_hash(),
340            Command::Associate,
341            Address::DomainAddress(b"example.com".to_vec(), 8080),
342        );
343        let mut buf = BytesMut::new();
344        request.write_to_buf(&mut buf);
345
346        let parsed = Request::read_from(&mut Cursor::new(buf.to_vec()))
347            .await
348            .unwrap();
349
350        assert_eq!(parsed.hash, request.hash);
351        assert_eq!(parsed.command, request.command);
352        assert_eq!(parsed.address, request.address);
353    }
354
355    #[tokio::test]
356    async fn read_from_accepts_padding_request() {
357        let request = Request::new(
358            test_hash(),
359            Command::Padding,
360            Address::DomainAddress(b"example.com".to_vec(), 443),
361        );
362        let mut buf = BytesMut::new();
363        request.write_to_buf(&mut buf);
364
365        let parsed = Request::read_from(&mut Cursor::new(buf.to_vec()))
366            .await
367            .unwrap();
368
369        assert_eq!(parsed.hash, request.hash);
370        assert_eq!(parsed.command, Command::Padding);
371        assert_eq!(parsed.address, request.address);
372    }
373
374    #[tokio::test]
375    async fn read_from_rejects_short_hash_prefix() {
376        let err = Request::read_from(&mut Cursor::new(vec![b'a'; 10]))
377            .await
378            .unwrap_err();
379
380        assert!(err.to_string().contains("not Trojan"));
381    }
382}