use std::sync::Arc;
use std::time::Duration;
use futures::future::join_all;
use tokio::task::JoinHandle;
use tokio::time;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, warn};
use super::{InboundEvent, InboundRx, OutboundEvent, OutboundRx, Peer, Session};
use crate::device::{Endpoint, Transport};
use crate::noise::handshake::IncomingInitiation;
use crate::noise::protocol::{
self, CookieReply, HandshakeResponse, TransportData, COOKIE_REPLY_PACKET_SIZE,
HANDSHAKE_RESPONSE_PACKET_SIZE,
};
use crate::Tun;
pub(crate) struct PeerHandle {
token: CancellationToken,
handles: Vec<JoinHandle<()>>,
}
impl PeerHandle {
pub fn spawn<T, I>(
token: CancellationToken,
peer: Arc<Peer<T, I>>,
inbound: InboundRx<I>,
outbound: OutboundRx,
) -> Self
where
T: Tun + 'static,
I: Transport,
{
let handshake_loop = tokio::spawn(loop_handshake(token.child_token(), Arc::clone(&peer)));
let inbound_loop = tokio::spawn(loop_inbound(
token.child_token(),
Arc::clone(&peer),
inbound,
));
let outbound_loop = tokio::spawn(loop_outbound(
token.child_token(),
Arc::clone(&peer),
outbound,
));
Self {
token,
handles: vec![handshake_loop, inbound_loop, outbound_loop],
}
}
pub async fn cancel(mut self, timeout: Duration) {
self.token.cancel();
let handles = self.handles.drain(..).collect::<Vec<_>>();
let abort_handles = handles.iter().map(|h| h.abort_handle()).collect::<Vec<_>>();
if let Err(e) = tokio::time::timeout(timeout, join_all(handles)).await {
warn!(
"failed to cancel peer tasks in {}ms: {}",
timeout.as_millis(),
e
);
for handle in abort_handles {
handle.abort();
}
}
}
}
impl Drop for PeerHandle {
fn drop(&mut self) {
self.token.cancel();
}
}
async fn loop_handshake<T, I>(token: CancellationToken, peer: Arc<Peer<T, I>>)
where
T: Tun + 'static,
I: Transport,
{
debug!("Handshake loop for {peer} is UP");
while !token.is_cancelled() {
if peer.monitor.can_handshake() {
info!("initiating handshake");
let packet = {
let (next, packet) = peer.handshake.write().unwrap().initiate();
let mut sessions = peer.sessions.write().unwrap();
sessions.prepare_uninit(next);
packet
};
peer.send_outbound(&packet).await; peer.monitor.handshake().initiated();
}
time::sleep_until(peer.monitor.handshake().will_initiate_in().into()).await;
}
debug!("Handshake loop for {peer} is DOWN");
}
async fn loop_outbound<T, I>(token: CancellationToken, peer: Arc<Peer<T, I>>, mut rx: OutboundRx)
where
T: Tun + 'static,
I: Transport,
{
debug!("Outbound loop for {peer} is UP");
loop {
tokio::select! {
_ = token.cancelled() => break,
_ = time::sleep_until(peer.monitor.keepalive().next_attempt_in(peer.monitor.traffic()).into()) => {
peer.keepalive().await;
}
event = rx.recv() => {
match event {
Some(OutboundEvent::Data(data)) => {
tick_outbound(Arc::clone(&peer), data).await;
}
None => break,
}
}
}
}
debug!("Outbound loop for {peer} is DOWN");
}
async fn tick_outbound<T, I>(peer: Arc<Peer<T, I>>, data: Vec<u8>)
where
T: Tun + 'static,
I: Transport,
{
let session = { peer.sessions.read().unwrap().current().clone() };
let session = if let Some(s) = session { s } else { return };
match session.encrypt_data(&data) {
Ok(packet) => {
let buf = packet.to_bytes();
peer.send_outbound(&buf).await;
peer.monitor.traffic().outbound(buf.len());
}
Err(e) => {
warn!("failed to encrypt packet: {}", e);
}
}
}
async fn loop_inbound<T, I>(token: CancellationToken, peer: Arc<Peer<T, I>>, mut rx: InboundRx<I>)
where
T: Tun + 'static,
I: Transport,
{
debug!("Inbound loop for {peer} is UP");
loop {
tokio::select! {
() = token.cancelled() => break,
event = rx.recv() => {
match event {
Some(event) => tick_inbound(Arc::clone(&peer), event).await,
None => break,
}
}
}
}
debug!("Inbound loop for {peer} is DOWN");
}
async fn tick_inbound<T, I>(peer: Arc<Peer<T, I>>, event: InboundEvent<I>)
where
T: Tun + 'static,
I: Transport,
{
match event {
InboundEvent::HanshakeInitiation {
endpoint,
initiation,
} => inbound::handle_handshake_initiation(Arc::clone(&peer), endpoint, initiation).await,
InboundEvent::HandshakeResponse {
endpoint,
packet,
session,
} => inbound::handle_handshake_response(Arc::clone(&peer), endpoint, packet, session).await,
InboundEvent::CookieReply {
endpoint,
packet,
session,
} => inbound::handle_cookie_reply(Arc::clone(&peer), endpoint, packet, session).await,
InboundEvent::TransportData {
endpoint,
packet,
session,
} => inbound::handle_transport_data(Arc::clone(&peer), endpoint, packet, session).await,
}
}
mod inbound {
use super::*;
use tracing::error;
pub(super) async fn handle_handshake_initiation<T, I>(
peer: Arc<Peer<T, I>>,
endpoint: Endpoint<I>,
initiation: IncomingInitiation,
) where
T: Tun + 'static,
I: Transport,
{
peer.monitor
.traffic()
.inbound(protocol::HANDSHAKE_INITIATION_PACKET_SIZE);
let ret = {
let mut handshake = peer.handshake.write().unwrap();
handshake.respond(&initiation)
};
match ret {
Ok((session, packet)) => {
{
let mut sessions = peer.sessions.write().unwrap();
sessions.prepare_next(session);
}
peer.update_endpoint(endpoint.clone());
endpoint.send(&packet).await.unwrap();
peer.monitor.handshake().initiated();
}
Err(e) => debug!("failed to respond to handshake initiation: {e}"),
}
}
pub(super) async fn handle_handshake_response<T, I>(
peer: Arc<Peer<T, I>>,
endpoint: Endpoint<I>,
packet: HandshakeResponse,
_session: Session,
) where
T: Tun + 'static,
I: Transport,
{
peer.monitor
.traffic()
.inbound(HANDSHAKE_RESPONSE_PACKET_SIZE);
let ret = {
let mut handshake = peer.handshake.write().unwrap();
handshake.finalize(&packet)
};
match ret {
Ok(session) => {
let ret = {
let mut sessions = peer.sessions.write().unwrap();
sessions.complete_uninit(session)
};
if !ret {
debug!("failed to complete handshake, session not found");
return;
}
peer.monitor.handshake().completed();
info!("handshake completed");
peer.update_endpoint(endpoint);
peer.stage_outbound(vec![]).await; }
Err(e) => debug!("failed to finalize handshake: {e}"),
}
}
pub(super) async fn handle_cookie_reply<T, I>(
peer: Arc<Peer<T, I>>,
_endpoint: Endpoint<I>,
_packet: CookieReply,
_session: Session,
) where
T: Tun + 'static,
I: Transport,
{
peer.monitor.traffic().inbound(COOKIE_REPLY_PACKET_SIZE);
}
pub(super) async fn handle_transport_data<T, I>(
peer: Arc<Peer<T, I>>,
endpoint: Endpoint<I>,
packet: TransportData,
session: Session,
) where
T: Tun + 'static,
I: Transport,
{
peer.monitor.traffic().inbound(packet.packet_len());
{
let mut sessions = peer.sessions.write().unwrap();
if sessions.complete_next(session.clone()) {
info!("handshake completed");
peer.monitor.handshake().completed();
}
}
if !session.can_accept(packet.counter) {
debug!("dropping packet due to replay");
return;
}
peer.update_endpoint(endpoint);
match session.decrypt_data(&packet) {
Ok(data) => {
if data.is_empty() {
return;
}
debug!("recv data from peer and try to send it to TUN");
if let Err(e) = peer.tun.send(&data).await {
error!("{peer} failed to send data to tun: {e}");
}
session.aceept(packet.counter);
}
Err(e) => debug!("failed to decrypt packet: {e}"),
}
}
}