use std::sync::Arc;
use std::time::Duration;
use crate::client_wire::{build_inner_packet, build_random_mdh_packet, DEFAULT_MDH_LEN};
use crate::crypto::SessionKeys;
use crate::error::{Error, Result};
use crate::protocol::{ControlPayload, InnerType};
use tokio::net::UdpSocket;
use tokio::sync::mpsc;
use tokio::time;
pub struct UploadConfig {
pub burst_size: usize,
pub keepalive_interval: Duration,
}
impl Default for UploadConfig {
fn default() -> Self {
Self {
burst_size: 63,
keepalive_interval: Duration::from_secs(25),
}
}
}
pub trait PacketEncryptor: Send {
fn encrypt_data(&mut self, payload: &[u8]) -> Result<Vec<u8>>;
fn encrypt_control(&mut self, payload: &ControlPayload) -> Result<Vec<u8>>;
fn encrypt_keepalive(&mut self) -> Result<Vec<u8>>;
fn on_data_sent(&mut self, payload_len: usize);
}
pub struct ZeroMdhEncryptor {
keys: SessionKeys,
counter: u64,
seq: u16,
mdh_len: usize,
}
impl ZeroMdhEncryptor {
pub fn new(keys: SessionKeys, counter: u64, seq: u16) -> Self {
Self {
keys,
counter,
seq,
mdh_len: DEFAULT_MDH_LEN,
}
}
pub fn with_mdh_len(keys: SessionKeys, counter: u64, seq: u16, mdh_len: usize) -> Self {
Self {
keys,
counter,
seq,
mdh_len,
}
}
}
impl PacketEncryptor for ZeroMdhEncryptor {
fn encrypt_data(&mut self, payload: &[u8]) -> Result<Vec<u8>> {
let inner = build_inner_packet(InnerType::Data, self.seq, payload);
self.seq = self.seq.wrapping_add(1);
build_random_mdh_packet(&self.keys, &mut self.counter, &inner, None, self.mdh_len)
}
fn encrypt_control(&mut self, payload: &ControlPayload) -> Result<Vec<u8>> {
let bytes = payload.encode()?;
let inner = build_inner_packet(InnerType::Control, self.seq, &bytes);
self.seq = self.seq.wrapping_add(1);
build_random_mdh_packet(&self.keys, &mut self.counter, &inner, None, self.mdh_len)
}
fn encrypt_keepalive(&mut self) -> Result<Vec<u8>> {
let keepalive = ControlPayload::Keepalive.encode()?;
let inner = build_inner_packet(InnerType::Control, self.seq, &keepalive);
self.seq = self.seq.wrapping_add(1);
build_random_mdh_packet(&self.keys, &mut self.counter, &inner, None, self.mdh_len)
}
fn on_data_sent(&mut self, _payload_len: usize) {}
}
fn is_transient_send_error(e: &std::io::Error) -> bool {
use std::io::ErrorKind::*;
matches!(
e.kind(),
NetworkUnreachable | HostUnreachable | NetworkDown | AddrNotAvailable | Interrupted
)
}
async fn send_tolerant(udp: &UdpSocket, data: &[u8]) -> Result<()> {
match udp.send(data).await {
Ok(_) => Ok(()),
Err(e) if is_transient_send_error(&e) => {
tracing::debug!("upload: transient send error (dropped packet): {e}");
Ok(())
}
Err(e) => Err(Error::Io(e)),
}
}
pub async fn run_upload_loop(
rx: &mut mpsc::Receiver<Vec<u8>>,
mut control_rx: Option<&mut mpsc::Receiver<ControlPayload>>,
udp: &Arc<UdpSocket>,
enc: &mut impl PacketEncryptor,
config: &UploadConfig,
) -> Result<()> {
let mut ka_interval = time::interval(config.keepalive_interval);
let mut data_packet_count: u64 = 0;
ka_interval.tick().await;
loop {
tokio::select! {
biased;
maybe_pkt = rx.recv() => {
let pkt_data = match maybe_pkt {
Some(p) => p,
None => return Err(Error::Channel("TUN->UDP channel closed".into())),
};
let encrypted = enc.encrypt_data(&pkt_data)?;
send_tolerant(udp, &encrypted).await?;
data_packet_count = data_packet_count.wrapping_add(1);
enc.on_data_sent(pkt_data.len());
for _ in 0..config.burst_size {
match rx.try_recv() {
Ok(pkt) => {
let encrypted = enc.encrypt_data(&pkt)?;
send_tolerant(udp, &encrypted).await?;
data_packet_count = data_packet_count.wrapping_add(1);
enc.on_data_sent(pkt.len());
}
Err(mpsc::error::TryRecvError::Empty) => break,
Err(mpsc::error::TryRecvError::Disconnected) => {
return Err(Error::Channel("TUN->UDP channel closed".into()));
}
}
}
}
_ = ka_interval.tick() => {
let encrypted = enc.encrypt_keepalive()?;
send_tolerant(udp, &encrypted).await?;
}
maybe_ctrl = async {
if let Some(crx) = control_rx.as_mut() {
crx.recv().await
} else {
std::future::pending().await
}
} => {
if let Some(payload) = maybe_ctrl {
let encrypted = enc.encrypt_control(&payload)?;
send_tolerant(udp, &encrypted).await?;
}
}
}
}
}