use crate::stack::{Stack, pop, push};
use crate::value::Value;
use may::net::UdpSocket;
use std::sync::{Arc, Mutex};
const MAX_SOCKETS: usize = 10_000;
const MAX_READ_SIZE: usize = 65_536;
struct SocketRegistry<T> {
sockets: Vec<Option<Arc<T>>>,
free_ids: Vec<usize>,
}
impl<T> SocketRegistry<T> {
const fn new() -> Self {
Self {
sockets: Vec::new(),
free_ids: Vec::new(),
}
}
fn allocate(&mut self, socket: T) -> Result<i64, &'static str> {
let socket = Arc::new(socket);
if let Some(id) = self.free_ids.pop() {
self.sockets[id] = Some(socket);
return Ok(id as i64);
}
if self.sockets.len() >= MAX_SOCKETS {
return Err("Maximum socket limit reached");
}
let id = self.sockets.len();
self.sockets.push(Some(socket));
Ok(id as i64)
}
fn checkout(&self, id: usize) -> Option<Arc<T>> {
self.sockets.get(id).and_then(|slot| slot.clone())
}
fn free(&mut self, id: usize) -> bool {
if let Some(slot) = self.sockets.get_mut(id)
&& slot.is_some()
{
*slot = None;
self.free_ids.push(id);
return true;
}
false
}
}
static SOCKETS: Mutex<SocketRegistry<UdpSocket>> = Mutex::new(SocketRegistry::new());
#[unsafe(no_mangle)]
pub unsafe extern "C" fn patch_seq_udp_bind(stack: Stack) -> Stack {
unsafe {
let (stack, port_val) = pop(stack);
let port = match port_val {
Value::Int(p) => p,
_ => return push_bind_failure(stack),
};
if !(0..=65535).contains(&port) {
return push_bind_failure(stack);
}
let addr = format!("0.0.0.0:{}", port);
let socket = match UdpSocket::bind(&addr) {
Ok(s) => s,
Err(_) => return push_bind_failure(stack),
};
let bound_port = match socket.local_addr() {
Ok(addr) => addr.port() as i64,
Err(_) => return push_bind_failure(stack),
};
let mut sockets = SOCKETS.lock().unwrap();
match sockets.allocate(socket) {
Ok(socket_id) => {
let stack = push(stack, Value::Int(socket_id));
let stack = push(stack, Value::Int(bound_port));
push(stack, Value::Bool(true))
}
Err(_) => push_bind_failure(stack),
}
}
}
unsafe fn push_bind_failure(stack: Stack) -> Stack {
unsafe {
let stack = push(stack, Value::Int(0));
let stack = push(stack, Value::Int(0));
push(stack, Value::Bool(false))
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn patch_seq_udp_send_to(stack: Stack) -> Stack {
unsafe {
let (stack, socket_val) = pop(stack);
let socket_id = match socket_val {
Value::Int(id) if id >= 0 => id as usize,
_ => return push(stack, Value::Bool(false)),
};
let (stack, port_val) = pop(stack);
let port = match port_val {
Value::Int(p) if (0..=65535).contains(&p) => p,
_ => return push(stack, Value::Bool(false)),
};
let (stack, host_val) = pop(stack);
let host = match host_val {
Value::String(s) => s,
_ => return push(stack, Value::Bool(false)),
};
let (stack, bytes_val) = pop(stack);
let bytes = match bytes_val {
Value::String(s) => s,
_ => return push(stack, Value::Bool(false)),
};
let socket = {
let sockets = SOCKETS.lock().unwrap();
match sockets.checkout(socket_id) {
Some(s) => s,
None => return push(stack, Value::Bool(false)),
}
};
let addr = format!("{}:{}", host.as_str_or_empty(), port);
let result = socket.send_to(bytes.as_bytes(), &addr);
push(stack, Value::Bool(result.is_ok()))
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn patch_seq_udp_receive_from(stack: Stack) -> Stack {
unsafe {
let (stack, socket_val) = pop(stack);
let socket_id = match socket_val {
Value::Int(id) if id >= 0 => id as usize,
_ => return push_receive_failure(stack),
};
let socket = {
let sockets = SOCKETS.lock().unwrap();
match sockets.checkout(socket_id) {
Some(s) => s,
None => return push_receive_failure(stack),
}
};
let mut buffer = vec![0u8; MAX_READ_SIZE];
let recv_result = socket.recv_from(&mut buffer);
let (size, src) = match recv_result {
Ok(pair) => pair,
Err(_) => return push_receive_failure(stack),
};
buffer.truncate(size);
let stack = push(stack, Value::String(crate::seqstring::global_bytes(buffer)));
let stack = push(stack, Value::String(src.ip().to_string().into()));
let stack = push(stack, Value::Int(src.port() as i64));
push(stack, Value::Bool(true))
}
}
unsafe fn push_receive_failure(stack: Stack) -> Stack {
unsafe {
let stack = push(stack, Value::String("".into()));
let stack = push(stack, Value::String("".into()));
let stack = push(stack, Value::Int(0));
push(stack, Value::Bool(false))
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn patch_seq_udp_close(stack: Stack) -> Stack {
unsafe {
let (stack, socket_val) = pop(stack);
let socket_id = match socket_val {
Value::Int(id) if id >= 0 => id as usize,
_ => return push(stack, Value::Bool(false)),
};
let mut sockets = SOCKETS.lock().unwrap();
let existed = sockets.free(socket_id);
push(stack, Value::Bool(existed))
}
}
pub use patch_seq_udp_bind as udp_bind;
pub use patch_seq_udp_close as udp_close;
pub use patch_seq_udp_receive_from as udp_receive_from;
pub use patch_seq_udp_send_to as udp_send_to;
#[cfg(test)]
mod tests;