use std::{collections::HashMap, net::{SocketAddr, ToSocketAddrs}, sync::Arc, time::{Duration, Instant}};
use bytes::BytesMut;
use proto::{ConnectionError, EcnCodepoint, Transmit};
use tracing::debug;
use udp::RecvMeta;
use crate::{udp_transmit, AsyncUdpSocket, Runtime};
const RATE_LIMIT_CYCLE: Duration = Duration::from_millis(10);
pub(crate) fn bind_upstream_socket(upstream_addr: &str) -> Result<(std::net::UdpSocket, SocketAddr), ConnectionError> {
let upstream_addr: SocketAddr = upstream_addr
.to_socket_addrs()
.map_err(|x| ConnectionError::JlsForwardError(x.to_string()))?
.next()
.ok_or(ConnectionError::JlsForwardError(
"jls upstream domain name resolved failed".into(),
))?;
let bind_addr = if upstream_addr.is_ipv6() {
"[::]:0"
} else {
"0.0.0.0:0"
};
let socket =
std::net::UdpSocket::bind(bind_addr.parse::<SocketAddr>().unwrap())
.map_err(|x| ConnectionError::JlsForwardError(x.to_string()))?;
Ok((socket, upstream_addr))
}
pub(crate) fn insert_forward_conn(
jls_state: &mut JlsState,
runtime: &dyn Runtime,
trans: Vec<Transmit>,
response_buffer: &[u8],
upstream_addr: &str,
remote_addr: SocketAddr,
now: Instant) -> Result<(), ConnectionError> {
let (socket, upstream_addr) = bind_upstream_socket(upstream_addr)?;
debug!("new forward connection");
let udp_socket = runtime.wrap_udp_socket(socket).unwrap();
let recv_buf = vec![0; 32 * 32 * 1024]; let byte_per_cycle = jls_state.rate_bps.min(u64::MAX/RATE_LIMIT_CYCLE.as_millis() as u64) *
RATE_LIMIT_CYCLE.as_millis() as u64 /
(1000 * 8 );
let byte_per_cycle = byte_per_cycle.min(usize::MAX as u64) as usize;
let jls_conn = JlsForwardConnection {
upstream_socket: udp_socket.clone(),
upstream_addr: upstream_addr,
from_upstream: recv_buf.into(),
active_time: now.clone(),
send_limiter: JlsRateLimiter::new(RATE_LIMIT_CYCLE, byte_per_cycle as usize), recv_limiter: JlsRateLimiter::new(RATE_LIMIT_CYCLE, byte_per_cycle as usize), };
let mut pos = 0;
for mut trans in trans {
let size = trans.size;
trans.destination = upstream_addr;
respond(trans, &response_buffer[pos..], &*udp_socket);
pos += size;
}
jls_state
.upstream_connections
.insert(remote_addr, jls_conn);
Ok(())
}
#[derive(Debug)]
pub(crate) struct JlsForwardConnection {
pub(crate) upstream_socket: Arc<dyn AsyncUdpSocket>,
pub(crate) upstream_addr: SocketAddr,
pub(crate) from_upstream: Box<[u8]>,
pub(crate) active_time: Instant,
pub(crate) send_limiter: JlsRateLimiter,
pub(crate) recv_limiter: JlsRateLimiter,
}
#[derive(Debug, Default)]
pub(crate) struct JlsState {
pub(crate) upstream_connections: HashMap<SocketAddr, JlsForwardConnection>,
pub(crate) rate_bps: u64,
}
impl JlsState {
pub(crate) fn new(rate_bps: u64) -> Self {
Self {
upstream_connections: HashMap::new(),
rate_bps,
}
}
pub(crate) fn handle_jls_forward(&mut self,
buf: &BytesMut,
meta: &RecvMeta,
now: Instant) -> bool {
match self.upstream_connections.get_mut(&meta.addr) {
Some(conn) => {
let trans = Transmit {
destination: conn.upstream_addr,
ecn: meta.ecn.map(|x| EcnCodepoint::from_bits(x as u8).unwrap()),
segment_size: None,
size: buf.len(),
src_ip: None,
};
conn.active_time = now;
tracing::trace!("jls forward to upstream {} bytes",
trans.size);
conn.send_limiter.try_send(
buf,trans, &*conn.upstream_socket, now);
true
}
None => false,
}
}
}
fn respond(transmit: proto::Transmit, response_buffer: &[u8], socket: &dyn AsyncUdpSocket) {
_ = socket.try_send(&udp_transmit(&transmit, &response_buffer[..transmit.size]));
}
#[derive(Debug)]
pub(crate) struct JlsRateLimiter {
last_cycle: Instant,
data_handled: usize, cycle_period: Duration,
pub bytes_per_cycle: usize,
}
impl JlsRateLimiter {
pub(crate) fn new(cycle_period: Duration, bytes_per_cycle: usize) -> Self {
Self {
last_cycle: Instant::now(),
data_handled: 0,
cycle_period,
bytes_per_cycle,
}
}
pub(crate) fn should_send(&mut self, data_size: usize, now: Instant) -> bool {
if now.duration_since(self.last_cycle) >= self.cycle_period {
self.data_handled = 0;
self.last_cycle = now;
}
if self.data_handled + data_size > self.bytes_per_cycle { return false;
}
self.data_handled += data_size;
true
}
pub(crate) fn try_send(&mut self, buf: &[u8], trans: Transmit, socket: &dyn AsyncUdpSocket, now: Instant) -> bool {
if self.should_send(buf.len(), now) {
respond(trans, buf, socket);
return true;
}
false
}
}