use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;
use bytes::Bytes;
use futures::stream::FuturesUnordered;
use futures::StreamExt;
use governor::Jitter;
use ipnet::IpNet;
use quinn::Connection;
use tokio::sync::mpsc::{Receiver, Sender};
use tracing::{debug, info};
use crate::identity;
use crate::server::address_pool::AddressPoolManager;
use crate::server::session::BandwidthLimiter;
use crate::users::UsersFile;
use quincy::config::ServerProtocolConfig;
use quincy::ip_assignment::{self, IpAssignment};
use quincy::network::packet::Packet;
use quincy::utils::tasks::abort_all;
use quincy::{QuincyError, Result};
const IP_ASSIGNMENT_TIMEOUT: Duration = Duration::from_secs(10);
pub struct New;
pub struct Identified {
pub username: String,
}
pub struct Assigned {
pub username: String,
pub client_address: IpNet,
}
pub struct QuincyConnection<S> {
connection: Connection,
ingress_queue: Sender<Packet>,
state: S,
}
impl QuincyConnection<New> {
pub fn new(connection: Connection, tun_queue: Sender<Packet>) -> Self {
Self {
connection,
ingress_queue: tun_queue,
state: New,
}
}
pub fn identify(
self,
protocol: &ServerProtocolConfig,
users: &UsersFile,
) -> Result<QuincyConnection<Identified>> {
let username = identity::identify_peer(&self.connection, protocol, users)?;
Ok(QuincyConnection {
connection: self.connection,
ingress_queue: self.ingress_queue,
state: Identified { username },
})
}
}
impl QuincyConnection<Identified> {
#[allow(dead_code)]
pub fn username(&self) -> &str {
&self.state.username
}
pub async fn assign_ip(
self,
address_pool: &AddressPoolManager,
server_address: IpNet,
) -> Result<QuincyConnection<Assigned>> {
let client_address = address_pool
.allocate_address(&self.state.username)
.ok_or(quincy::error::AuthError::AddressPoolExhausted)?;
let assignment = IpAssignment {
client_address,
server_address,
};
if let Err(e) =
ip_assignment::send_ip_assignment(&self.connection, &assignment, IP_ASSIGNMENT_TIMEOUT)
.await
{
address_pool.release_address(&self.state.username, &client_address.addr());
return Err(e);
}
info!(
"Connection established: user = {}, client address = {}, remote address = {}",
self.state.username,
client_address.addr(),
self.connection.remote_address().ip(),
);
Ok(QuincyConnection {
connection: self.connection,
ingress_queue: self.ingress_queue,
state: Assigned {
username: self.state.username,
client_address,
},
})
}
}
impl QuincyConnection<Assigned> {
pub fn username(&self) -> &str {
&self.state.username
}
pub fn client_address(&self) -> IpNet {
self.state.client_address
}
pub async fn run(
self,
egress_queue: Receiver<Bytes>,
rate_limiter: Option<Arc<BandwidthLimiter>>,
#[cfg(feature = "metrics")] metrics_interval: Duration,
) -> (Self, QuincyError) {
let client_address = self.state.client_address.addr();
let mut tasks = FuturesUnordered::new();
tasks.extend([
tokio::spawn(Self::process_outgoing_data(
self.connection.clone(),
egress_queue,
rate_limiter.clone(),
)),
tokio::spawn(Self::process_incoming_data(
self.connection.clone(),
self.ingress_queue.clone(),
client_address,
rate_limiter,
)),
]);
#[cfg(feature = "metrics")]
tasks.push(tokio::spawn(Self::report_metrics(
self.connection.clone(),
metrics_interval,
self.state.username.clone(),
self.state.client_address.addr(),
)));
let res = tasks
.next()
.await
.expect("tasks is not empty")
.expect("task is joinable");
let _ = abort_all(tasks).await;
match res {
Err(e) => (self, e),
Ok(()) => (
self,
QuincyError::system("Connection task exited unexpectedly"),
),
}
}
async fn process_outgoing_data(
connection: Connection,
mut egress_queue: Receiver<Bytes>,
rate_limiter: Option<Arc<BandwidthLimiter>>,
) -> Result<()> {
loop {
let data = egress_queue
.recv()
.await
.ok_or(QuincyError::system("Egress queue has been closed"))?;
if let Some(ref limiter) = rate_limiter {
let tokens = (data.len() as u32 / 1024)
.max(1)
.try_into()
.expect("token amount is always non-zero");
let _ = limiter
.until_n_ready_with_jitter(tokens, Jitter::up_to(Duration::from_millis(5)))
.await;
}
connection.send_datagram(data)?;
}
}
async fn process_incoming_data(
connection: Connection,
ingress_queue: Sender<Packet>,
client_address: IpAddr,
rate_limiter: Option<Arc<BandwidthLimiter>>,
) -> Result<()> {
loop {
let packet: Packet = connection.read_datagram().await?.into();
let source_address = match packet.source() {
Ok(source) => source,
Err(err) => {
debug!("Dropping packet: unable to parse source IP from header due to {err}");
continue;
}
};
if source_address != client_address {
debug!(
"Dropping packet: source IP {source_address} does not match assigned address {client_address}"
);
continue;
}
if let Some(ref limiter) = rate_limiter {
let tokens = (packet.len() as u32 / 1024)
.max(1)
.try_into()
.expect("token amount is always non-zero");
let _ = limiter
.until_n_ready_with_jitter(tokens, Jitter::up_to(Duration::from_millis(5)))
.await;
}
ingress_queue.send(packet).await?;
}
}
#[cfg(feature = "metrics")]
async fn report_metrics(
connection: Connection,
reporting_interval: Duration,
username: String,
client_ip: IpAddr,
) -> Result<()> {
use metrics::{counter, gauge};
let connected_at = std::time::Instant::now();
let mut interval = tokio::time::interval(reporting_interval);
let labels = [("user", username), ("connection", client_ip.to_string())];
loop {
interval.tick().await;
let stats = connection.stats();
counter!("quincy_bytes_tx_total", &labels).absolute(stats.udp_tx.bytes);
counter!("quincy_bytes_rx_total", &labels).absolute(stats.udp_rx.bytes);
counter!("quincy_datagrams_tx_total", &labels).absolute(stats.udp_tx.datagrams);
counter!("quincy_datagrams_rx_total", &labels).absolute(stats.udp_rx.datagrams);
gauge!("quincy_connection_rtt_seconds", &labels).set(stats.path.rtt.as_secs_f64());
gauge!("quincy_connection_duration_seconds", &labels)
.set(connected_at.elapsed().as_secs_f64());
}
}
}