Skip to main content

borer_core/
dial.rs

1use std::sync::Arc;
2use std::time::{Duration, Instant};
3
4use anyhow::Context;
5use async_trait::async_trait;
6use log::{debug, info};
7use socks5_proto::Address;
8use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
9use tokio::net::TcpStream;
10use tokio::time::timeout;
11
12use crate::address_list::{DirectList, ProxyList};
13use crate::proto::padding::Padding;
14use crate::proto::trojan;
15use crate::proto::trojan::Command;
16use crate::tls::make_server_name;
17use crate::tls::make_tls_connector;
18
19pub trait AsyncStream: AsyncRead + AsyncWrite + Unpin + Send {}
20impl<T: AsyncRead + AsyncWrite + Unpin + Send> AsyncStream for T {}
21
22#[async_trait]
23pub trait Dial: Send + Sync {
24    async fn dial(&self, addr: Address) -> anyhow::Result<Box<dyn AsyncStream>>;
25}
26
27pub struct DirectDial {
28    connect_timeout: Duration,
29}
30
31impl DirectDial {
32    pub fn new(connect_timeout: Duration) -> Self {
33        Self { connect_timeout }
34    }
35}
36
37#[async_trait]
38impl Dial for DirectDial {
39    async fn dial(&self, addr: Address) -> anyhow::Result<Box<dyn AsyncStream>> {
40        let stream: TcpStream = match addr {
41            Address::DomainAddress(domain, port) => {
42                let domain = String::from_utf8_lossy(&domain);
43                timeout(
44                    self.connect_timeout,
45                    TcpStream::connect((domain.as_ref(), port)),
46                )
47                .await
48                .context(format!("connect {}:{} timeout", domain, port))?
49                .context(format!("connect {}:{} failed", domain, port))
50            }
51            Address::SocketAddress(socket_addr) => {
52                timeout(self.connect_timeout, TcpStream::connect(socket_addr))
53                    .await
54                    .context(format!("connect {} timeout", socket_addr))?
55                    .context(format!("connect {} failed", socket_addr))
56            }
57        }?;
58        Ok(Box::new(stream))
59    }
60}
61
62pub struct TrojanDial {
63    remote_addr: String,
64    hash: String,
65    insecure: bool,
66    padding: bool,
67    connect_timeout: Duration,
68}
69
70impl TrojanDial {
71    pub fn new(
72        remote_addr: String,
73        hash: String,
74        insecure: bool,
75        padding: bool,
76        connect_timeout: Duration,
77    ) -> Self {
78        Self {
79            remote_addr,
80            hash,
81            insecure,
82            padding,
83            connect_timeout,
84        }
85    }
86}
87
88#[async_trait]
89impl Dial for TrojanDial {
90    async fn dial(&self, addr: Address) -> anyhow::Result<Box<dyn AsyncStream>> {
91        let remote_ts = timeout(self.connect_timeout, TcpStream::connect(&self.remote_addr))
92            .await
93            .context(format!("connect {} timeout", self.remote_addr))?
94            .context(format!("connect {} failed", self.remote_addr))?;
95
96        let server_name = make_server_name(self.remote_addr.as_str())?;
97        let mut remote_ts_ssl = make_tls_connector(self.insecure)
98            .connect(server_name, remote_ts)
99            .await
100            .context("trojan can't connect tls")?;
101
102        if self.padding {
103            let req = trojan::Request::new(self.hash.clone(), Command::Padding, addr);
104            req.write_to(&mut remote_ts_ssl).await?;
105            Padding::read_from(&mut remote_ts_ssl).await?;
106        } else {
107            let req = trojan::Request::new(self.hash.clone(), Command::Connect, addr);
108            req.write_to(&mut remote_ts_ssl).await?;
109        }
110
111        Ok(Box::new(remote_ts_ssl))
112    }
113}
114
115#[cfg(feature = "websocket")]
116pub struct WebSocketDial {
117    remote_addr: String,
118    hash: String,
119    insecure: bool,
120    padding: bool,
121    connect_timeout: Duration,
122}
123
124#[cfg(feature = "websocket")]
125impl WebSocketDial {
126    pub fn new(
127        remote_addr: String,
128        hash: String,
129        insecure: bool,
130        padding: bool,
131        connect_timeout: Duration,
132    ) -> Self {
133        Self {
134            remote_addr,
135            hash,
136            insecure,
137            padding,
138            connect_timeout,
139        }
140    }
141}
142
143#[cfg(feature = "websocket")]
144#[async_trait]
145impl Dial for WebSocketDial {
146    async fn dial(&self, addr: Address) -> anyhow::Result<Box<dyn AsyncStream>> {
147        use crate::stream::websocket::WebSocketCopyStream;
148        use crate::tls::make_tls_client_config;
149        use bytes::BytesMut;
150        use futures::SinkExt;
151        use futures::StreamExt;
152        use tokio_tungstenite::connect_async_tls_with_config;
153        use tokio_tungstenite::tungstenite::Message;
154
155        let (mut ws, _) = timeout(
156            self.connect_timeout,
157            connect_async_tls_with_config(
158                &self.remote_addr,
159                None,
160                false,
161                Some(tokio_tungstenite::Connector::Rustls(Arc::new(
162                    make_tls_client_config(self.insecure),
163                ))),
164            ),
165        )
166        .await
167        .context(format!("websocket connect {} timeout", self.remote_addr))?
168        .context(format!("websocket connect {} failed", self.remote_addr))?;
169
170        if self.padding {
171            let mut buf = BytesMut::new();
172            let req = trojan::Request::new(self.hash.clone(), Command::Padding, addr);
173            req.write_to_buf(&mut buf);
174            ws.send(Message::Binary(buf.freeze()))
175                .await
176                .context("websocket can't send")?;
177            ws.flush().await?;
178            let _ = ws.next().await;
179        } else {
180            let mut buf = BytesMut::new();
181            let req = trojan::Request::new(self.hash.clone(), Command::Connect, addr);
182            req.write_to_buf(&mut buf);
183            ws.send(Message::Binary(buf.freeze()))
184                .await
185                .context("websocket can't send")?;
186            ws.flush().await?;
187        }
188        Ok(Box::new(WebSocketCopyStream::new(ws)))
189    }
190}
191
192const EXTRA_TIMEOUT_MS: u64 = 200;
193
194pub struct SmartDial {
195    direct: Box<dyn Dial>,
196    proxy: Box<dyn Dial>,
197    proxy_list: Arc<ProxyList>,
198    direct_list: Arc<DirectList>,
199    connect_timeout: Duration,
200}
201
202impl SmartDial {
203    pub fn new(
204        direct: Box<dyn Dial>,
205        proxy: Box<dyn Dial>,
206        proxy_list: Arc<ProxyList>,
207        direct_list: Arc<DirectList>,
208        connect_timeout: Duration,
209    ) -> Self {
210        Self {
211            direct,
212            proxy,
213            proxy_list,
214            direct_list,
215            connect_timeout,
216        }
217    }
218
219    async fn handle_proxy_result(
220        &self,
221        proxy_res: anyhow::Result<Box<dyn AsyncStream>>,
222        addr: &Address,
223        elapsed: Duration,
224    ) -> anyhow::Result<Box<dyn AsyncStream>> {
225        match proxy_res {
226            Ok(mut stream) => {
227                let remaining = self
228                    .connect_timeout
229                    .saturating_add(Duration::from_millis(EXTRA_TIMEOUT_MS))
230                    .saturating_sub(elapsed);
231                let check_timeout = remaining.min(Duration::from_millis(EXTRA_TIMEOUT_MS));
232
233                if !is_stream_closed(&mut stream, check_timeout).await {
234                    self.proxy_list.add_address(addr);
235                    info!(
236                        "Proxy Connect to: {addr} : record [{:.3}s]",
237                        elapsed.as_secs_f64()
238                    );
239                } else {
240                    info!(
241                        "Proxy Connect to: {addr} : unrecord [{:.3}s]",
242                        elapsed.as_secs_f64()
243                    );
244                }
245                Ok(stream)
246            }
247            Err(e) => Err(e),
248        }
249    }
250}
251
252#[async_trait]
253impl Dial for SmartDial {
254    async fn dial(&self, addr: Address) -> anyhow::Result<Box<dyn AsyncStream>> {
255        if self.direct_list.contains_address(&addr) {
256            debug!("Address {:?} is in direct list, using direct dial", addr);
257            match self.direct.dial(addr.clone()).await {
258                Ok(stream) => {
259                    info!("Direct Connect to: {addr}");
260                    return Ok(stream);
261                }
262                Err(e) => {
263                    debug!(
264                        "Direct dial failed for {:?}: {}, removing from direct list",
265                        addr, e
266                    );
267                    self.direct_list.remove_address(&addr);
268                }
269            }
270        }
271
272        if self.proxy_list.contains_address(&addr) {
273            debug!("Address {:?} is in proxy list, using proxy dial", addr);
274            info!("Proxy Connect to: {addr}");
275            return self.proxy.dial(addr).await;
276        }
277
278        let start = Instant::now();
279        let direct_fut = self.direct.dial(addr.clone());
280        let proxy_fut = self.proxy.dial(addr.clone());
281
282        tokio::pin!(direct_fut);
283        tokio::pin!(proxy_fut);
284
285        tokio::select! {
286            direct_res = &mut direct_fut => {
287                match direct_res {
288                    Ok(stream) => {
289                        self.direct_list.add_address(&addr);
290                        info!("Direct Connect to: {addr}");
291                        Ok(stream)
292                    },
293                    Err(e) => {
294                        debug!("Direct dial failed for {:?}: {}, using proxy", addr, e);
295                        let proxy_res = proxy_fut.await;
296                        self.handle_proxy_result(proxy_res, &addr, start.elapsed()).await
297                    }
298                }
299            }
300            proxy_res = &mut proxy_fut => {
301                let direct_res = direct_fut.await;
302                match direct_res {
303                    Ok(stream) => {
304                        self.direct_list.add_address(&addr);
305                        info!("Direct Connect to: {addr}");
306                        Ok(stream)
307                    }
308                    Err(e) => {
309                        debug!("Direct dial failed for {:?}: {}, using proxy", addr, e);
310                        self.handle_proxy_result(proxy_res, &addr, start.elapsed()).await
311                    }
312                }
313            }
314        }
315    }
316}
317
318pub async fn is_stream_closed<R: AsyncRead + Unpin>(
319    stream: &mut R,
320    timeout_duration: Duration,
321) -> bool {
322    match timeout(timeout_duration, stream.read(&mut [])).await {
323        Ok(Ok(_)) => false,
324        Ok(Err(e)) => !matches!(e.kind(), std::io::ErrorKind::WouldBlock),
325        Err(_) => false,
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use std::sync::Arc;
332
333    use async_trait::async_trait;
334    use socks5_proto::Address;
335    use tempfile::NamedTempFile;
336    use tokio::{
337        io::{AsyncReadExt, AsyncWriteExt},
338        net::TcpListener,
339        time::Duration,
340    };
341
342    use super::{AsyncStream, Dial, DirectDial, SmartDial};
343    use crate::address_list::{DirectList, ProxyList};
344
345    struct MockDial {
346        succeed: bool,
347        delay: Duration,
348    }
349
350    impl MockDial {
351        fn succeed() -> Self {
352            Self {
353                succeed: true,
354                delay: Duration::ZERO,
355            }
356        }
357
358        fn fail() -> Self {
359            Self {
360                succeed: false,
361                delay: Duration::ZERO,
362            }
363        }
364
365        fn with_delay(delay: Duration, succeed: bool) -> Self {
366            Self { succeed, delay }
367        }
368
369        fn with_timeout_error(delay: Duration) -> Self {
370            Self {
371                succeed: false,
372                delay,
373            }
374        }
375    }
376
377    #[async_trait]
378    impl Dial for MockDial {
379        async fn dial(&self, _addr: Address) -> anyhow::Result<Box<dyn AsyncStream>> {
380            tokio::time::sleep(self.delay).await;
381            if self.succeed {
382                Ok(Box::new(tokio::io::duplex(64).0) as Box<dyn AsyncStream>)
383            } else {
384                Err(anyhow::anyhow!("mock dial failed"))
385            }
386        }
387    }
388
389    #[tokio::test]
390    async fn direct_dial_connects_to_local_listener() {
391        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
392        let addr = listener.local_addr().unwrap();
393
394        let accept_task = tokio::spawn(async move {
395            let (mut stream, _) = listener.accept().await.unwrap();
396            stream.write_all(b"ok").await.unwrap();
397        });
398
399        let mut stream = DirectDial::new(Duration::from_secs(3))
400            .dial(Address::SocketAddress(addr))
401            .await
402            .unwrap();
403        let mut buf = [0u8; 2];
404        AsyncReadExt::read_exact(&mut stream, &mut buf)
405            .await
406            .unwrap();
407
408        accept_task.await.unwrap();
409        assert_eq!(&buf, b"ok");
410    }
411
412    #[tokio::test]
413    async fn direct_dial_returns_error_for_unreachable_port() {
414        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
415        let addr = listener.local_addr().unwrap();
416        drop(listener);
417
418        let result = DirectDial::new(Duration::from_secs(3))
419            .dial(Address::SocketAddress(addr))
420            .await;
421
422        assert!(result.is_err());
423    }
424
425    #[tokio::test]
426    async fn smart_dial_uses_proxy_when_domain_in_list() {
427        let proxy_temp_file = NamedTempFile::new().unwrap();
428        std::fs::write(proxy_temp_file.path(), "blocked.com\n").unwrap();
429
430        let direct_temp_file = NamedTempFile::new().unwrap();
431
432        let proxy_list = Arc::new(ProxyList::new(proxy_temp_file.path()));
433        let direct_list = Arc::new(DirectList::new(direct_temp_file.path()));
434        let direct = MockDial::succeed();
435        let proxy = MockDial::succeed();
436
437        let smart_dial = SmartDial::new(
438            Box::new(direct),
439            Box::new(proxy),
440            proxy_list,
441            direct_list,
442            Duration::from_secs(3),
443        );
444
445        let result = smart_dial
446            .dial(Address::DomainAddress(b"blocked.com".to_vec(), 443))
447            .await;
448
449        assert!(result.is_ok());
450    }
451
452    #[tokio::test]
453    async fn smart_dial_uses_direct_when_succeeds_first() {
454        let proxy_temp_file = NamedTempFile::new().unwrap();
455        let direct_temp_file = NamedTempFile::new().unwrap();
456        let proxy_list = Arc::new(ProxyList::new(proxy_temp_file.path()));
457        let direct_list = Arc::new(DirectList::new(direct_temp_file.path()));
458
459        let direct = MockDial::with_delay(Duration::from_millis(10), true);
460        let proxy = MockDial::with_delay(Duration::from_millis(100), true);
461
462        let smart_dial = SmartDial::new(
463            Box::new(direct),
464            Box::new(proxy),
465            proxy_list,
466            direct_list,
467            Duration::from_secs(3),
468        );
469
470        let result = smart_dial
471            .dial(Address::DomainAddress(b"fast-direct.com".to_vec(), 443))
472            .await;
473
474        assert!(result.is_ok());
475    }
476
477    #[tokio::test]
478    async fn smart_dial_adds_domain_when_direct_times_out_and_proxy_succeeds() {
479        let proxy_temp_file = NamedTempFile::new().unwrap();
480        let direct_temp_file = NamedTempFile::new().unwrap();
481        let proxy_list = Arc::new(ProxyList::new(proxy_temp_file.path()));
482        let direct_list = Arc::new(DirectList::new(direct_temp_file.path()));
483
484        let direct = MockDial::with_timeout_error(Duration::from_millis(100));
485        let proxy = MockDial::with_delay(Duration::from_millis(10), true);
486
487        let smart_dial = SmartDial::new(
488            Box::new(direct),
489            Box::new(proxy),
490            proxy_list.clone(),
491            direct_list,
492            Duration::from_secs(3),
493        );
494
495        let result = smart_dial
496            .dial(Address::DomainAddress(b"slow-direct.com".to_vec(), 443))
497            .await;
498
499        assert!(result.is_ok());
500        assert!(
501            proxy_list.contains_address(&Address::DomainAddress(b"slow-direct.com".to_vec(), 443))
502        );
503    }
504
505    #[tokio::test]
506    async fn smart_dial_returns_error_when_both_fail() {
507        let proxy_temp_file = NamedTempFile::new().unwrap();
508        let direct_temp_file = NamedTempFile::new().unwrap();
509        let proxy_list = Arc::new(ProxyList::new(proxy_temp_file.path()));
510        let direct_list = Arc::new(DirectList::new(direct_temp_file.path()));
511
512        let direct = MockDial::fail();
513        let proxy = MockDial::fail();
514
515        let smart_dial = SmartDial::new(
516            Box::new(direct),
517            Box::new(proxy),
518            proxy_list,
519            direct_list,
520            Duration::from_secs(3),
521        );
522
523        let result = smart_dial
524            .dial(Address::DomainAddress(b"both-fail.com".to_vec(), 443))
525            .await;
526
527        assert!(result.is_err());
528    }
529
530    #[tokio::test]
531    async fn smart_dial_uses_direct_when_domain_in_direct_list() {
532        let proxy_temp_file = NamedTempFile::new().unwrap();
533        let direct_temp_file = NamedTempFile::new().unwrap();
534        std::fs::write(direct_temp_file.path(), "direct.com\n").unwrap();
535
536        let proxy_list = Arc::new(ProxyList::new(proxy_temp_file.path()));
537        let direct_list = Arc::new(DirectList::new(direct_temp_file.path()));
538        let direct = MockDial::succeed();
539        let proxy = MockDial::fail();
540
541        let smart_dial = SmartDial::new(
542            Box::new(direct),
543            Box::new(proxy),
544            proxy_list,
545            direct_list,
546            Duration::from_secs(3),
547        );
548
549        let result = smart_dial
550            .dial(Address::DomainAddress(b"direct.com".to_vec(), 443))
551            .await;
552
553        assert!(result.is_ok());
554    }
555
556    #[tokio::test]
557    async fn smart_dial_removes_from_direct_list_on_failure() {
558        let proxy_temp_file = NamedTempFile::new().unwrap();
559        let direct_temp_file = NamedTempFile::new().unwrap();
560        std::fs::write(direct_temp_file.path(), "direct-fail.com\n").unwrap();
561
562        let proxy_list = Arc::new(ProxyList::new(proxy_temp_file.path()));
563        let direct_list = Arc::new(DirectList::new(direct_temp_file.path()));
564        let direct = MockDial::fail();
565        let proxy = MockDial::succeed();
566
567        let smart_dial = SmartDial::new(
568            Box::new(direct),
569            Box::new(proxy),
570            proxy_list,
571            direct_list.clone(),
572            Duration::from_secs(3),
573        );
574
575        let result = smart_dial
576            .dial(Address::DomainAddress(b"direct-fail.com".to_vec(), 443))
577            .await;
578
579        assert!(result.is_ok());
580        assert!(
581            !direct_list
582                .contains_address(&Address::DomainAddress(b"direct-fail.com".to_vec(), 443))
583        );
584    }
585
586    #[tokio::test]
587    async fn smart_dial_adds_to_direct_list_on_success() {
588        let proxy_temp_file = NamedTempFile::new().unwrap();
589        let direct_temp_file = NamedTempFile::new().unwrap();
590        let proxy_list = Arc::new(ProxyList::new(proxy_temp_file.path()));
591        let direct_list = Arc::new(DirectList::new(direct_temp_file.path()));
592
593        let direct = MockDial::with_delay(Duration::from_millis(10), true);
594        let proxy = MockDial::with_delay(Duration::from_millis(100), true);
595
596        let smart_dial = SmartDial::new(
597            Box::new(direct),
598            Box::new(proxy),
599            proxy_list,
600            direct_list.clone(),
601            Duration::from_secs(3),
602        );
603
604        let result = smart_dial
605            .dial(Address::DomainAddress(b"new-direct.com".to_vec(), 443))
606            .await;
607
608        assert!(result.is_ok());
609        assert!(
610            direct_list.contains_address(&Address::DomainAddress(b"new-direct.com".to_vec(), 443))
611        );
612    }
613}