1use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
2
3use anyhow::{Context, anyhow};
4use bytes::{BufMut, BytesMut};
5use socks5_proto::Address;
6use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
7
8use crate::{CRLF, stream::peekable::AsyncPeek};
9
10use super::padding::Padding;
11
12#[derive(Clone, Debug)]
31pub struct Request {
33 pub hash: String,
34 pub command: Command,
35 pub address: Address,
36}
37
38impl Request {
39 const ATYP_IPV4: u8 = 0x01;
40 const ATYP_FQDN: u8 = 0x03;
41 const ATYP_IPV6: u8 = 0x04;
42
43 pub const fn new(hash: String, command: Command, address: Address) -> Self {
45 Self {
46 hash,
47 command,
48 address,
49 }
50 }
51
52 pub async fn peek_head<T>(r: &mut T) -> anyhow::Result<Vec<u8>>
54 where
55 T: AsyncRead + AsyncPeek + Unpin,
56 {
57 let mut buf = Vec::new();
58 for _i in 0..56 {
59 let b1 = r.peek_u8().await.context("trojan peek u8 failed")?;
60 if b1 == b'\r' {
61 let b2 = r.peek_u8().await.context("trojan peek u8 failed")?;
62 if b2 == b'\n' {
63 buf.push(b1);
64 buf.push(b2);
65 break;
66 }
67 } else {
68 buf.push(b1);
69 }
70 }
71
72 Ok(buf)
73 }
74
75 pub async fn read_from<R>(r: &mut R) -> anyhow::Result<Self>
77 where
78 R: AsyncRead + Unpin,
79 {
80 let mut buf: [u8; 56] = [0; 56];
81 let len = r
82 .read(&mut buf[..])
83 .await
84 .context("trojan read hash failed")?;
85 if len != 56 {
86 return Err(anyhow!("the Request not Trojan"));
87 }
88
89 let hash = String::from_utf8_lossy(&buf[..]).to_string();
90
91 let _crlf = r.read_u16().await?;
92
93 let (cmd, addr) = Self::read_address_from(r)
94 .await
95 .context("trojan read Address failed")?;
96
97 let _crlf = r.read_u16().await?;
98
99 if let Command::Padding = cmd {
100 let _padding = Padding::read_from(r)
101 .await
102 .context("trojan read padding failed")?;
103 }
104
105 Ok(Self::new(hash, cmd, addr))
106 }
107
108 pub async fn read_address_from<R>(r: &mut R) -> anyhow::Result<(Command, Address)>
110 where
111 R: AsyncRead + Unpin,
112 {
113 let cmd = r.read_u8().await.context("address read cmd failed")?;
114 let cmd = Command::try_from(cmd).map_err(|cmd| anyhow!("Unknown cmd {cmd}"))?;
115
116 let atyp = r.read_u8().await.context("address read atyp failed")?;
117
118 match atyp {
119 Self::ATYP_IPV4 => {
120 let mut buf = [0; 6];
121 r.read_exact(&mut buf)
122 .await
123 .context("address read ipv4 failed")?;
124
125 let addr = Ipv4Addr::new(buf[0], buf[1], buf[2], buf[3]);
126
127 let port = u16::from_be_bytes([buf[4], buf[5]]);
128
129 let addr = Address::SocketAddress(SocketAddr::from((addr, port)));
130 Ok((cmd, addr))
131 }
132 Self::ATYP_FQDN => {
133 let len = r.read_u8().await? as usize;
134
135 let mut buf = vec![0; len + 2];
136 r.read_exact(&mut buf)
137 .await
138 .context("address read domain failed")?;
139
140 let port = u16::from_be_bytes([buf[len], buf[len + 1]]);
141 buf.truncate(len);
142
143 let addr = Address::DomainAddress(buf, port);
144 Ok((cmd, addr))
145 }
146 Self::ATYP_IPV6 => {
147 let mut buf = [0; 18];
148 r.read_exact(&mut buf)
149 .await
150 .context("address read ipv6 failed")?;
151
152 let addr = Ipv6Addr::new(
153 u16::from_be_bytes([buf[0], buf[1]]),
154 u16::from_be_bytes([buf[2], buf[3]]),
155 u16::from_be_bytes([buf[4], buf[5]]),
156 u16::from_be_bytes([buf[6], buf[7]]),
157 u16::from_be_bytes([buf[8], buf[9]]),
158 u16::from_be_bytes([buf[10], buf[11]]),
159 u16::from_be_bytes([buf[12], buf[13]]),
160 u16::from_be_bytes([buf[14], buf[15]]),
161 );
162
163 let port = u16::from_be_bytes([buf[16], buf[17]]);
164
165 let addr = Address::SocketAddress(SocketAddr::from((addr, port)));
166 Ok((cmd, addr))
167 }
168 atyp => Err(anyhow!("invalid type {atyp}")),
169 }
170 }
171
172 pub async fn write_to<W>(&self, w: &mut W) -> anyhow::Result<()>
174 where
175 W: AsyncWrite + Unpin,
176 {
177 let mut buf = BytesMut::with_capacity(self.serialized_len());
178 self.write_to_buf(&mut buf);
179 w.write_all(&buf).await.context("trojan Write buf failed")?;
180
181 Ok(())
182 }
183
184 pub fn write_to_buf<B: BufMut>(&self, buf: &mut B) {
186 buf.put_slice(self.hash.as_bytes());
187 buf.put_slice(&CRLF);
188 buf.put_u8(u8::from(self.command));
189 self.write_to_buf_address(buf);
190 buf.put_slice(&CRLF);
191 if self.is_padding() {
192 Padding::default().write_to_buf(buf)
193 }
194 }
195 pub fn write_to_buf_address<B: BufMut>(&self, buf: &mut B) {
196 match &self.address {
197 Address::SocketAddress(SocketAddr::V4(addr)) => {
198 buf.put_u8(Self::ATYP_IPV4);
199 buf.put_slice(&addr.ip().octets());
200 buf.put_u16(addr.port());
201 }
202 Address::SocketAddress(SocketAddr::V6(addr)) => {
203 buf.put_u8(Self::ATYP_IPV6);
204 for seg in addr.ip().segments() {
205 buf.put_u16(seg);
206 }
207 buf.put_u16(addr.port());
208 }
209 Address::DomainAddress(addr, port) => {
210 buf.put_u8(Self::ATYP_FQDN);
211 buf.put_u8(addr.len() as u8);
212 buf.put_slice(addr);
213 buf.put_u16(*port);
214 }
215 }
216 }
217
218 pub fn serialized_len(&self) -> usize {
220 56 + 2 + 1 + self.address.serialized_len() + 2
221 }
222
223 pub fn is_padding(&self) -> bool {
225 Command::Padding == self.command
226 }
227}
228
229#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
231pub enum Command {
233 Connect,
234 Bind,
235 Associate,
236 Padding,
237}
238
239impl Command {
240 const CONNECT: u8 = 0x01;
241 const BIND: u8 = 0x02;
242 const ASSOCIATE: u8 = 0x03;
243 const PADDING: u8 = 0x04;
244}
245
246impl TryFrom<u8> for Command {
247 type Error = u8;
248
249 fn try_from(code: u8) -> Result<Self, Self::Error> {
250 match code {
251 Self::CONNECT => Ok(Self::Connect),
252 Self::BIND => Ok(Self::Bind),
253 Self::ASSOCIATE => Ok(Self::Associate),
254 Self::PADDING => Ok(Self::Padding),
255 code => Err(code),
256 }
257 }
258}
259
260impl From<Command> for u8 {
261 fn from(cmd: Command) -> Self {
262 match cmd {
263 Command::Connect => Command::CONNECT,
264 Command::Bind => Command::BIND,
265 Command::Associate => Command::ASSOCIATE,
266 Command::Padding => Command::PADDING,
267 }
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use std::{
274 io::Cursor,
275 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
276 };
277
278 use bytes::BytesMut;
279 use socks5_proto::Address;
280 use tokio::io::AsyncReadExt;
281
282 use super::{Command, Request};
283 use crate::stream::peekable::{AsyncPeek, PeekableStream};
284
285 fn test_hash() -> String {
286 "a".repeat(56)
287 }
288
289 #[tokio::test]
290 async fn peek_head_reads_hash_prefix_without_consuming_stream() {
291 let payload = format!("{}\r\nrest", test_hash()).into_bytes();
292 let inner = Cursor::new(payload.clone());
293 let mut stream = PeekableStream::new(inner);
294
295 let head = Request::peek_head(&mut stream).await.unwrap();
296 let drained = stream.drain().unwrap();
297 let mut replay = Vec::new();
298 stream.read_to_end(&mut replay).await.unwrap();
299
300 assert_eq!(head, test_hash().into_bytes());
301 assert_eq!(drained, head);
302 assert_eq!(replay, b"\r\nrest");
303 }
304
305 #[tokio::test]
306 async fn read_address_from_parses_ipv4_domain_and_ipv6() {
307 let mut ipv4 = Cursor::new(vec![1, 1, 127, 0, 0, 1, 0x01, 0xbb]);
308 let mut domain = Cursor::new(vec![
309 4, 3, 11, b'e', b'x', b'a', b'm', b'p', b'l', b'e', b'.', b'c', b'o', b'm', 0, 80,
310 ]);
311 let mut ipv6 = Cursor::new(vec![
312 2, 4, 0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0x01, 0xbb,
313 ]);
314
315 let (cmd4, addr4) = Request::read_address_from(&mut ipv4).await.unwrap();
316 let (cmdd, addrd) = Request::read_address_from(&mut domain).await.unwrap();
317 let (cmd6, addr6) = Request::read_address_from(&mut ipv6).await.unwrap();
318
319 assert_eq!(cmd4, Command::Connect);
320 assert_eq!(
321 addr4,
322 Address::SocketAddress(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443))
323 );
324 assert_eq!(cmdd, Command::Padding);
325 assert_eq!(addrd, Address::DomainAddress(b"example.com".to_vec(), 80));
326 assert_eq!(cmd6, Command::Bind);
327 assert_eq!(
328 addr6,
329 Address::SocketAddress(SocketAddr::new(
330 IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)),
331 443
332 ))
333 );
334 }
335
336 #[tokio::test]
337 async fn read_from_round_trips_non_padding_request() {
338 let request = Request::new(
339 test_hash(),
340 Command::Associate,
341 Address::DomainAddress(b"example.com".to_vec(), 8080),
342 );
343 let mut buf = BytesMut::new();
344 request.write_to_buf(&mut buf);
345
346 let parsed = Request::read_from(&mut Cursor::new(buf.to_vec()))
347 .await
348 .unwrap();
349
350 assert_eq!(parsed.hash, request.hash);
351 assert_eq!(parsed.command, request.command);
352 assert_eq!(parsed.address, request.address);
353 }
354
355 #[tokio::test]
356 async fn read_from_accepts_padding_request() {
357 let request = Request::new(
358 test_hash(),
359 Command::Padding,
360 Address::DomainAddress(b"example.com".to_vec(), 443),
361 );
362 let mut buf = BytesMut::new();
363 request.write_to_buf(&mut buf);
364
365 let parsed = Request::read_from(&mut Cursor::new(buf.to_vec()))
366 .await
367 .unwrap();
368
369 assert_eq!(parsed.hash, request.hash);
370 assert_eq!(parsed.command, Command::Padding);
371 assert_eq!(parsed.address, request.address);
372 }
373
374 #[tokio::test]
375 async fn read_from_rejects_short_hash_prefix() {
376 let err = Request::read_from(&mut Cursor::new(vec![b'a'; 10]))
377 .await
378 .unwrap_err();
379
380 assert!(err.to_string().contains("not Trojan"));
381 }
382}