playit-agent-core 0.20.1

Contains the logic to create a playit.gg agent
Documentation
use std::{sync::{atomic::{AtomicU64, Ordering}, Arc}, time::Duration};
use tokio::time::Instant;

use playit_agent_proto::{control_messages::UdpChannelDetails, udp_proto::{UdpFlow, UDP_CHANNEL_ESTABLISH_ID}};
use tokio::sync::mpsc::{channel, Receiver, Sender};

use crate::{agent_control::{DualStackUdpSocket, PacketIO}, utils::now_milli};

use super::{packets::{Packet, Packets}, udp_errors::udp_errors};

pub struct UdpChannel {
    session_tx: Sender<UdpChannelDetails>,
    send: Sender<(UdpFlow, Packet)>,
    recv: Receiver<(UdpFlow, Packet)>,
    shared: Arc<Shared>,
}

#[derive(Default)]
struct Shared {
    establish_rx_epoch: AtomicU64,
    establish_tx_epoch: AtomicU64,
}

struct Task {
    socket: DualStackUdpSocket,
    session: Option<UdpChannelDetails>,
    session_rx: Receiver<UdpChannelDetails>,

    packets: Packets,

    send_rx: Receiver<(UdpFlow, Packet)>,
    recv_tx: Sender<(UdpFlow, Packet)>,

    shared: Arc<Shared>,
}

impl UdpChannel {
    pub async fn new(packets: Packets) -> Result<Self, std::io::Error> {
        let socket = DualStackUdpSocket::new().await?;

        let (session_tx, session_rx) = channel(32);

        let (send_tx, send_rx) = channel(1024);
        let (recv_tx, recv_rx) = channel(1024);

        let shared = Arc::new(Shared::default());

        tokio::spawn(Task {
            socket,
            session: None,
            session_rx,
            packets,
            send_rx,
            recv_tx,
            shared: shared.clone(),
        }.start());

        Ok(UdpChannel {
            session_tx,
            send: send_tx,
            recv: recv_rx,
            shared,
        })
    }

    pub fn time_since_established(&self) -> Option<Duration> {
        let ts = self.shared.establish_rx_epoch.load(Ordering::Acquire);
        if ts == 0 {
            return None;
        }
        let now = now_milli();
        Some(Duration::from_millis(now.max(ts) - ts))
    }

    pub fn time_since_establish_send(&self) -> Option<Duration> {
        let ts = self.shared.establish_tx_epoch.load(Ordering::Acquire);
        if ts == 0 {
            return None;
        }
        let now = now_milli();
        Some(Duration::from_millis(now.max(ts) - ts))
    }

    pub async fn update_session(&self, details: UdpChannelDetails) {
        self.session_tx.send(details).await.expect("task closed");
    }

    pub async fn send(&self, flow: UdpFlow, packet: Packet) {
        if self.send.send((flow, packet)).await.is_err() {
            panic!("UdpChannel task closed");
        }
    }

    pub async fn recv(&mut self) -> (UdpFlow, Packet) {
        self.recv.recv().await.expect("UdpChannel task closed")
    }
}

impl Task {
    async fn start(mut self) {
        let mut packet = self.packets.allocate_wait().await;
        let mut last_establish_send = Instant::now();

        loop {
            let next_send = if self.session.is_none() {
                Instant::now() + Duration::from_secs(5)
            } else {
                let since_establish = Duration::from_millis({
                    let now = now_milli();
                    let last_rx = self.shared.establish_rx_epoch.load(Ordering::Relaxed);
                    now.max(last_rx) - last_rx
                });

                const MIN_INTERVAL: Duration = Duration::from_secs(3);
                const MAX_ESTABLISH_WAIT: Duration = Duration::from_secs(15);
                const SEND_INTERVAL: Duration = Duration::from_secs(10);

                if MAX_ESTABLISH_WAIT <= since_establish {
                    last_establish_send + MIN_INTERVAL
                } else {
                    last_establish_send + SEND_INTERVAL
                }
            };

            let recv_res = tokio::select! {
                _ = tokio::time::sleep_until(next_send) => {
                    last_establish_send = Instant::now();
                    self.send_establish().await;
                    continue;
                }
                session_res = self.session_rx.recv() => {
                    let Some(details) = session_res else { break };
                    self.handle_session(details).await;
                    continue;
                }
                to_send_res = self.send_rx.recv() => {
                    let Some((flow, to_send)) = to_send_res else { break };
                    self.send(flow, to_send).await;
                    continue;
                }
                recv_res = self.socket.recv_from(packet.full_slice_mut()) => recv_res,
            };

            let Ok((bytes, source)) = recv_res else {
                udp_errors().recv_io_error.inc();
                tokio::time::sleep(Duration::from_millis(20)).await;
                continue;
            };

            let Some(session) = self.session.as_ref() else {
                udp_errors().recv_with_no_session.inc();
                return;
            };

            if session.tunnel_addr != source {
                udp_errors().recv_source_no_match.inc();
                return;
            }

            packet.set_len(bytes).expect("failed to update packet len");
            let flow = match UdpFlow::from_tail(packet.as_ref()) {
                Ok(flow) => flow,
                Err(Some(footer)) if footer == UDP_CHANNEL_ESTABLISH_ID => {
                    self.shared.establish_rx_epoch.store(now_milli(), Ordering::Release);
                    continue;
                }
                Err(id) => {
                    if id.is_none() {
                        udp_errors().recv_too_small.inc();
                    } else {
                        udp_errors().recv_invalid_footer_id.inc();
                    }
                    continue;
                }
            };

            packet.set_len(bytes - flow.footer_len())
                .expect("failed to remove udp footer");

            if self.recv_tx.send((flow, packet)).await.is_err() {
                break;
            }
            packet = self.packets.allocate_wait().await
        }
    }

    async fn handle_session(&mut self, details: UdpChannelDetails) {
        let should_send = match self.session.take() {
            None => true,
            Some(old) => {
                if old != details {
                    true
                } else {
                    5_000 < now_milli() - self.shared.establish_rx_epoch.load(Ordering::Relaxed)
                }
            }
        };

        self.session = Some(details);
        if should_send {
            self.send_establish().await;
        }
    }

    async fn send_establish(&self) {
        let Some(session) = self.session.as_ref() else {
            udp_errors().establish_no_session.inc();
            return;
        };

        self.shared.establish_tx_epoch.store(now_milli(), Ordering::Release);

        if self.socket.send_to(&session.token[..], session.tunnel_addr).await.is_err() {
            udp_errors().establish_send_io_error.inc();
        }
    }

    async fn send(&self, flow: UdpFlow, mut packet: Packet) {
        let Some(session) = self.session.as_ref() else {
            udp_errors().no_session_send_fail.inc();
            return;
        };

        let og_len = packet.len();
        let remaining = &mut packet.full_slice_mut()[og_len..];
        if !flow.write_to(remaining) {
            udp_errors().tail_append_fail.inc();
            return;
        }

        packet.set_len(og_len + flow.footer_len())
            .expect("should be able to update packet len");

        if self.socket.send_to(packet.as_ref(), session.tunnel_addr).await.is_err() {
            udp_errors().send_io_error.inc();
        }
    }
}