use std::{sync::Arc, time::Duration};
use borsh::BorshDeserialize;
use futures::Future;
use qos_core::{
io::{IOError, Listener, StreamPool},
server::{PermittedStream, SocketServer, SocketServerError},
};
use tokio::sync::Semaphore;
use crate::{
error::QosNetError, proxy_connection::ProxyConnection, proxy_msg::ProxyMsg,
};
const PROXY_CLIENT_TIMEOUT: Duration = Duration::from_secs(10);
const MEGABYTE: usize = 1024 * 1024;
const MAX_ENCODED_MSG_LEN: usize = 128 * MEGABYTE;
pub struct Proxy {
tcp_connection: Option<ProxyConnection>,
sock_stream: PermittedStream,
}
impl Proxy {
pub fn new(sock_stream: PermittedStream) -> Self {
Self { tcp_connection: None, sock_stream }
}
async fn connect_by_name(
&mut self,
hostname: String,
port: u16,
dns_resolvers: Vec<String>,
dns_port: u16,
) -> ProxyMsg {
match ProxyConnection::new_from_name(
hostname.clone(),
port,
dns_resolvers.clone(),
dns_port,
)
.await
{
Ok(conn) => {
let remote_ip = conn.ip.clone();
self.tcp_connection = Some(conn);
println!("Connection to {hostname}@{remote_ip} established");
ProxyMsg::ConnectResponse { remote_ip }
}
Err(e) => {
println!("error while establishing connection: {e:?}");
ProxyMsg::ProxyError(e)
}
}
}
async fn connect_by_ip(&mut self, ip: String, port: u16) -> ProxyMsg {
match ProxyConnection::new_from_ip(ip.clone(), port).await {
Ok(conn) => {
let remote_ip = conn.ip.clone();
self.tcp_connection = Some(conn);
println!("Connection to {ip} established");
ProxyMsg::ConnectResponse { remote_ip }
}
Err(e) => {
println!("error while establishing connection: {e:?}");
ProxyMsg::ProxyError(e)
}
}
}
async fn process_req(&mut self, req_bytes: Vec<u8>) -> Vec<u8> {
if req_bytes.len() > MAX_ENCODED_MSG_LEN {
return borsh::to_vec(&ProxyMsg::ProxyError(
QosNetError::OversizedPayload,
))
.expect("ProtocolMsg can always be serialized. qed.");
}
let resp = match ProxyMsg::try_from_slice(&req_bytes) {
Ok(req) => match req {
ProxyMsg::StatusRequest => ProxyMsg::StatusResponse(0),
ProxyMsg::ConnectByNameRequest {
hostname,
port,
dns_resolvers,
dns_port,
} => {
self.connect_by_name(
hostname,
port,
dns_resolvers,
dns_port,
)
.await
}
ProxyMsg::ConnectByIpRequest { ip, port } => {
self.connect_by_ip(ip, port).await
}
ProxyMsg::ProxyError(err) => ProxyMsg::ProxyError(err),
_ => ProxyMsg::ProxyError(QosNetError::InvalidMsg),
},
Err(_) => ProxyMsg::ProxyError(QosNetError::InvalidMsg),
};
borsh::to_vec(&resp)
.expect("Protocol message can always be serialized. qed!")
}
}
impl Proxy {
async fn run(&mut self) -> Result<(), IOError> {
if self.tcp_connection.is_some() {
return Err(IOError::UnexpectedProxyConnection);
}
match tokio::time::timeout(PROXY_CLIENT_TIMEOUT, {
self.connect_and_stream()
})
.await
{
Ok(result) => result,
Err(err) => {
eprintln!("proxy timeout: {err}");
Err(IOError::RecvTimeout)
}
}
}
async fn connect_and_stream(&mut self) -> Result<(), IOError> {
let req_bytes = self.sock_stream.recv().await?;
let resp_bytes = self.process_req(req_bytes).await;
if let Err(err) = self.sock_stream.send(&resp_bytes).await {
self.tcp_connection = None; return Err(err);
}
if let Some(tcp_connection) = &mut self.tcp_connection {
let result = tokio::io::copy_bidirectional(
&mut self.sock_stream.stream(),
&mut tcp_connection.tcp_stream,
)
.await
.map(|_| ())
.map_err(IOError::from);
self.tcp_connection = None;
result
} else {
Err(IOError::MissingProxyConnection)
}
}
}
pub trait ProxyServer {
fn listen_proxy(
pool: StreamPool,
max_connections: usize,
) -> impl Future<Output = Result<Box<Self>, SocketServerError>> + Send;
}
impl ProxyServer for SocketServer {
async fn listen_proxy(
pool: StreamPool,
max_connections: usize,
) -> Result<Box<Self>, SocketServerError> {
println!("`SocketServer` proxy listening on pool size {}", pool.len());
let listeners = pool.listen()?;
let mut tasks = Vec::new();
for listener in listeners {
let task = tokio::spawn(async move {
accept_loop_proxy(listener, max_connections).await
});
tasks.push(task);
}
Ok(Box::new(Self { pool, tasks, max_connections }))
}
}
async fn accept_loop_proxy(
listener: Listener,
max_connections: usize,
) -> Result<(), SocketServerError> {
let connections = Arc::new(Semaphore::const_new(max_connections));
loop {
println!("Proxy::accept_loop_proxy accepting connection");
let stream =
PermittedStream::accept(&listener, connections.clone()).await?;
println!("Proxy::accept_loop_proxy new connection accepted");
tokio::task::spawn(async move {
let mut proxy = Proxy::new(stream);
match proxy.run().await {
Ok(()) => {
println!("Proxy::run done");
}
Err(IOError::RecvConnectionClosed) => {} Err(err) => {
eprintln!("Error on proxy run {err:?} rerunning");
}
}
});
}
}