use crate::stack::{Stack, pop, push};
use crate::value::Value;
use may::net::{TcpListener, TcpStream};
use std::io::{Read, Write};
use std::sync::Mutex;
const MAX_SOCKETS: usize = 10_000;
const MAX_READ_SIZE: usize = 1_048_576;
struct SocketRegistry<T> {
sockets: Vec<Option<T>>,
free_ids: Vec<usize>,
}
impl<T> SocketRegistry<T> {
const fn new() -> Self {
Self {
sockets: Vec::new(),
free_ids: Vec::new(),
}
}
fn allocate(&mut self, socket: T) -> Result<i64, &'static str> {
if let Some(id) = self.free_ids.pop() {
self.sockets[id] = Some(socket);
return Ok(id as i64);
}
if self.sockets.len() >= MAX_SOCKETS {
return Err("Maximum socket limit reached");
}
let id = self.sockets.len();
self.sockets.push(Some(socket));
Ok(id as i64)
}
fn get_mut(&mut self, id: usize) -> Option<&mut Option<T>> {
self.sockets.get_mut(id)
}
fn free(&mut self, id: usize) {
if let Some(slot) = self.sockets.get_mut(id)
&& slot.is_some()
{
*slot = None;
self.free_ids.push(id);
}
}
}
static LISTENERS: Mutex<SocketRegistry<TcpListener>> = Mutex::new(SocketRegistry::new());
static STREAMS: Mutex<SocketRegistry<TcpStream>> = Mutex::new(SocketRegistry::new());
#[unsafe(no_mangle)]
pub unsafe extern "C" fn patch_seq_tcp_listen(stack: Stack) -> Stack {
unsafe {
let (stack, port_val) = pop(stack);
let port = match port_val {
Value::Int(p) => p,
_ => {
let stack = push(stack, Value::Int(0));
return push(stack, Value::Bool(false));
}
};
if !(0..=65535).contains(&port) {
let stack = push(stack, Value::Int(0));
return push(stack, Value::Bool(false));
}
let addr = format!("0.0.0.0:{}", port);
let listener = match TcpListener::bind(&addr) {
Ok(l) => l,
Err(_) => {
let stack = push(stack, Value::Int(0));
return push(stack, Value::Bool(false));
}
};
let mut listeners = LISTENERS.lock().unwrap();
match listeners.allocate(listener) {
Ok(listener_id) => {
let stack = push(stack, Value::Int(listener_id));
push(stack, Value::Bool(true))
}
Err(_) => {
let stack = push(stack, Value::Int(0));
push(stack, Value::Bool(false))
}
}
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn patch_seq_tcp_accept(stack: Stack) -> Stack {
unsafe {
let (stack, listener_id_val) = pop(stack);
let listener_id = match listener_id_val {
Value::Int(id) => id as usize,
_ => {
let stack = push(stack, Value::Int(0));
return push(stack, Value::Bool(false));
}
};
let listener = {
let mut listeners = LISTENERS.lock().unwrap();
match listeners.get_mut(listener_id).and_then(|opt| opt.take()) {
Some(l) => l,
None => {
let stack = push(stack, Value::Int(0));
return push(stack, Value::Bool(false));
}
}
};
let (stream, _addr) = match listener.accept() {
Ok(result) => result,
Err(_) => {
let mut listeners = LISTENERS.lock().unwrap();
if let Some(slot) = listeners.get_mut(listener_id) {
*slot = Some(listener);
}
let stack = push(stack, Value::Int(0));
return push(stack, Value::Bool(false));
}
};
{
let mut listeners = LISTENERS.lock().unwrap();
if let Some(slot) = listeners.get_mut(listener_id) {
*slot = Some(listener);
}
}
let mut streams = STREAMS.lock().unwrap();
match streams.allocate(stream) {
Ok(client_id) => {
let stack = push(stack, Value::Int(client_id));
push(stack, Value::Bool(true))
}
Err(_) => {
let stack = push(stack, Value::Int(0));
push(stack, Value::Bool(false))
}
}
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn patch_seq_tcp_read(stack: Stack) -> Stack {
unsafe {
let (stack, socket_id_val) = pop(stack);
let socket_id = match socket_id_val {
Value::Int(id) => id as usize,
_ => {
let stack = push(stack, Value::String("".into()));
return push(stack, Value::Bool(false));
}
};
let mut stream = {
let mut streams = STREAMS.lock().unwrap();
match streams.get_mut(socket_id).and_then(|opt| opt.take()) {
Some(s) => s,
None => {
let stack = push(stack, Value::String("".into()));
return push(stack, Value::Bool(false));
}
}
};
let mut buffer = Vec::new();
let mut chunk = [0u8; 4096];
let mut read_error = false;
loop {
if buffer.len() >= MAX_READ_SIZE {
read_error = true;
break;
}
match stream.read(&mut chunk) {
Ok(0) => {
break;
}
Ok(n) => {
let bytes_to_add = n.min(MAX_READ_SIZE.saturating_sub(buffer.len()));
buffer.extend_from_slice(&chunk[..bytes_to_add]);
if bytes_to_add < n {
break; }
break;
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
if buffer.is_empty() {
may::coroutine::yield_now();
continue;
}
break;
}
Err(_) => {
read_error = true;
break;
}
}
}
{
let mut streams = STREAMS.lock().unwrap();
if let Some(slot) = streams.get_mut(socket_id) {
*slot = Some(stream);
}
}
if read_error {
let stack = push(stack, Value::String("".into()));
return push(stack, Value::Bool(false));
}
match String::from_utf8(buffer) {
Ok(data) => {
let stack = push(stack, Value::String(data.into()));
push(stack, Value::Bool(true))
}
Err(_) => {
let stack = push(stack, Value::String("".into()));
push(stack, Value::Bool(false))
}
}
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn patch_seq_tcp_write(stack: Stack) -> Stack {
unsafe {
let (stack, socket_id_val) = pop(stack);
let socket_id = match socket_id_val {
Value::Int(id) => id as usize,
_ => {
return push(stack, Value::Bool(false));
}
};
let (stack, data_val) = pop(stack);
let data = match data_val {
Value::String(s) => s,
_ => {
return push(stack, Value::Bool(false));
}
};
let mut stream = {
let mut streams = STREAMS.lock().unwrap();
match streams.get_mut(socket_id).and_then(|opt| opt.take()) {
Some(s) => s,
None => {
return push(stack, Value::Bool(false));
}
}
};
let write_result = stream.write_all(data.as_str().as_bytes());
let flush_result = if write_result.is_ok() {
stream.flush()
} else {
write_result
};
{
let mut streams = STREAMS.lock().unwrap();
if let Some(slot) = streams.get_mut(socket_id) {
*slot = Some(stream);
}
}
push(stack, Value::Bool(flush_result.is_ok()))
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn patch_seq_tcp_close(stack: Stack) -> Stack {
unsafe {
let (stack, socket_id_val) = pop(stack);
let socket_id = match socket_id_val {
Value::Int(id) => id as usize,
_ => {
return push(stack, Value::Bool(false));
}
};
let mut streams = STREAMS.lock().unwrap();
let existed = streams
.get_mut(socket_id)
.map(|slot| slot.is_some())
.unwrap_or(false);
if existed {
streams.free(socket_id);
}
push(stack, Value::Bool(existed))
}
}
pub use patch_seq_tcp_accept as tcp_accept;
pub use patch_seq_tcp_close as tcp_close;
pub use patch_seq_tcp_listen as tcp_listen;
pub use patch_seq_tcp_read as tcp_read;
pub use patch_seq_tcp_write as tcp_write;
#[cfg(test)]
mod tests {
use super::*;
use crate::arithmetic::push_int;
use crate::scheduler::scheduler_init;
#[test]
fn test_tcp_listen() {
unsafe {
scheduler_init();
let stack = crate::stack::alloc_test_stack();
let stack = push_int(stack, 0); let stack = tcp_listen(stack);
let (stack, success) = pop(stack);
assert!(
matches!(success, Value::Bool(true)),
"tcp_listen should succeed"
);
let (_stack, result) = pop(stack);
match result {
Value::Int(listener_id) => {
assert!(listener_id >= 0, "Listener ID should be non-negative");
}
_ => panic!("Expected Int (listener_id), got {:?}", result),
}
}
}
#[test]
fn test_tcp_listen_invalid_port_negative() {
unsafe {
scheduler_init();
let stack = crate::stack::alloc_test_stack();
let stack = push_int(stack, -1);
let stack = tcp_listen(stack);
let (stack, success) = pop(stack);
assert!(
matches!(success, Value::Bool(false)),
"Invalid port should return false"
);
let (_stack, result) = pop(stack);
assert!(
matches!(result, Value::Int(0)),
"Invalid port should return 0"
);
}
}
#[test]
fn test_tcp_listen_invalid_port_too_high() {
unsafe {
scheduler_init();
let stack = crate::stack::alloc_test_stack();
let stack = push_int(stack, 65536);
let stack = tcp_listen(stack);
let (stack, success) = pop(stack);
assert!(
matches!(success, Value::Bool(false)),
"Invalid port should return false"
);
let (_stack, result) = pop(stack);
assert!(
matches!(result, Value::Int(0)),
"Invalid port should return 0"
);
}
}
#[test]
fn test_tcp_port_range_valid() {
unsafe {
scheduler_init();
let stack = push_int(crate::stack::alloc_test_stack(), 0);
let stack = tcp_listen(stack);
let (stack, success) = pop(stack);
assert!(matches!(success, Value::Bool(true)));
let (_, result) = pop(stack);
assert!(matches!(result, Value::Int(_)));
let stack = push_int(crate::stack::alloc_test_stack(), 9999);
let stack = tcp_listen(stack);
let (stack, success) = pop(stack);
assert!(matches!(success, Value::Bool(true)));
let (_, result) = pop(stack);
assert!(matches!(result, Value::Int(_)));
}
}
#[test]
fn test_socket_id_reuse_after_close() {
unsafe {
scheduler_init();
let stack = push_int(crate::stack::alloc_test_stack(), 0);
let stack = tcp_listen(stack);
let (stack, success) = pop(stack);
assert!(matches!(success, Value::Bool(true)));
let (_stack, listener_result) = pop(stack);
let listener_id = match listener_result {
Value::Int(id) => id,
_ => panic!("Expected listener ID"),
};
assert!(listener_id >= 0);
}
}
#[test]
fn test_tcp_read_invalid_socket_id() {
unsafe {
scheduler_init();
let stack = push_int(crate::stack::alloc_test_stack(), 9999);
let stack = tcp_read(stack);
let (stack, success) = pop(stack);
assert!(
matches!(success, Value::Bool(false)),
"Invalid socket should return false"
);
let (_stack, result) = pop(stack);
match result {
Value::String(s) => assert_eq!(s.as_str(), ""),
_ => panic!("Expected empty string"),
}
}
}
#[test]
fn test_tcp_write_invalid_socket_id() {
unsafe {
scheduler_init();
let stack = push(
crate::stack::alloc_test_stack(),
Value::String("test".into()),
);
let stack = push_int(stack, 9999);
let stack = tcp_write(stack);
let (_stack, success) = pop(stack);
assert!(
matches!(success, Value::Bool(false)),
"Invalid socket should return false"
);
}
}
#[test]
fn test_tcp_close_idempotent() {
unsafe {
scheduler_init();
let stack = push_int(crate::stack::alloc_test_stack(), 0);
let stack = tcp_listen(stack);
let (stack, success) = pop(stack);
assert!(matches!(success, Value::Bool(true)));
let (stack, _listener_result) = pop(stack);
let stack = push_int(stack, 9999);
let stack = tcp_close(stack);
let (_stack, success) = pop(stack);
assert!(
matches!(success, Value::Bool(false)),
"Invalid socket close should return false"
);
}
}
#[test]
fn test_socket_registry_capacity() {
assert_eq!(MAX_SOCKETS, 10_000);
}
#[test]
fn test_max_read_size_limit() {
assert_eq!(MAX_READ_SIZE, 1_048_576);
}
}