ombrac_client/endpoint/
socks.rs

1use std::io;
2use std::net::SocketAddr;
3use std::sync::Arc;
4
5use bytes::Bytes;
6use socks_lib::io::{AsyncRead, AsyncReadExt, AsyncWrite};
7use socks_lib::v5::server::Handler;
8use socks_lib::v5::{Address as Socks5Address, Request, Response, Stream, UdpPacket};
9use tokio::net::UdpSocket;
10
11use ombrac_macros::{debug, error, info, warn};
12use ombrac_transport::{Connection, Initiator};
13
14use crate::client::Client;
15#[cfg(feature = "datagram")]
16use crate::client::UdpSession;
17
18pub struct CommandHandler<T, C>
19where
20    T: Initiator<Connection = C> + Send + Sync + 'static,
21    C: Connection + Send + Sync + 'static,
22{
23    client: Arc<Client<T, C>>,
24}
25
26impl<T, C> CommandHandler<T, C>
27where
28    T: Initiator<Connection = C> + Send + Sync + 'static,
29    C: Connection + Send + Sync + 'static,
30{
31    pub fn new(client: Arc<Client<T, C>>) -> Self {
32        Self { client }
33    }
34
35    async fn handle_connect(
36        &self,
37        address: Socks5Address,
38        mut stream: &mut Stream<impl AsyncRead + AsyncWrite + Unpin>,
39    ) -> io::Result<()> {
40        let dst_addr = util::socks_to_ombrac_addr(address)?;
41        let mut dest_stream = self.client.open_bidirectional(dst_addr.clone()).await?;
42        match ombrac_transport::io::copy_bidirectional(&mut stream, &mut dest_stream).await {
43            Ok(stats) => {
44                #[cfg(feature = "tracing")]
45                tracing::info!(
46                    src_addr = stream.local_addr().to_string(),
47                    dst_addr = dst_addr.to_string(),
48                    send = stats.a_to_b_bytes,
49                    recv = stats.b_to_a_bytes,
50                    status = "ok",
51                    "Connect"
52                );
53            }
54            Err((err, stats)) => {
55                #[cfg(feature = "tracing")]
56                tracing::error!(
57                    src_addr = stream.local_addr().to_string(),
58                    dst_addr = dst_addr.to_string(),
59                    send = stats.a_to_b_bytes,
60                    recv = stats.b_to_a_bytes,
61                    status = "err",
62                    error = %err,
63                    "Connect"
64                );
65                return Err(err);
66            }
67        };
68
69        Ok(())
70    }
71
72    /// Handles the SOCKS5 UDP ASSOCIATE command.
73    #[cfg(feature = "datagram")]
74    async fn handle_associate(
75        &self,
76        stream: &mut Stream<impl AsyncRead + AsyncWrite + Unpin + Send>,
77    ) -> io::Result<()> {
78        info!("SOCKS: Handling UDP ASSOCIATE from {}", stream.peer_addr());
79
80        let udp_session = self.client.open_associate();
81
82        let relay_socket = UdpSocket::bind("0.0.0.0:0").await?;
83        let relay_addr = SocketAddr::new(
84            stream.local_addr().ip(),
85            relay_socket.local_addr().unwrap().port(),
86        );
87        info!("SOCKS: UDP relay listening on {}", relay_addr);
88
89        let response_addr = Socks5Address::from(relay_addr);
90        stream
91            .write_response(&Response::Success(&response_addr))
92            .await?;
93
94        // 进入转发循环
95        self.udp_relay_loop(stream, relay_socket, udp_session).await
96    }
97
98    /// The main relay loop for a UDP association.
99    ///
100    /// This loop concurrently handles two data flows:
101    /// - SOCKS Client -> Relay Socket -> ombrac Tunnel -> Destination
102    /// - Destination -> ombrac Tunnel -> Relay Socket -> SOCKS Client
103    #[cfg(feature = "datagram")]
104    async fn udp_relay_loop(
105        &self,
106        stream: &mut Stream<impl AsyncRead + AsyncWrite + Unpin>,
107        relay_socket: UdpSocket,
108        mut udp_session: UdpSession<T, C>,
109    ) -> io::Result<()> {
110        let mut client_udp_src: Option<SocketAddr> = None;
111        let mut buf = vec![0u8; 65535]; // Max UDP packet size
112
113        loop {
114            tokio::select! {
115                // biased; 优先检查控制连接是否关闭
116                biased;
117
118                // 1. 检查 TCP 控制连接是否已关闭。
119                // 如果是,则关联结束,我们应该退出循环。
120                result = stream.read_u8() => {
121                    match result {
122                        Ok(0) | Err(_) => {
123                            info!("SOCKS: TCP control connection for UDP associate closed. Ending session.");
124                            return Ok(());
125                        }
126                        _ => {}
127                    }
128                }
129
130                Some((data, from_addr)) = udp_session.recv_from() => {
131                    if let Some(dest) = client_udp_src {
132                        let socks_from_addr = util::ombrac_addr_to_socks(from_addr)?;
133                        let udp_response = UdpPacket::un_frag(socks_from_addr, data);
134                        relay_socket.send_to(&udp_response.to_bytes(), dest).await?;
135                    } else {
136                        warn!("SOCKS: Received packet from tunnel before client, discarding.");
137                    }
138                }
139
140                result = relay_socket.recv_from(&mut buf) => {
141                    let (len, src) = result?;
142                    if client_udp_src.is_none() {
143                        client_udp_src = Some(src);
144                        info!("SOCKS: First UDP packet received from client {}", src);
145                    }
146                    let mut bytes = Bytes::copy_from_slice(&buf[..len]);
147                    let udp_request = UdpPacket::from_bytes(&mut bytes)?;
148                    let payload = udp_request.data;
149                    let dest_addr = util::socks_to_ombrac_addr(udp_request.address)?;
150
151                    udp_session.send_to(payload, dest_addr).await?;
152                }
153            }
154        }
155    }
156}
157
158impl<T, C> Handler for CommandHandler<T, C>
159where
160    T: Initiator<Connection = C> + Send + Sync + 'static,
161    C: Connection + Send + Sync + 'static,
162{
163    async fn handle<S>(&self, stream: &mut Stream<S>, request: Request) -> io::Result<()>
164    where
165        S: AsyncRead + AsyncWrite + Unpin + Send + Sync,
166    {
167        debug!("SOCKS Request: {:?}", request);
168
169        match request {
170            Request::Connect(address) => {
171                stream.write_response_unspecified().await?;
172
173                if let Err(err) = self.handle_connect(address.clone(), stream).await {
174                    if err.kind() != io::ErrorKind::BrokenPipe
175                        && err.kind() != io::ErrorKind::ConnectionReset
176                    {
177                        error!("SOCKS: Connect to {} failed: {}", address, err);
178                    }
179                    return Err(err);
180                }
181            }
182            #[cfg(feature = "datagram")]
183            Request::Associate(_) => {
184                if let Err(err) = self.handle_associate(stream).await {
185                    if err.kind() != io::ErrorKind::BrokenPipe
186                        && err.kind() != io::ErrorKind::ConnectionReset
187                    {
188                        error!(
189                            "SOCKS: Associate from {} failed: {}",
190                            stream.peer_addr(),
191                            err
192                        );
193                    }
194                    return Err(err);
195                }
196            }
197            _ => {
198                warn!("SOCKS: BIND command is not supported.");
199                stream.write_response_unsupported().await?;
200            }
201        }
202
203        Ok(())
204    }
205}
206
207mod util {
208    use ombrac::protocol::Address as OmbracAddress;
209    use socks_lib::v5::Address as Socks5Address;
210    use std::io;
211
212    pub(super) fn socks_to_ombrac_addr(addr: Socks5Address) -> io::Result<OmbracAddress> {
213        let result = match addr {
214            Socks5Address::IPv4(value) => OmbracAddress::SocketV4(value),
215            Socks5Address::IPv6(value) => OmbracAddress::SocketV6(value),
216            Socks5Address::Domain(domain, port) => {
217                OmbracAddress::Domain(domain.as_bytes().to_owned(), port)
218            }
219        };
220
221        Ok(result)
222    }
223
224    pub(super) fn ombrac_addr_to_socks(addr: OmbracAddress) -> io::Result<Socks5Address> {
225        let result = match addr {
226            OmbracAddress::SocketV4(sa) => Socks5Address::IPv4(sa),
227            OmbracAddress::SocketV6(sa) => Socks5Address::IPv6(sa),
228            OmbracAddress::Domain(domain_bytes, port) => {
229                Socks5Address::Domain(domain_bytes.try_into()?, port)
230            }
231        };
232
233        Ok(result)
234    }
235}