use crate::{
helper_functions,
TIMEOUT_SECONDS,
};
use anyhow::Result;
use demikernel::{
demi_sgarray_t,
runtime::types::{
demi_opcode_t,
demi_qresult_t,
},
LibOS,
QDesc,
QToken,
};
use std::{
collections::{
HashMap,
HashSet,
},
net::SocketAddr,
};
#[cfg(target_os = "windows")]
pub const AF_INET: i32 = windows::Win32::Networking::WinSock::AF_INET.0 as i32;
#[cfg(target_os = "windows")]
pub const SOCK_STREAM: i32 = windows::Win32::Networking::WinSock::SOCK_STREAM.0 as i32;
#[cfg(target_os = "linux")]
pub const AF_INET: i32 = libc::AF_INET;
#[cfg(target_os = "linux")]
pub const SOCK_STREAM: i32 = libc::SOCK_STREAM;
pub struct TcpServer {
libos: LibOS,
sockqd: QDesc,
connected_client_qds: HashSet<QDesc>,
pending_qtokens: Vec<QToken>,
qtokens_to_qdesc_map: HashMap<QToken, QDesc>,
num_accepted_clients: usize,
num_closed_clients: usize,
has_test_passed: bool,
}
impl TcpServer {
pub fn new(mut libos: LibOS, local_socket_addr: SocketAddr) -> Result<Self> {
let sockqd: QDesc = libos.socket(AF_INET, SOCK_STREAM, 0)?;
libos.bind(sockqd, local_socket_addr)?;
println!("Listening to: {:?}", local_socket_addr);
return Ok(Self {
libos,
sockqd,
connected_client_qds: HashSet::default(),
pending_qtokens: Vec::default(),
qtokens_to_qdesc_map: HashMap::default(),
num_accepted_clients: 0,
num_closed_clients: 0,
has_test_passed: false,
});
}
pub fn run(&mut self, nclients: Option<usize>) -> Result<()> {
self.libos.listen(self.sockqd, nclients.unwrap_or(512))?;
self.issue_accept()?;
loop {
if let Some(num_clients) = nclients {
if self.num_closed_clients >= num_clients {
assert_eq!(
self.connected_client_qds.len(),
0,
"there should be no clients connected, but there are"
);
break;
}
}
let qr: demi_qresult_t = {
let (index, qr): (usize, demi_qresult_t) = self.libos.wait_any(&self.pending_qtokens, Some(TIMEOUT_SECONDS))?;
self.mark_completed_operation(index)?;
qr
};
match qr.qr_opcode {
demi_opcode_t::DEMI_OPC_ACCEPT => {
let qd: QDesc = unsafe { qr.qr_value.ares.qd.into() };
self.handle_accept_completion(qd)?;
self.issue_accept()?;
},
demi_opcode_t::DEMI_OPC_POP => {
let qd: QDesc = qr.qr_qd.into();
let sga: demi_sgarray_t = unsafe { qr.qr_value.sga };
let seglen: usize = sga.sga_segs[0].sgaseg_len as usize;
assert_eq!(seglen, 0, "client must have had closed the connection, but it has not");
self.libos.sgafree(sga)?;
let qts_cancelled: Vec<QToken> = self.handle_connection_termination(qd)?;
assert!(
qts_cancelled.is_empty(),
"client should not have any pending operations, but it has"
);
},
demi_opcode_t::DEMI_OPC_FAILED => {
let qd: QDesc = qr.qr_qd.into();
if !helper_functions::is_closed(qr.qr_ret) {
anyhow::bail!(
"client should have had terminated the connection, but it has not: error={:?}",
qr.qr_ret
)
}
let _: Vec<QToken> = self.handle_connection_termination(qd)?;
},
_ => {
anyhow::bail!("unexpected result")
},
}
}
helper_functions::close_and_wait(&mut self.libos, self.sockqd)?;
self.has_test_passed = true;
Ok(())
}
pub fn run_close_sockets_on_accept(&mut self, nclients: Option<usize>) -> Result<()> {
self.libos.listen(self.sockqd, nclients.unwrap_or(512))?;
self.issue_accept()?;
loop {
if let Some(nclients) = nclients {
if self.num_closed_clients >= nclients {
const ERR_MSG: &str = "there should be no clients connected, but there are";
assert_eq!(self.connected_client_qds.len(), 0, "{}", ERR_MSG);
break;
}
}
let qr: demi_qresult_t = {
let (index, qr): (usize, demi_qresult_t) = self.libos.wait_any(&self.pending_qtokens, Some(TIMEOUT_SECONDS))?;
self.mark_completed_operation(index)?;
qr
};
match qr.qr_opcode {
demi_opcode_t::DEMI_OPC_ACCEPT => {
let qd: QDesc = unsafe { qr.qr_value.ares.qd.into() };
self.num_accepted_clients += 1;
println!("{} clients accepted, closing socket", self.num_accepted_clients);
helper_functions::close_and_wait(&mut self.libos, qd)?;
self.num_closed_clients += 1;
self.issue_accept()?;
},
_ => {
anyhow::bail!("unexpected result")
},
}
}
helper_functions::close_and_wait(&mut self.libos, self.sockqd)?;
Ok(())
}
fn register_client(&mut self, qd: QDesc) {
assert_eq!(
self.connected_client_qds.insert(qd),
true,
"client is already registered and it shouldn't be"
);
}
fn unregister_client(&mut self, qd: QDesc) {
assert_eq!(
self.connected_client_qds.remove(&qd),
true,
"client isn't registered and it should be"
);
}
fn cancel_pending_operations(&mut self, qd: QDesc) -> Vec<QToken> {
let qts_drained: HashMap<QToken, QDesc> = self.qtokens_to_qdesc_map.extract_if(|_k, v| *v == qd).collect();
let qts_dropped: Vec<QToken> = self.pending_qtokens.extract_if(|x| qts_drained.contains_key(x)).collect();
qts_dropped
}
fn mark_completed_operation(&mut self, index: usize) -> Result<()> {
let qt: QToken = self.pending_qtokens.remove(index);
self.qtokens_to_qdesc_map
.remove(&qt)
.ok_or(anyhow::anyhow!("unregistered queue token"))?;
Ok(())
}
fn issue_accept(&mut self) -> Result<()> {
let qt: QToken = self.libos.accept(self.sockqd)?;
self.qtokens_to_qdesc_map.insert(qt, self.sockqd);
self.pending_qtokens.push(qt);
Ok(())
}
fn issue_pop(&mut self, qd: QDesc) -> Result<()> {
let qt: QToken = self.libos.pop(qd, None)?;
self.qtokens_to_qdesc_map.insert(qt, qd);
self.pending_qtokens.push(qt);
Ok(())
}
fn handle_accept_completion(&mut self, qd: QDesc) -> Result<()> {
self.register_client(qd);
self.issue_pop(qd)?;
self.num_accepted_clients += 1;
println!("{} clients accepted", self.num_accepted_clients);
Ok(())
}
fn handle_connection_termination(&mut self, qd: QDesc) -> Result<Vec<QToken>> {
let qts_cancelled: Vec<QToken> = self.cancel_pending_operations(qd);
self.unregister_client(qd);
helper_functions::close_and_wait(&mut self.libos, qd)?;
self.num_closed_clients += 1;
println!("{} clients closed", self.num_closed_clients);
Ok(qts_cancelled)
}
}
impl Drop for TcpServer {
fn drop(&mut self) {
if self.has_test_passed {
return;
}
for qd in self.connected_client_qds.clone().drain() {
if let Err(e) = helper_functions::close_and_wait(&mut self.libos, qd) {
println!("ERROR: close() failed (error={:?}", e);
println!("WARN: leaking qd={:?}", qd);
}
}
if let Err(e) = helper_functions::close_and_wait(&mut self.libos, self.sockqd) {
println!("ERROR: close() failed (error={:?}", e);
println!("WARN: leaking qd={:?}", self.sockqd);
}
}
}