ombrac_client/endpoint/
socks.rs

1use std::sync::Arc;
2
3use ombrac::Secret;
4use ombrac::client::Client;
5use ombrac_macros::{debug, info};
6use ombrac_transport::Initiator;
7use socks_lib::io::{self, AsyncRead, AsyncWrite};
8use socks_lib::v5::server::Handler;
9use socks_lib::v5::{Address as SocksAddress, Request, Stream};
10
11pub struct CommandHandler<I: Initiator> {
12    ombrac_client: Arc<Client<I>>,
13    secret: Secret,
14}
15
16impl<I: Initiator> CommandHandler<I> {
17    pub fn new(inner: Arc<Client<I>>, secret: Secret) -> Self {
18        Self {
19            ombrac_client: inner,
20            secret,
21        }
22    }
23
24    async fn handle_connect(
25        &self,
26        address: SocksAddress,
27        stream: &mut Stream<impl AsyncRead + AsyncWrite + Unpin>,
28    ) -> io::Result<(u64, u64)> {
29        let addr = util::socks_to_ombrac_addr(address)?;
30        let mut outbound = self.ombrac_client.connect(addr, self.secret).await?;
31        ombrac::io::util::copy_bidirectional(stream, &mut outbound).await
32    }
33}
34
35#[cfg(feature = "datagram")]
36mod datagram {
37    use std::time::Duration;
38
39    use ombrac::client::Datagram;
40    use ombrac_transport::{Initiator, Unreliable};
41    use socks_lib::v5::UdpPacket;
42    use tokio::{net::UdpSocket, time::timeout};
43
44    use super::*;
45
46    const IDLE_TIMEOUT: Duration = Duration::from_secs(30);
47    const DEFAULT_BUFFER_SIZE: usize = 1500;
48
49    impl<I: Initiator> CommandHandler<I> {
50        pub async fn handle_associate(&self, socket: UdpSocket) -> io::Result<()> {
51            let outbound = self.ombrac_client.associate().await?;
52            let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE];
53            let (first_packet, source_addr) =
54                match timeout(IDLE_TIMEOUT, socket.recv_from(&mut buf)).await {
55                    Ok(Ok((_size, from))) => (UdpPacket::from_bytes(&mut (&buf[..]))?, from),
56                    Ok(Err(e)) => return Err(e),
57                    Err(_) => return Ok(()),
58                };
59
60            let packet = util::socks_to_ombrac_packet(first_packet, self.secret)?;
61            outbound.send(packet).await?;
62            socket.connect(source_addr).await?;
63
64            let datagram = Arc::new(outbound);
65            let udp_socket = Arc::new(socket);
66
67            let client_to_target_task = tokio::spawn(proxy_ombrac_to_target(
68                Arc::clone(&datagram),
69                Arc::clone(&udp_socket),
70                IDLE_TIMEOUT,
71            ));
72
73            let target_to_client_task = tokio::spawn(proxy_target_to_ombrac_server(
74                datagram,
75                udp_socket,
76                self.secret,
77                IDLE_TIMEOUT,
78            ));
79
80            let (client_res, target_res) =
81                tokio::join!(client_to_target_task, target_to_client_task);
82
83            client_res??;
84            target_res??;
85
86            Ok(())
87        }
88    }
89
90    async fn proxy_ombrac_to_target<U>(
91        datagram: Arc<Datagram<U>>,
92        udp_socket: Arc<UdpSocket>,
93        idle_timeout: Duration,
94    ) -> io::Result<()>
95    where
96        U: Unreliable,
97    {
98        loop {
99            let ombrac_packet = match timeout(idle_timeout, datagram.recv()).await {
100                Ok(Ok(packet)) => packet,
101                Ok(Err(e)) => return Err(e),
102                Err(_) => break, // Timeout
103            };
104
105            let socks_packet = util::ombrac_to_socks_packet(ombrac_packet)?;
106            udp_socket.send(&socks_packet.to_bytes()).await?;
107        }
108        Ok(())
109    }
110
111    async fn proxy_target_to_ombrac_server<U>(
112        datagram: Arc<Datagram<U>>,
113        udp_socket: Arc<UdpSocket>,
114        session_secret: Secret,
115        idle_timeout: Duration,
116    ) -> io::Result<()>
117    where
118        U: Unreliable,
119    {
120        let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE];
121        loop {
122            let n = match timeout(idle_timeout, udp_socket.recv(&mut buf)).await {
123                Ok(Ok(result)) => result,
124                Ok(Err(e)) => return Err(e),
125                Err(_) => break,
126            };
127
128            let socks_packet = UdpPacket::from_bytes(&mut (&buf[..n]))?;
129            let packet = util::socks_to_ombrac_packet(socks_packet, session_secret)?;
130
131            datagram.send(packet).await?;
132        }
133        Ok(())
134    }
135}
136
137impl<I: Initiator> Handler for CommandHandler<I> {
138    async fn handle<T>(&self, stream: &mut Stream<T>, request: Request) -> io::Result<()>
139    where
140        T: AsyncRead + AsyncWrite + Unpin + Send + Sync,
141    {
142        debug!("SOCKS Request: {:?}", request);
143
144        match &request {
145            Request::Connect(address) => {
146                stream.write_response_unspecified().await?;
147
148                match self.handle_connect(address.clone(), stream).await {
149                    Ok(_copy) => {
150                        info!(
151                            "{} Connect {}, Send: {}, Recv: {}",
152                            stream.peer_addr(),
153                            address,
154                            _copy.0,
155                            _copy.1
156                        );
157                    }
158                    Err(err) => return Err(err),
159                }
160            }
161            #[cfg(feature = "datagram")]
162            Request::Associate(_addr) => {
163                use socks_lib::v5::Response;
164                use tokio::net::UdpSocket;
165
166                let socket = UdpSocket::bind("0.0.0.0:0").await?;
167                let addr = SocksAddress::from(socket.local_addr()?);
168
169                stream.write_response(&Response::Success(&addr)).await?;
170
171                self.handle_associate(socket).await?;
172            }
173            _ => {
174                stream.write_response_unsupported().await?;
175            }
176        }
177
178        Ok(())
179    }
180}
181
182mod util {
183    use std::io;
184
185    use ombrac::address::{Address as OmbracAddress, Domain as OmbracDoamin};
186    #[cfg(feature = "datagram")]
187    use ombrac::{Secret as OmbracSecret, associate::Associate as OmbracPacket};
188    use socks_lib::v5::Address as Socks5Address;
189    #[cfg(feature = "datagram")]
190    use socks_lib::v5::UdpPacket as Socks5Packet;
191
192    #[inline]
193    #[cfg(feature = "datagram")]
194    pub(super) fn socks_to_ombrac_packet(
195        packet: Socks5Packet,
196        secret: OmbracSecret,
197    ) -> io::Result<OmbracPacket> {
198        let addr = socks_to_ombrac_addr(packet.address)?;
199        let data = packet.data;
200
201        Ok(OmbracPacket::with(secret, addr, data))
202    }
203
204    #[inline]
205    pub(super) fn socks_to_ombrac_addr(addr: Socks5Address) -> io::Result<OmbracAddress> {
206        let result = match addr {
207            Socks5Address::IPv4(value) => OmbracAddress::IPv4(value),
208            Socks5Address::IPv6(value) => OmbracAddress::IPv6(value),
209            Socks5Address::Domain(domain, port) => OmbracAddress::Domain(
210                OmbracDoamin::from_bytes(domain.as_bytes().to_owned())?,
211                port,
212            ),
213        };
214
215        Ok(result)
216    }
217
218    #[inline]
219    #[cfg(feature = "datagram")]
220    pub(super) fn ombrac_to_socks_packet(packet: OmbracPacket) -> io::Result<Socks5Packet> {
221        let addr = match packet.address {
222            OmbracAddress::IPv4(value) => Socks5Address::IPv4(value),
223            OmbracAddress::IPv6(value) => Socks5Address::IPv6(value),
224            OmbracAddress::Domain(domain, port) => {
225                Socks5Address::Domain(domain.to_bytes().try_into()?, port)
226            }
227        };
228        let data = packet.data;
229
230        Ok(Socks5Packet::un_frag(addr, data))
231    }
232}