use crate::connection_controller::ConnectionReceiver;
use crate::ordered_sender::OrderedMessageSender;
use nym_socks5_requests::{ConnectionId, SocketData};
use nym_task::connections::LaneQueueLengths;
use nym_task::ShutdownTracker;
use std::fmt::Debug;
use std::{sync::Arc, time::Duration};
use tokio::{net::TcpStream, sync::Notify};
mod inbound;
mod outbound;
const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(3);
const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(60);
#[derive(Debug)]
pub struct ProxyMessage {
pub data: Vec<u8>,
pub socket_closed: bool,
}
impl From<(Vec<u8>, bool)> for ProxyMessage {
fn from(data: (Vec<u8>, bool)) -> Self {
ProxyMessage {
data: data.0,
socket_closed: data.1,
}
}
}
pub type MixProxySender<S> = tokio::sync::mpsc::Sender<S>;
pub type MixProxyReader<S> = tokio::sync::mpsc::Receiver<S>;
#[derive(Debug)]
pub struct ProxyRunner<S> {
mix_receiver: Option<ConnectionReceiver>,
mix_sender: MixProxySender<S>,
socket: Option<TcpStream>,
local_destination_address: String,
remote_source_address: String,
connection_id: ConnectionId,
lane_queue_lengths: Option<LaneQueueLengths>,
available_plaintext_per_mix_packet: usize,
shutdown_tracker: ShutdownTracker,
}
impl<S> ProxyRunner<S>
where
S: Debug + Send + 'static,
{
#[allow(clippy::too_many_arguments)]
pub fn new(
socket: TcpStream,
local_destination_address: String, remote_source_address: String,
mix_receiver: ConnectionReceiver,
mix_sender: MixProxySender<S>,
available_plaintext_per_mix_packet: usize,
connection_id: ConnectionId,
lane_queue_lengths: Option<LaneQueueLengths>,
shutdown_tracker: ShutdownTracker,
) -> Self {
ProxyRunner {
mix_receiver: Some(mix_receiver),
mix_sender,
socket: Some(socket),
local_destination_address,
remote_source_address,
connection_id,
lane_queue_lengths,
available_plaintext_per_mix_packet,
shutdown_tracker,
}
}
pub async fn run<F>(mut self, adapter_fn: F) -> Self
where
F: Fn(SocketData) -> S + Send + Sync + 'static,
{
let (read_half, write_half) = self.socket.take().unwrap().into_split();
let shutdown_notify = Arc::new(Notify::new());
let ordered_sender = OrderedMessageSender::new(
self.local_destination_address.clone(),
self.remote_source_address.clone(),
self.connection_id,
self.mix_sender.clone(),
adapter_fn,
);
let inbound_future = inbound::run_inbound(
read_half,
ordered_sender,
self.connection_id,
self.available_plaintext_per_mix_packet,
Arc::clone(&shutdown_notify),
self.lane_queue_lengths.clone(),
self.shutdown_tracker.clone_shutdown_token(),
);
let outbound_future = outbound::run_outbound(
write_half,
self.local_destination_address.clone(),
self.remote_source_address.clone(),
self.mix_receiver.take().unwrap(),
self.connection_id,
shutdown_notify,
self.shutdown_tracker.clone_shutdown_token(),
);
let handle_inbound = self.shutdown_tracker.try_spawn_named(
inbound_future,
&format!(
"Socks5Inbound::{}::{}",
self.remote_source_address, self.connection_id
),
);
let handle_outbound = self.shutdown_tracker.try_spawn_named(
outbound_future,
&format!(
"Socks5Outbound::{}::{}",
self.remote_source_address, self.connection_id
),
);
let (inbound_result, outbound_result) =
futures::future::join(handle_inbound, handle_outbound).await;
if inbound_result.is_err() || outbound_result.is_err() {
panic!("TODO: some future error?")
}
let read_half = inbound_result.unwrap();
let (write_half, mix_receiver) = outbound_result.unwrap();
self.socket = Some(write_half.reunite(read_half).unwrap());
self.mix_receiver = Some(mix_receiver);
self
}
pub fn into_inner(mut self) -> (TcpStream, ConnectionReceiver) {
(
self.socket.take().unwrap(),
self.mix_receiver.take().unwrap(),
)
}
}