Skip to main content

trojan_proto/
lib.rs

1//! Trojan protocol parsing and serialization.
2//!
3//! This module provides zero-copy parsers for trojan request headers and UDP packets.
4//! It is intentionally minimal and DRY: address parsing is shared by request and UDP paths.
5
6use bytes::BytesMut;
7
8pub const HASH_LEN: usize = 56;
9pub const CRLF: &[u8; 2] = b"\r\n";
10
11pub const CMD_CONNECT: u8 = 0x01;
12pub const CMD_UDP_ASSOCIATE: u8 = 0x03;
13/// Mux command for trojan-go multiplexing extension.
14pub const CMD_MUX: u8 = 0x7f;
15
16/// Maximum UDP payload size (8 KiB, consistent with trojan-go).
17pub const MAX_UDP_PAYLOAD: usize = 8 * 1024;
18/// Maximum domain name length.
19pub const MAX_DOMAIN_LEN: usize = 255;
20
21pub const ATYP_IPV4: u8 = 0x01;
22pub const ATYP_DOMAIN: u8 = 0x03;
23pub const ATYP_IPV6: u8 = 0x04;
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum ParseError {
27    InvalidCrlf,
28    InvalidCommand,
29    InvalidAtyp,
30    InvalidDomainLen,
31    InvalidUtf8,
32    /// Hash contains non-hex characters (expected lowercase a-f, 0-9).
33    InvalidHashFormat,
34}
35
36/// Errors that can occur when writing protocol data.
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum WriteError {
39    /// Payload exceeds maximum allowed size (65535 bytes for UDP).
40    PayloadTooLarge,
41    /// Domain name exceeds maximum length (255 bytes).
42    DomainTooLong,
43    /// Hash must be exactly 56 bytes.
44    InvalidHashLen,
45}
46
47/// Parse result for incremental parsing.
48///
49/// - `Complete(T)` - parsing succeeded, contains the parsed value.
50/// - `Incomplete(n)` - buffer too small; `n` is the **minimum total bytes** needed
51///   (not the additional bytes needed). Caller should accumulate more data and retry.
52/// - `Invalid(e)` - protocol violation, connection should be rejected or redirected.
53#[derive(Debug, Clone, PartialEq, Eq)]
54pub enum ParseResult<T> {
55    Complete(T),
56    Incomplete(usize),
57    Invalid(ParseError),
58}
59
60#[derive(Debug, Clone, PartialEq, Eq)]
61pub enum HostRef<'a> {
62    Ipv4([u8; 4]),
63    Ipv6([u8; 16]),
64    Domain(&'a [u8]),
65}
66
67#[derive(Debug, Clone, PartialEq, Eq)]
68pub struct AddressRef<'a> {
69    pub host: HostRef<'a>,
70    pub port: u16,
71}
72
73#[derive(Debug, Clone, PartialEq, Eq)]
74pub struct TrojanRequest<'a> {
75    pub hash: &'a [u8],
76    pub command: u8,
77    pub address: AddressRef<'a>,
78    pub header_len: usize,
79    pub payload: &'a [u8],
80}
81
82#[derive(Debug, Clone, PartialEq, Eq)]
83pub struct UdpPacket<'a> {
84    pub address: AddressRef<'a>,
85    pub length: usize,
86    pub packet_len: usize,
87    pub payload: &'a [u8],
88}
89
90/// Validates that the hash is a valid hex string (a-f, A-F, 0-9).
91#[inline]
92pub fn is_valid_hash(hash: &[u8]) -> bool {
93    hash.len() == HASH_LEN && hash.iter().all(|&b| b.is_ascii_hexdigit())
94}
95
96#[inline]
97pub fn parse_request(buf: &[u8]) -> ParseResult<TrojanRequest<'_>> {
98    if buf.len() < HASH_LEN {
99        return ParseResult::Incomplete(HASH_LEN);
100    }
101
102    let hash = &buf[..HASH_LEN];
103    if !is_valid_hash(hash) {
104        return ParseResult::Invalid(ParseError::InvalidHashFormat);
105    }
106    let mut offset = HASH_LEN;
107
108    if let Some(res) = expect_crlf(buf, offset) {
109        return res;
110    }
111    offset += 2;
112
113    if buf.len() < offset + 2 {
114        return ParseResult::Incomplete(offset + 2);
115    }
116    let command = buf[offset];
117    if command != CMD_CONNECT && command != CMD_UDP_ASSOCIATE {
118        return ParseResult::Invalid(ParseError::InvalidCommand);
119    }
120    let atyp = buf[offset + 1];
121    offset += 2;
122
123    let addr_res = parse_address(atyp, &buf[offset..]);
124    let (address, addr_len) = match addr_res {
125        ParseResult::Complete(v) => v,
126        ParseResult::Incomplete(n) => return ParseResult::Incomplete(offset + n),
127        ParseResult::Invalid(e) => return ParseResult::Invalid(e),
128    };
129    offset += addr_len;
130
131    if let Some(res) = expect_crlf(buf, offset) {
132        return res;
133    }
134    offset += 2;
135
136    ParseResult::Complete(TrojanRequest {
137        hash,
138        command,
139        address,
140        header_len: offset,
141        payload: &buf[offset..],
142    })
143}
144
145#[inline]
146pub fn parse_udp_packet(buf: &[u8]) -> ParseResult<UdpPacket<'_>> {
147    if buf.is_empty() {
148        return ParseResult::Incomplete(1);
149    }
150    let atyp = buf[0];
151    let addr_res = parse_address(atyp, &buf[1..]);
152    let (address, addr_len) = match addr_res {
153        ParseResult::Complete(v) => v,
154        ParseResult::Incomplete(n) => return ParseResult::Incomplete(1 + n),
155        ParseResult::Invalid(e) => return ParseResult::Invalid(e),
156    };
157
158    let mut offset = 1 + addr_len;
159    if buf.len() < offset + 2 {
160        return ParseResult::Incomplete(offset + 2);
161    }
162    let length = read_u16(&buf[offset..offset + 2]) as usize;
163    if buf.len() < offset + 4 {
164        return ParseResult::Incomplete(offset + 4);
165    }
166    if &buf[offset + 2..offset + 4] != CRLF {
167        return ParseResult::Invalid(ParseError::InvalidCrlf);
168    }
169    offset += 4;
170    if buf.len() < offset + length {
171        return ParseResult::Incomplete(offset + length);
172    }
173
174    ParseResult::Complete(UdpPacket {
175        address,
176        length,
177        packet_len: offset + length,
178        payload: &buf[offset..offset + length],
179    })
180}
181
182/// Writes a Trojan request header to the buffer.
183///
184/// # Errors
185/// - `InvalidHashLen` if hash is not exactly 56 bytes.
186/// - `DomainTooLong` if address contains a domain longer than 255 bytes.
187#[allow(clippy::cast_possible_truncation)]
188pub fn write_request_header(
189    buf: &mut BytesMut,
190    hash_hex: &[u8],
191    command: u8,
192    address: &AddressRef<'_>,
193) -> Result<(), WriteError> {
194    if hash_hex.len() != HASH_LEN {
195        return Err(WriteError::InvalidHashLen);
196    }
197    if let HostRef::Domain(d) = &address.host
198        && d.len() > MAX_DOMAIN_LEN
199    {
200        return Err(WriteError::DomainTooLong);
201    }
202    buf.extend_from_slice(hash_hex);
203    buf.extend_from_slice(CRLF);
204    buf.extend_from_slice(&[command, address_atyp(address)]);
205    write_address_unchecked(buf, address);
206    buf.extend_from_slice(CRLF);
207    Ok(())
208}
209
210/// Writes a UDP packet to the buffer.
211///
212/// # Errors
213/// - `PayloadTooLarge` if payload exceeds 65535 bytes.
214/// - `DomainTooLong` if address contains a domain longer than 255 bytes.
215#[allow(clippy::cast_possible_truncation)]
216pub fn write_udp_packet(
217    buf: &mut BytesMut,
218    address: &AddressRef<'_>,
219    payload: &[u8],
220) -> Result<(), WriteError> {
221    if payload.len() > u16::MAX as usize {
222        return Err(WriteError::PayloadTooLarge);
223    }
224    if let HostRef::Domain(d) = &address.host
225        && d.len() > MAX_DOMAIN_LEN
226    {
227        return Err(WriteError::DomainTooLong);
228    }
229    buf.extend_from_slice(&[address_atyp(address)]);
230    write_address_unchecked(buf, address);
231    buf.extend_from_slice(&(payload.len() as u16).to_be_bytes());
232    buf.extend_from_slice(CRLF);
233    buf.extend_from_slice(payload);
234    Ok(())
235}
236
237#[inline]
238fn expect_crlf<T>(buf: &[u8], offset: usize) -> Option<ParseResult<T>> {
239    if buf.len() < offset + 2 {
240        return Some(ParseResult::Incomplete(offset + 2));
241    }
242    if &buf[offset..offset + 2] != CRLF {
243        return Some(ParseResult::Invalid(ParseError::InvalidCrlf));
244    }
245    None
246}
247
248#[inline]
249fn parse_address<'a>(atyp: u8, buf: &'a [u8]) -> ParseResult<(AddressRef<'a>, usize)> {
250    match atyp {
251        ATYP_IPV4 => {
252            if buf.len() < 6 {
253                return ParseResult::Incomplete(6);
254            }
255            let host = HostRef::Ipv4([buf[0], buf[1], buf[2], buf[3]]);
256            let port = read_u16(&buf[4..6]);
257            ParseResult::Complete((AddressRef { host, port }, 6))
258        }
259        ATYP_DOMAIN => {
260            if buf.is_empty() {
261                return ParseResult::Incomplete(1);
262            }
263            let len = buf[0] as usize;
264            if len == 0 {
265                return ParseResult::Invalid(ParseError::InvalidDomainLen);
266            }
267            let need = 1 + len + 2;
268            if buf.len() < need {
269                return ParseResult::Incomplete(need);
270            }
271            let domain = &buf[1..1 + len];
272            if std::str::from_utf8(domain).is_err() {
273                return ParseResult::Invalid(ParseError::InvalidUtf8);
274            }
275            let port = read_u16(&buf[1 + len..1 + len + 2]);
276            ParseResult::Complete((
277                AddressRef {
278                    host: HostRef::Domain(domain),
279                    port,
280                },
281                need,
282            ))
283        }
284        ATYP_IPV6 => {
285            if buf.len() < 18 {
286                return ParseResult::Incomplete(18);
287            }
288            let mut ip = [0u8; 16];
289            ip.copy_from_slice(&buf[0..16]);
290            let port = read_u16(&buf[16..18]);
291            ParseResult::Complete((
292                AddressRef {
293                    host: HostRef::Ipv6(ip),
294                    port,
295                },
296                18,
297            ))
298        }
299        _ => ParseResult::Invalid(ParseError::InvalidAtyp),
300    }
301}
302
303/// Writes address without validation. Caller must ensure domain length <= 255.
304#[allow(clippy::cast_possible_truncation)]
305fn write_address_unchecked(buf: &mut BytesMut, address: &AddressRef<'_>) {
306    match address.host {
307        HostRef::Ipv4(ip) => {
308            buf.extend_from_slice(&ip);
309        }
310        HostRef::Ipv6(ip) => {
311            buf.extend_from_slice(&ip);
312        }
313        HostRef::Domain(domain) => {
314            debug_assert!(domain.len() <= MAX_DOMAIN_LEN);
315            buf.extend_from_slice(&[domain.len() as u8]);
316            buf.extend_from_slice(domain);
317        }
318    }
319    buf.extend_from_slice(&address.port.to_be_bytes());
320}
321
322#[inline]
323fn address_atyp(address: &AddressRef<'_>) -> u8 {
324    match address.host {
325        HostRef::Ipv4(_) => ATYP_IPV4,
326        HostRef::Ipv6(_) => ATYP_IPV6,
327        HostRef::Domain(_) => ATYP_DOMAIN,
328    }
329}
330
331#[inline]
332fn read_u16(buf: &[u8]) -> u16 {
333    debug_assert!(buf.len() >= 2, "read_u16 requires at least 2 bytes");
334    u16::from_be_bytes([buf[0], buf[1]])
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340
341    fn sample_hash() -> [u8; HASH_LEN] {
342        [b'a'; HASH_LEN]
343    }
344
345    #[test]
346    fn test_is_valid_hash() {
347        // Valid lowercase hex
348        assert!(is_valid_hash(&[b'a'; HASH_LEN]));
349        assert!(is_valid_hash(
350            b"0123456789abcdef0123456789abcdef0123456789abcdef01234567"
351        ));
352
353        // Uppercase should be accepted
354        assert!(is_valid_hash(
355            b"0123456789ABCDEF0123456789abcdef0123456789abcdef01234567"
356        ));
357
358        // Invalid: wrong length
359        assert!(!is_valid_hash(&[b'a'; HASH_LEN - 1]));
360        assert!(!is_valid_hash(&[b'a'; HASH_LEN + 1]));
361
362        // Invalid: non-hex characters
363        let mut invalid = [b'a'; HASH_LEN];
364        invalid[0] = b'g';
365        assert!(!is_valid_hash(&invalid));
366    }
367
368    #[test]
369    fn parse_request_connect_ipv4() {
370        let addr = AddressRef {
371            host: HostRef::Ipv4([1, 2, 3, 4]),
372            port: 443,
373        };
374        let mut buf = BytesMut::new();
375        write_request_header(&mut buf, &sample_hash(), CMD_CONNECT, &addr).unwrap();
376        buf.extend_from_slice(b"hello");
377
378        let res = parse_request(&buf);
379        match res {
380            ParseResult::Complete(req) => {
381                assert_eq!(req.command, CMD_CONNECT);
382                assert_eq!(req.address, addr);
383                assert_eq!(req.payload, b"hello");
384            }
385            _ => panic!("unexpected parse result: {:?}", res),
386        }
387    }
388
389    #[test]
390    fn parse_request_invalid_hash() {
391        let addr = AddressRef {
392            host: HostRef::Ipv4([1, 2, 3, 4]),
393            port: 443,
394        };
395        let mut buf = BytesMut::new();
396        // Use non-hex which is invalid
397        let mut invalid_hash = [b'a'; HASH_LEN];
398        invalid_hash[0] = b'g';
399        write_request_header(&mut buf, &invalid_hash, CMD_CONNECT, &addr).unwrap();
400
401        let res = parse_request(&buf);
402        assert_eq!(res, ParseResult::Invalid(ParseError::InvalidHashFormat));
403    }
404
405    #[test]
406    fn parse_request_incomplete() {
407        let data = vec![b'a'; HASH_LEN - 1];
408        assert_eq!(parse_request(&data), ParseResult::Incomplete(HASH_LEN));
409    }
410
411    #[test]
412    fn parse_udp_packet_ipv4() {
413        let addr = AddressRef {
414            host: HostRef::Ipv4([8, 8, 8, 8]),
415            port: 53,
416        };
417        let mut buf = BytesMut::new();
418        write_udp_packet(&mut buf, &addr, b"ping").unwrap();
419        let res = parse_udp_packet(&buf);
420        match res {
421            ParseResult::Complete(pkt) => {
422                assert_eq!(pkt.address, addr);
423                assert_eq!(pkt.payload, b"ping");
424            }
425            _ => panic!("unexpected parse result: {:?}", res),
426        }
427    }
428
429    #[test]
430    fn write_udp_packet_payload_too_large() {
431        let addr = AddressRef {
432            host: HostRef::Ipv4([8, 8, 8, 8]),
433            port: 53,
434        };
435        let mut buf = BytesMut::new();
436        let large_payload = vec![0u8; u16::MAX as usize + 1];
437        let res = write_udp_packet(&mut buf, &addr, &large_payload);
438        assert_eq!(res, Err(WriteError::PayloadTooLarge));
439    }
440
441    #[test]
442    fn write_request_header_domain_too_long() {
443        let long_domain = vec![b'a'; 256];
444        let addr = AddressRef {
445            host: HostRef::Domain(&long_domain),
446            port: 443,
447        };
448        let mut buf = BytesMut::new();
449        let res = write_request_header(&mut buf, &sample_hash(), CMD_CONNECT, &addr);
450        assert_eq!(res, Err(WriteError::DomainTooLong));
451    }
452
453    #[test]
454    fn write_request_header_invalid_hash_len() {
455        let addr = AddressRef {
456            host: HostRef::Ipv4([1, 2, 3, 4]),
457            port: 443,
458        };
459        let mut buf = BytesMut::new();
460        let short_hash = [b'a'; HASH_LEN - 1];
461        let res = write_request_header(&mut buf, &short_hash, CMD_CONNECT, &addr);
462        assert_eq!(res, Err(WriteError::InvalidHashLen));
463    }
464}