use std::io::{self, Cursor};
use bytes::{BufMut, BytesMut};
use tokio::net::{ToSocketAddrs, UdpSocket};
use shadowsocks::relay::socks5::{Address, Error, UdpAssociateHeader};
use super::tcp_client::Socks5TcpClient;
pub struct Socks5UdpClient {
socket: UdpSocket,
#[allow(dead_code)]
assoc_client: Option<Socks5TcpClient>,
}
impl Socks5UdpClient {
pub async fn bind<A>(addrs: A) -> io::Result<Self>
where
A: ToSocketAddrs,
{
Ok(Self {
socket: UdpSocket::bind(addrs).await?,
assoc_client: None,
})
}
pub async fn associate<P>(&mut self, proxy: P) -> Result<(), Error>
where
P: ToSocketAddrs,
{
if self.assoc_client.is_some() {
let err = io::Error::other("udp is associated");
return Err(err.into());
}
let local_addr = self.socket.local_addr()?;
let (assoc_client, proxy_addr) = Socks5TcpClient::udp_associate(local_addr, proxy).await?;
match proxy_addr {
Address::SocketAddress(sa) => self.socket.connect(sa).await?,
Address::DomainNameAddress(ref dname, port) => self.socket.connect((dname.as_str(), port)).await?,
}
self.assoc_client = Some(assoc_client);
Ok(())
}
pub async fn send_to<A>(&self, frag: u8, buf: &[u8], target: A) -> Result<usize, Error>
where
A: Into<Address>,
{
self.check_associated()?;
let header = UdpAssociateHeader::new(frag, target.into());
let header_len = header.serialized_len();
let mut send_buf = BytesMut::with_capacity(header.serialized_len() + buf.len());
header.write_to_buf(&mut send_buf);
send_buf.put_slice(buf);
let n = self.socket.send(&send_buf).await?;
Ok(n.saturating_sub(header_len))
}
pub async fn recv_from(&self, recv_buf: &mut [u8]) -> Result<(usize, u8, Address), Error> {
self.check_associated()?;
let n = self.socket.recv(recv_buf).await?;
let mut cur = Cursor::new(&recv_buf[..n]);
let header = UdpAssociateHeader::read_from(&mut cur).await?;
let pos = cur.position() as usize;
recv_buf.copy_within(pos.., 0);
Ok((n - pos, header.frag, header.address))
}
fn check_associated(&self) -> io::Result<()> {
if self.assoc_client.is_none() {
let err = io::Error::other("udp not associated");
return Err(err);
}
Ok(())
}
}