use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::UdpSocket;
use tokio::sync::Mutex;
use tracing::{debug, error, info, warn};
use crate::frame::{encode_frame, try_decode_frame};
use crate::types::{Flags, Frame, FrameType, Header, VstpError, VSTP_VERSION};
use crate::udp::reassembly::{extract_fragment_info, ReassemblyManager, MAX_DATAGRAM_SIZE};
#[derive(Debug, Clone)]
pub struct UdpServerConfig {
pub use_crc: bool,
pub allow_frag: bool,
pub max_reassembly_sessions: usize,
}
impl Default for UdpServerConfig {
fn default() -> Self {
Self {
use_crc: true,
allow_frag: true,
max_reassembly_sessions: 1000,
}
}
}
pub struct VstpUdpServer {
socket: UdpSocket,
config: UdpServerConfig,
reassembly: ReassemblyManager,
next_session_id: Arc<Mutex<u128>>,
}
impl VstpUdpServer {
pub async fn bind(addr: &str) -> Result<Self, VstpError> {
let socket = UdpSocket::bind(addr).await?;
info!("VSTP UDP server bound to {}", addr);
Ok(Self {
socket,
config: UdpServerConfig::default(),
reassembly: ReassemblyManager::new(),
next_session_id: Arc::new(Mutex::new(1)),
})
}
pub async fn bind_with_config(addr: &str, config: UdpServerConfig) -> Result<Self, VstpError> {
let socket = UdpSocket::bind(addr).await?;
info!("VSTP UDP server bound to {} with custom config", addr);
Ok(Self {
socket,
config,
reassembly: ReassemblyManager::new(),
next_session_id: Arc::new(Mutex::new(1)),
})
}
pub fn local_addr(&self) -> Result<SocketAddr, VstpError> {
self.socket.local_addr().map_err(|e| VstpError::Io(e))
}
pub async fn send(&self, frame: Frame, dest: SocketAddr) -> Result<(), VstpError> {
let encoded = encode_frame(&frame)?;
self.socket.send_to(&encoded, dest).await?;
Ok(())
}
pub async fn recv(&self) -> Result<(Frame, SocketAddr), VstpError> {
let mut buf = vec![0u8; MAX_DATAGRAM_SIZE * 2];
loop {
let (len, from_addr) = self.socket.recv_from(&mut buf).await?;
let data = &buf[..len];
debug!("Received {} bytes from {}", len, from_addr);
let mut buf = bytes::BytesMut::from(data);
match try_decode_frame(&mut buf, 65536) {
Ok(Some(frame)) => {
if let Some(fragment) = extract_fragment_info(&frame) {
if let Some(assembled_data) = self.reassembly.add_fragment(from_addr, fragment).await? {
let mut complete_frame = frame;
complete_frame.payload = assembled_data;
complete_frame.headers.retain(|h| {
h.key != b"frag-id" && h.key != b"frag-index" && h.key != b"frag-total"
});
if complete_frame.flags.contains(Flags::REQ_ACK) {
if let Some(msg_id) = self.extract_msg_id(&complete_frame) {
let _ = self.send_ack(msg_id, from_addr).await;
}
}
return Ok((complete_frame, from_addr));
}
continue;
} else {
if frame.flags.contains(Flags::REQ_ACK) {
if let Some(msg_id) = self.extract_msg_id(&frame) {
let _ = self.send_ack(msg_id, from_addr).await;
}
}
return Ok((frame, from_addr));
}
}
Ok(None) => continue, Err(_) => continue, }
}
}
fn extract_msg_id(&self, frame: &Frame) -> Option<u64> {
for header in &frame.headers {
if header.key == b"msg-id" {
if let Ok(msg_id) = std::str::from_utf8(&header.value).ok()?.parse::<u64>() {
return Some(msg_id);
}
}
}
None
}
async fn send_ack(&self, msg_id: u64, dest: SocketAddr) -> Result<(), VstpError> {
let ack_frame = Frame {
version: VSTP_VERSION,
typ: FrameType::Ack,
flags: Flags::empty(),
headers: vec![Header {
key: b"msg-id".to_vec(),
value: msg_id.to_string().into_bytes(),
}],
payload: Vec::new(),
};
self.send(ack_frame, dest).await
}
pub async fn reassembly_session_count(&self) -> usize {
self.reassembly.session_count().await
}
}