use std::{io, net::SocketAddr, sync::Arc};
use once_cell::sync::Lazy;
use tokio::{
io::{AsyncRead, AsyncWrite},
net::TcpStream,
};
use crate::{
config::{ServerAddr, ServerConfig, ServerUserManager},
context::SharedContext,
crypto::CipherKind,
net::{AcceptOpts, TcpListener},
relay::tcprelay::proxy_stream::server::ProxyServerStream,
};
pub struct ProxyListener {
listener: TcpListener,
method: CipherKind,
key: Box<[u8]>,
context: SharedContext,
user_manager: Option<Arc<ServerUserManager>>,
}
static DEFAULT_ACCEPT_OPTS: Lazy<AcceptOpts> = Lazy::new(Default::default);
impl ProxyListener {
pub async fn bind(context: SharedContext, svr_cfg: &ServerConfig) -> io::Result<ProxyListener> {
ProxyListener::bind_with_opts(context, svr_cfg, DEFAULT_ACCEPT_OPTS.clone()).await
}
pub async fn bind_with_opts(
context: SharedContext,
svr_cfg: &ServerConfig,
accept_opts: AcceptOpts,
) -> io::Result<ProxyListener> {
let listener = match svr_cfg.external_addr() {
ServerAddr::SocketAddr(sa) => TcpListener::bind_with_opts(sa, accept_opts).await?,
ServerAddr::DomainName(domain, port) => {
lookup_then!(&context, domain, *port, |addr| {
TcpListener::bind_with_opts(&addr, accept_opts.clone()).await
})?
.1
}
};
Ok(ProxyListener::from_listener(context, listener, svr_cfg))
}
pub fn from_listener(context: SharedContext, listener: TcpListener, svr_cfg: &ServerConfig) -> ProxyListener {
ProxyListener {
listener,
method: svr_cfg.method(),
key: svr_cfg.key().to_vec().into_boxed_slice(),
context,
user_manager: svr_cfg.clone_user_manager(),
}
}
#[inline]
pub async fn accept(&self) -> io::Result<(ProxyServerStream<TcpStream>, SocketAddr)> {
self.accept_map(|s| s).await
}
pub async fn accept_map<F, S>(&self, map_fn: F) -> io::Result<(ProxyServerStream<S>, SocketAddr)>
where
F: FnOnce(TcpStream) -> S,
S: AsyncRead + AsyncWrite + Unpin,
{
let (stream, peer_addr) = self.listener.accept().await?;
let stream = map_fn(stream);
let stream = ProxyServerStream::from_stream(
self.context.clone(),
stream,
self.method,
&self.key,
self.user_manager.clone(),
);
Ok((stream, peer_addr))
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.listener.local_addr()
}
pub fn get_ref(&self) -> &TcpListener {
&self.listener
}
pub fn into_inner(self) -> TcpListener {
self.listener
}
}