ombrac_client/endpoint/
socks.rs1use std::io;
2use std::net::SocketAddr;
3use std::sync::Arc;
4
5use bytes::Bytes;
6use socks_lib::io::{AsyncRead, AsyncReadExt, AsyncWrite};
7use socks_lib::v5::server::Handler;
8use socks_lib::v5::{Address as Socks5Address, Request, Response, Stream, UdpPacket};
9use tokio::net::UdpSocket;
10
11use ombrac_macros::{debug, error, info, warn};
12use ombrac_transport::{Connection, Initiator};
13
14use crate::client::Client;
15#[cfg(feature = "datagram")]
16use crate::client::UdpSession;
17
18pub struct CommandHandler<T, C>
19where
20 T: Initiator<Connection = C> + Send + Sync + 'static,
21 C: Connection + Send + Sync + 'static,
22{
23 client: Arc<Client<T, C>>,
24}
25
26impl<T, C> CommandHandler<T, C>
27where
28 T: Initiator<Connection = C> + Send + Sync + 'static,
29 C: Connection + Send + Sync + 'static,
30{
31 pub fn new(client: Arc<Client<T, C>>) -> Self {
32 Self { client }
33 }
34
35 async fn handle_connect(
36 &self,
37 address: Socks5Address,
38 mut stream: &mut Stream<impl AsyncRead + AsyncWrite + Unpin>,
39 ) -> io::Result<()> {
40 let dst_addr = util::socks_to_ombrac_addr(address)?;
41 let mut dest_stream = self.client.open_bidirectional(dst_addr.clone()).await?;
42 match ombrac_transport::io::copy_bidirectional(&mut stream, &mut dest_stream).await {
43 Ok(stats) => {
44 #[cfg(feature = "tracing")]
45 tracing::info!(
46 src_addr = stream.local_addr().to_string(),
47 dst_addr = dst_addr.to_string(),
48 send = stats.a_to_b_bytes,
49 recv = stats.b_to_a_bytes,
50 status = "ok",
51 "Connect"
52 );
53 }
54 Err((err, stats)) => {
55 #[cfg(feature = "tracing")]
56 tracing::error!(
57 src_addr = stream.local_addr().to_string(),
58 dst_addr = dst_addr.to_string(),
59 send = stats.a_to_b_bytes,
60 recv = stats.b_to_a_bytes,
61 status = "err",
62 error = %err,
63 "Connect"
64 );
65 return Err(err);
66 }
67 };
68
69 Ok(())
70 }
71
72 #[cfg(feature = "datagram")]
74 async fn handle_associate(
75 &self,
76 stream: &mut Stream<impl AsyncRead + AsyncWrite + Unpin + Send>,
77 ) -> io::Result<()> {
78 info!("SOCKS: Handling UDP ASSOCIATE from {}", stream.peer_addr());
79
80 let udp_session = self.client.open_associate();
81
82 let relay_socket = UdpSocket::bind("0.0.0.0:0").await?;
83 let relay_addr = SocketAddr::new(
84 stream.local_addr().ip(),
85 relay_socket.local_addr().unwrap().port(),
86 );
87 info!("SOCKS: UDP relay listening on {}", relay_addr);
88
89 let response_addr = Socks5Address::from(relay_addr);
90 stream
91 .write_response(&Response::Success(&response_addr))
92 .await?;
93
94 self.udp_relay_loop(stream, relay_socket, udp_session).await
96 }
97
98 #[cfg(feature = "datagram")]
104 async fn udp_relay_loop(
105 &self,
106 stream: &mut Stream<impl AsyncRead + AsyncWrite + Unpin>,
107 relay_socket: UdpSocket,
108 mut udp_session: UdpSession<T, C>,
109 ) -> io::Result<()> {
110 let mut client_udp_src: Option<SocketAddr> = None;
111 let mut buf = vec![0u8; 65535]; loop {
114 tokio::select! {
115 biased;
117
118 result = stream.read_u8() => {
121 match result {
122 Ok(0) | Err(_) => {
123 info!("SOCKS: TCP control connection for UDP associate closed. Ending session.");
124 return Ok(());
125 }
126 _ => {}
127 }
128 }
129
130 Some((data, from_addr)) = udp_session.recv_from() => {
131 if let Some(dest) = client_udp_src {
132 let socks_from_addr = util::ombrac_addr_to_socks(from_addr)?;
133 let udp_response = UdpPacket::un_frag(socks_from_addr, data);
134 relay_socket.send_to(&udp_response.to_bytes(), dest).await?;
135 } else {
136 warn!("SOCKS: Received packet from tunnel before client, discarding.");
137 }
138 }
139
140 result = relay_socket.recv_from(&mut buf) => {
141 let (len, src) = result?;
142 if client_udp_src.is_none() {
143 client_udp_src = Some(src);
144 info!("SOCKS: First UDP packet received from client {}", src);
145 }
146 let mut bytes = Bytes::copy_from_slice(&buf[..len]);
147 let udp_request = UdpPacket::from_bytes(&mut bytes)?;
148 let payload = udp_request.data;
149 let dest_addr = util::socks_to_ombrac_addr(udp_request.address)?;
150
151 udp_session.send_to(payload, dest_addr).await?;
152 }
153 }
154 }
155 }
156}
157
158impl<T, C> Handler for CommandHandler<T, C>
159where
160 T: Initiator<Connection = C> + Send + Sync + 'static,
161 C: Connection + Send + Sync + 'static,
162{
163 async fn handle<S>(&self, stream: &mut Stream<S>, request: Request) -> io::Result<()>
164 where
165 S: AsyncRead + AsyncWrite + Unpin + Send + Sync,
166 {
167 debug!("SOCKS Request: {:?}", request);
168
169 match request {
170 Request::Connect(address) => {
171 stream.write_response_unspecified().await?;
172
173 if let Err(err) = self.handle_connect(address.clone(), stream).await {
174 if err.kind() != io::ErrorKind::BrokenPipe
175 && err.kind() != io::ErrorKind::ConnectionReset
176 {
177 error!("SOCKS: Connect to {} failed: {}", address, err);
178 }
179 return Err(err);
180 }
181 }
182 #[cfg(feature = "datagram")]
183 Request::Associate(_) => {
184 if let Err(err) = self.handle_associate(stream).await {
185 if err.kind() != io::ErrorKind::BrokenPipe
186 && err.kind() != io::ErrorKind::ConnectionReset
187 {
188 error!(
189 "SOCKS: Associate from {} failed: {}",
190 stream.peer_addr(),
191 err
192 );
193 }
194 return Err(err);
195 }
196 }
197 _ => {
198 warn!("SOCKS: BIND command is not supported.");
199 stream.write_response_unsupported().await?;
200 }
201 }
202
203 Ok(())
204 }
205}
206
207mod util {
208 use ombrac::protocol::Address as OmbracAddress;
209 use socks_lib::v5::Address as Socks5Address;
210 use std::io;
211
212 pub(super) fn socks_to_ombrac_addr(addr: Socks5Address) -> io::Result<OmbracAddress> {
213 let result = match addr {
214 Socks5Address::IPv4(value) => OmbracAddress::SocketV4(value),
215 Socks5Address::IPv6(value) => OmbracAddress::SocketV6(value),
216 Socks5Address::Domain(domain, port) => {
217 OmbracAddress::Domain(domain.as_bytes().to_owned(), port)
218 }
219 };
220
221 Ok(result)
222 }
223
224 pub(super) fn ombrac_addr_to_socks(addr: OmbracAddress) -> io::Result<Socks5Address> {
225 let result = match addr {
226 OmbracAddress::SocketV4(sa) => Socks5Address::IPv4(sa),
227 OmbracAddress::SocketV6(sa) => Socks5Address::IPv6(sa),
228 OmbracAddress::Domain(domain_bytes, port) => {
229 Socks5Address::Domain(domain_bytes.try_into()?, port)
230 }
231 };
232
233 Ok(result)
234 }
235}