1use std::sync::Arc;
2use std::time::{Duration, Instant};
3
4use anyhow::Context;
5use async_trait::async_trait;
6use log::{debug, info};
7use socks5_proto::Address;
8use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
9use tokio::net::TcpStream;
10use tokio::time::timeout;
11
12use crate::address_list::{DirectList, ProxyList};
13use crate::proto::padding::Padding;
14use crate::proto::trojan;
15use crate::proto::trojan::Command;
16use crate::tls::make_server_name;
17use crate::tls::make_tls_connector;
18
19pub trait AsyncStream: AsyncRead + AsyncWrite + Unpin + Send {}
20impl<T: AsyncRead + AsyncWrite + Unpin + Send> AsyncStream for T {}
21
22#[async_trait]
23pub trait Dial: Send + Sync {
24 async fn dial(&self, addr: Address) -> anyhow::Result<Box<dyn AsyncStream>>;
25}
26
27pub struct DirectDial {
28 connect_timeout: Duration,
29}
30
31impl DirectDial {
32 pub fn new(connect_timeout: Duration) -> Self {
33 Self { connect_timeout }
34 }
35}
36
37#[async_trait]
38impl Dial for DirectDial {
39 async fn dial(&self, addr: Address) -> anyhow::Result<Box<dyn AsyncStream>> {
40 let stream: TcpStream = match addr {
41 Address::DomainAddress(domain, port) => {
42 let domain = String::from_utf8_lossy(&domain);
43 timeout(
44 self.connect_timeout,
45 TcpStream::connect((domain.as_ref(), port)),
46 )
47 .await
48 .context(format!("connect {}:{} timeout", domain, port))?
49 .context(format!("connect {}:{} failed", domain, port))
50 }
51 Address::SocketAddress(socket_addr) => {
52 timeout(self.connect_timeout, TcpStream::connect(socket_addr))
53 .await
54 .context(format!("connect {} timeout", socket_addr))?
55 .context(format!("connect {} failed", socket_addr))
56 }
57 }?;
58 Ok(Box::new(stream))
59 }
60}
61
62pub struct TrojanDial {
63 remote_addr: String,
64 hash: String,
65 insecure: bool,
66 padding: bool,
67 connect_timeout: Duration,
68}
69
70impl TrojanDial {
71 pub fn new(
72 remote_addr: String,
73 hash: String,
74 insecure: bool,
75 padding: bool,
76 connect_timeout: Duration,
77 ) -> Self {
78 Self {
79 remote_addr,
80 hash,
81 insecure,
82 padding,
83 connect_timeout,
84 }
85 }
86}
87
88#[async_trait]
89impl Dial for TrojanDial {
90 async fn dial(&self, addr: Address) -> anyhow::Result<Box<dyn AsyncStream>> {
91 let remote_ts = timeout(self.connect_timeout, TcpStream::connect(&self.remote_addr))
92 .await
93 .context(format!("connect {} timeout", self.remote_addr))?
94 .context(format!("connect {} failed", self.remote_addr))?;
95
96 let server_name = make_server_name(self.remote_addr.as_str())?;
97 let mut remote_ts_ssl = make_tls_connector(self.insecure)
98 .connect(server_name, remote_ts)
99 .await
100 .context("trojan can't connect tls")?;
101
102 if self.padding {
103 let req = trojan::Request::new(self.hash.clone(), Command::Padding, addr);
104 req.write_to(&mut remote_ts_ssl).await?;
105 Padding::read_from(&mut remote_ts_ssl).await?;
106 } else {
107 let req = trojan::Request::new(self.hash.clone(), Command::Connect, addr);
108 req.write_to(&mut remote_ts_ssl).await?;
109 }
110
111 Ok(Box::new(remote_ts_ssl))
112 }
113}
114
115#[cfg(feature = "websocket")]
116pub struct WebSocketDial {
117 remote_addr: String,
118 hash: String,
119 insecure: bool,
120 padding: bool,
121 connect_timeout: Duration,
122}
123
124#[cfg(feature = "websocket")]
125impl WebSocketDial {
126 pub fn new(
127 remote_addr: String,
128 hash: String,
129 insecure: bool,
130 padding: bool,
131 connect_timeout: Duration,
132 ) -> Self {
133 Self {
134 remote_addr,
135 hash,
136 insecure,
137 padding,
138 connect_timeout,
139 }
140 }
141}
142
143#[cfg(feature = "websocket")]
144#[async_trait]
145impl Dial for WebSocketDial {
146 async fn dial(&self, addr: Address) -> anyhow::Result<Box<dyn AsyncStream>> {
147 use crate::stream::websocket::WebSocketCopyStream;
148 use crate::tls::make_tls_client_config;
149 use bytes::BytesMut;
150 use futures::SinkExt;
151 use futures::StreamExt;
152 use tokio_tungstenite::connect_async_tls_with_config;
153 use tokio_tungstenite::tungstenite::Message;
154
155 let (mut ws, _) = timeout(
156 self.connect_timeout,
157 connect_async_tls_with_config(
158 &self.remote_addr,
159 None,
160 false,
161 Some(tokio_tungstenite::Connector::Rustls(Arc::new(
162 make_tls_client_config(self.insecure),
163 ))),
164 ),
165 )
166 .await
167 .context(format!("websocket connect {} timeout", self.remote_addr))?
168 .context(format!("websocket connect {} failed", self.remote_addr))?;
169
170 if self.padding {
171 let mut buf = BytesMut::new();
172 let req = trojan::Request::new(self.hash.clone(), Command::Padding, addr);
173 req.write_to_buf(&mut buf);
174 ws.send(Message::Binary(buf.freeze()))
175 .await
176 .context("websocket can't send")?;
177 ws.flush().await?;
178 let _ = ws.next().await;
179 } else {
180 let mut buf = BytesMut::new();
181 let req = trojan::Request::new(self.hash.clone(), Command::Connect, addr);
182 req.write_to_buf(&mut buf);
183 ws.send(Message::Binary(buf.freeze()))
184 .await
185 .context("websocket can't send")?;
186 ws.flush().await?;
187 }
188 Ok(Box::new(WebSocketCopyStream::new(ws)))
189 }
190}
191
192const EXTRA_TIMEOUT_MS: u64 = 200;
193
194pub struct SmartDial {
195 direct: Box<dyn Dial>,
196 proxy: Box<dyn Dial>,
197 proxy_list: Arc<ProxyList>,
198 direct_list: Arc<DirectList>,
199 connect_timeout: Duration,
200}
201
202impl SmartDial {
203 pub fn new(
204 direct: Box<dyn Dial>,
205 proxy: Box<dyn Dial>,
206 proxy_list: Arc<ProxyList>,
207 direct_list: Arc<DirectList>,
208 connect_timeout: Duration,
209 ) -> Self {
210 Self {
211 direct,
212 proxy,
213 proxy_list,
214 direct_list,
215 connect_timeout,
216 }
217 }
218
219 async fn handle_proxy_result(
220 &self,
221 proxy_res: anyhow::Result<Box<dyn AsyncStream>>,
222 addr: &Address,
223 elapsed: Duration,
224 ) -> anyhow::Result<Box<dyn AsyncStream>> {
225 match proxy_res {
226 Ok(mut stream) => {
227 let remaining = self
228 .connect_timeout
229 .saturating_add(Duration::from_millis(EXTRA_TIMEOUT_MS))
230 .saturating_sub(elapsed);
231 let check_timeout = remaining.min(Duration::from_millis(EXTRA_TIMEOUT_MS));
232
233 if !is_stream_closed(&mut stream, check_timeout).await {
234 self.proxy_list.add_address(addr);
235 info!(
236 "Proxy Connect to: {addr} : record [{:.3}s]",
237 elapsed.as_secs_f64()
238 );
239 } else {
240 info!(
241 "Proxy Connect to: {addr} : unrecord [{:.3}s]",
242 elapsed.as_secs_f64()
243 );
244 }
245 Ok(stream)
246 }
247 Err(e) => Err(e),
248 }
249 }
250}
251
252#[async_trait]
253impl Dial for SmartDial {
254 async fn dial(&self, addr: Address) -> anyhow::Result<Box<dyn AsyncStream>> {
255 if self.direct_list.contains_address(&addr) {
256 debug!("Address {:?} is in direct list, using direct dial", addr);
257 match self.direct.dial(addr.clone()).await {
258 Ok(stream) => {
259 info!("Direct Connect to: {addr}");
260 return Ok(stream);
261 }
262 Err(e) => {
263 debug!(
264 "Direct dial failed for {:?}: {}, removing from direct list",
265 addr, e
266 );
267 self.direct_list.remove_address(&addr);
268 }
269 }
270 }
271
272 if self.proxy_list.contains_address(&addr) {
273 debug!("Address {:?} is in proxy list, using proxy dial", addr);
274 info!("Proxy Connect to: {addr}");
275 return self.proxy.dial(addr).await;
276 }
277
278 let start = Instant::now();
279 let direct_fut = self.direct.dial(addr.clone());
280 let proxy_fut = self.proxy.dial(addr.clone());
281
282 tokio::pin!(direct_fut);
283 tokio::pin!(proxy_fut);
284
285 tokio::select! {
286 direct_res = &mut direct_fut => {
287 match direct_res {
288 Ok(stream) => {
289 self.direct_list.add_address(&addr);
290 info!("Direct Connect to: {addr}");
291 Ok(stream)
292 },
293 Err(e) => {
294 debug!("Direct dial failed for {:?}: {}, using proxy", addr, e);
295 let proxy_res = proxy_fut.await;
296 self.handle_proxy_result(proxy_res, &addr, start.elapsed()).await
297 }
298 }
299 }
300 proxy_res = &mut proxy_fut => {
301 let direct_res = direct_fut.await;
302 match direct_res {
303 Ok(stream) => {
304 self.direct_list.add_address(&addr);
305 info!("Direct Connect to: {addr}");
306 Ok(stream)
307 }
308 Err(e) => {
309 debug!("Direct dial failed for {:?}: {}, using proxy", addr, e);
310 self.handle_proxy_result(proxy_res, &addr, start.elapsed()).await
311 }
312 }
313 }
314 }
315 }
316}
317
318pub async fn is_stream_closed<R: AsyncRead + Unpin>(
319 stream: &mut R,
320 timeout_duration: Duration,
321) -> bool {
322 match timeout(timeout_duration, stream.read(&mut [])).await {
323 Ok(Ok(_)) => false,
324 Ok(Err(e)) => !matches!(e.kind(), std::io::ErrorKind::WouldBlock),
325 Err(_) => false,
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use std::sync::Arc;
332
333 use async_trait::async_trait;
334 use socks5_proto::Address;
335 use tempfile::NamedTempFile;
336 use tokio::{
337 io::{AsyncReadExt, AsyncWriteExt},
338 net::TcpListener,
339 time::Duration,
340 };
341
342 use super::{AsyncStream, Dial, DirectDial, SmartDial};
343 use crate::address_list::{DirectList, ProxyList};
344
345 struct MockDial {
346 succeed: bool,
347 delay: Duration,
348 }
349
350 impl MockDial {
351 fn succeed() -> Self {
352 Self {
353 succeed: true,
354 delay: Duration::ZERO,
355 }
356 }
357
358 fn fail() -> Self {
359 Self {
360 succeed: false,
361 delay: Duration::ZERO,
362 }
363 }
364
365 fn with_delay(delay: Duration, succeed: bool) -> Self {
366 Self { succeed, delay }
367 }
368
369 fn with_timeout_error(delay: Duration) -> Self {
370 Self {
371 succeed: false,
372 delay,
373 }
374 }
375 }
376
377 #[async_trait]
378 impl Dial for MockDial {
379 async fn dial(&self, _addr: Address) -> anyhow::Result<Box<dyn AsyncStream>> {
380 tokio::time::sleep(self.delay).await;
381 if self.succeed {
382 Ok(Box::new(tokio::io::duplex(64).0) as Box<dyn AsyncStream>)
383 } else {
384 Err(anyhow::anyhow!("mock dial failed"))
385 }
386 }
387 }
388
389 #[tokio::test]
390 async fn direct_dial_connects_to_local_listener() {
391 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
392 let addr = listener.local_addr().unwrap();
393
394 let accept_task = tokio::spawn(async move {
395 let (mut stream, _) = listener.accept().await.unwrap();
396 stream.write_all(b"ok").await.unwrap();
397 });
398
399 let mut stream = DirectDial::new(Duration::from_secs(3))
400 .dial(Address::SocketAddress(addr))
401 .await
402 .unwrap();
403 let mut buf = [0u8; 2];
404 AsyncReadExt::read_exact(&mut stream, &mut buf)
405 .await
406 .unwrap();
407
408 accept_task.await.unwrap();
409 assert_eq!(&buf, b"ok");
410 }
411
412 #[tokio::test]
413 async fn direct_dial_returns_error_for_unreachable_port() {
414 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
415 let addr = listener.local_addr().unwrap();
416 drop(listener);
417
418 let result = DirectDial::new(Duration::from_secs(3))
419 .dial(Address::SocketAddress(addr))
420 .await;
421
422 assert!(result.is_err());
423 }
424
425 #[tokio::test]
426 async fn smart_dial_uses_proxy_when_domain_in_list() {
427 let proxy_temp_file = NamedTempFile::new().unwrap();
428 std::fs::write(proxy_temp_file.path(), "blocked.com\n").unwrap();
429
430 let direct_temp_file = NamedTempFile::new().unwrap();
431
432 let proxy_list = Arc::new(ProxyList::new(proxy_temp_file.path()));
433 let direct_list = Arc::new(DirectList::new(direct_temp_file.path()));
434 let direct = MockDial::succeed();
435 let proxy = MockDial::succeed();
436
437 let smart_dial = SmartDial::new(
438 Box::new(direct),
439 Box::new(proxy),
440 proxy_list,
441 direct_list,
442 Duration::from_secs(3),
443 );
444
445 let result = smart_dial
446 .dial(Address::DomainAddress(b"blocked.com".to_vec(), 443))
447 .await;
448
449 assert!(result.is_ok());
450 }
451
452 #[tokio::test]
453 async fn smart_dial_uses_direct_when_succeeds_first() {
454 let proxy_temp_file = NamedTempFile::new().unwrap();
455 let direct_temp_file = NamedTempFile::new().unwrap();
456 let proxy_list = Arc::new(ProxyList::new(proxy_temp_file.path()));
457 let direct_list = Arc::new(DirectList::new(direct_temp_file.path()));
458
459 let direct = MockDial::with_delay(Duration::from_millis(10), true);
460 let proxy = MockDial::with_delay(Duration::from_millis(100), true);
461
462 let smart_dial = SmartDial::new(
463 Box::new(direct),
464 Box::new(proxy),
465 proxy_list,
466 direct_list,
467 Duration::from_secs(3),
468 );
469
470 let result = smart_dial
471 .dial(Address::DomainAddress(b"fast-direct.com".to_vec(), 443))
472 .await;
473
474 assert!(result.is_ok());
475 }
476
477 #[tokio::test]
478 async fn smart_dial_adds_domain_when_direct_times_out_and_proxy_succeeds() {
479 let proxy_temp_file = NamedTempFile::new().unwrap();
480 let direct_temp_file = NamedTempFile::new().unwrap();
481 let proxy_list = Arc::new(ProxyList::new(proxy_temp_file.path()));
482 let direct_list = Arc::new(DirectList::new(direct_temp_file.path()));
483
484 let direct = MockDial::with_timeout_error(Duration::from_millis(100));
485 let proxy = MockDial::with_delay(Duration::from_millis(10), true);
486
487 let smart_dial = SmartDial::new(
488 Box::new(direct),
489 Box::new(proxy),
490 proxy_list.clone(),
491 direct_list,
492 Duration::from_secs(3),
493 );
494
495 let result = smart_dial
496 .dial(Address::DomainAddress(b"slow-direct.com".to_vec(), 443))
497 .await;
498
499 assert!(result.is_ok());
500 assert!(
501 proxy_list.contains_address(&Address::DomainAddress(b"slow-direct.com".to_vec(), 443))
502 );
503 }
504
505 #[tokio::test]
506 async fn smart_dial_returns_error_when_both_fail() {
507 let proxy_temp_file = NamedTempFile::new().unwrap();
508 let direct_temp_file = NamedTempFile::new().unwrap();
509 let proxy_list = Arc::new(ProxyList::new(proxy_temp_file.path()));
510 let direct_list = Arc::new(DirectList::new(direct_temp_file.path()));
511
512 let direct = MockDial::fail();
513 let proxy = MockDial::fail();
514
515 let smart_dial = SmartDial::new(
516 Box::new(direct),
517 Box::new(proxy),
518 proxy_list,
519 direct_list,
520 Duration::from_secs(3),
521 );
522
523 let result = smart_dial
524 .dial(Address::DomainAddress(b"both-fail.com".to_vec(), 443))
525 .await;
526
527 assert!(result.is_err());
528 }
529
530 #[tokio::test]
531 async fn smart_dial_uses_direct_when_domain_in_direct_list() {
532 let proxy_temp_file = NamedTempFile::new().unwrap();
533 let direct_temp_file = NamedTempFile::new().unwrap();
534 std::fs::write(direct_temp_file.path(), "direct.com\n").unwrap();
535
536 let proxy_list = Arc::new(ProxyList::new(proxy_temp_file.path()));
537 let direct_list = Arc::new(DirectList::new(direct_temp_file.path()));
538 let direct = MockDial::succeed();
539 let proxy = MockDial::fail();
540
541 let smart_dial = SmartDial::new(
542 Box::new(direct),
543 Box::new(proxy),
544 proxy_list,
545 direct_list,
546 Duration::from_secs(3),
547 );
548
549 let result = smart_dial
550 .dial(Address::DomainAddress(b"direct.com".to_vec(), 443))
551 .await;
552
553 assert!(result.is_ok());
554 }
555
556 #[tokio::test]
557 async fn smart_dial_removes_from_direct_list_on_failure() {
558 let proxy_temp_file = NamedTempFile::new().unwrap();
559 let direct_temp_file = NamedTempFile::new().unwrap();
560 std::fs::write(direct_temp_file.path(), "direct-fail.com\n").unwrap();
561
562 let proxy_list = Arc::new(ProxyList::new(proxy_temp_file.path()));
563 let direct_list = Arc::new(DirectList::new(direct_temp_file.path()));
564 let direct = MockDial::fail();
565 let proxy = MockDial::succeed();
566
567 let smart_dial = SmartDial::new(
568 Box::new(direct),
569 Box::new(proxy),
570 proxy_list,
571 direct_list.clone(),
572 Duration::from_secs(3),
573 );
574
575 let result = smart_dial
576 .dial(Address::DomainAddress(b"direct-fail.com".to_vec(), 443))
577 .await;
578
579 assert!(result.is_ok());
580 assert!(
581 !direct_list
582 .contains_address(&Address::DomainAddress(b"direct-fail.com".to_vec(), 443))
583 );
584 }
585
586 #[tokio::test]
587 async fn smart_dial_adds_to_direct_list_on_success() {
588 let proxy_temp_file = NamedTempFile::new().unwrap();
589 let direct_temp_file = NamedTempFile::new().unwrap();
590 let proxy_list = Arc::new(ProxyList::new(proxy_temp_file.path()));
591 let direct_list = Arc::new(DirectList::new(direct_temp_file.path()));
592
593 let direct = MockDial::with_delay(Duration::from_millis(10), true);
594 let proxy = MockDial::with_delay(Duration::from_millis(100), true);
595
596 let smart_dial = SmartDial::new(
597 Box::new(direct),
598 Box::new(proxy),
599 proxy_list,
600 direct_list.clone(),
601 Duration::from_secs(3),
602 );
603
604 let result = smart_dial
605 .dial(Address::DomainAddress(b"new-direct.com".to_vec(), 443))
606 .await;
607
608 assert!(result.is_ok());
609 assert!(
610 direct_list.contains_address(&Address::DomainAddress(b"new-direct.com".to_vec(), 443))
611 );
612 }
613}