use super::utils::send_message_over_tcp;
use surrealcs_kernel::messages::{
client::message::TransactionMessage,
server::{interface::ServerMessage, wrapper::WrappedServerMessage},
};
use tokio::{net::tcp::OwnedWriteHalf, sync::mpsc};
pub async fn close_connection(
ping_tx: mpsc::Sender<TransactionMessage>,
connection_id: String,
router_tx: mpsc::Sender<TransactionMessage>,
mut writer: OwnedWriteHalf,
) {
tracing::trace!("Closing the connection writer actor for: {}", connection_id);
match ping_tx.send(TransactionMessage::CloseConnection).await {
Ok(_) => {
tracing::info!("ping actor close message sent for: {}", connection_id);
}
Err(e) => {
tracing::error!("Error sending close connection message to the ping actor: {}", e);
}
}
match router_tx.send(TransactionMessage::CloseConnection).await {
Ok(_) => {
tracing::info!("router actor close message sent for: {}", connection_id);
}
Err(e) => {
tracing::error!("Error sending close connection message to the router actor: {}", e);
}
}
let wrapped_message =
WrappedServerMessage::new(0, ServerMessage::CloseConnection, connection_id.clone());
match send_message_over_tcp(&mut writer, &wrapped_message).await {
Ok(_) => {
println!("Connection closed");
tracing::info!("Connection {} writer actor closed", connection_id);
}
Err(e) => {
println!("Error sending close connection message to the server: {}", e);
tracing::error!("Error sending close connection message to the server: {}", e);
}
}
tracing::info!("Connection {} writer actor closed", connection_id);
}
#[cfg(test)]
mod tests {
use super::*;
use surrealcs_kernel::messages::server::interface::ServerMessage;
use surrealcs_kernel::messages::{
client::message::TransactionMessage,
serialization::bincode_processes::message::deserialize_from_stream,
};
use tokio::net::{TcpListener, TcpStream};
use tokio::time::{timeout, Duration};
static CONNECTION_ID: &str = "1-1234567890";
#[tokio::test]
async fn test_close_connection() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let client_stream = TcpStream::connect(addr).await.unwrap();
let (_client_reader, client_writer) = client_stream.into_split();
let server_stream = listener.accept().await.unwrap().0;
let (mut server_reader, _server_writer) = server_stream.into_split();
let (router_tx, mut router_rx) = mpsc::channel::<TransactionMessage>(32);
let (ping_tx, mut ping_rx) = mpsc::channel::<TransactionMessage>(32);
close_connection(ping_tx, CONNECTION_ID.into(), router_tx, client_writer).await;
let result = router_rx.recv().await.unwrap();
match result {
TransactionMessage::CloseConnection => {}
_ => {
panic!(
"Wrong message type received should have been CloseConnection! but was {:?}",
result
);
}
}
let result = ping_rx.recv().await.unwrap();
match result {
TransactionMessage::CloseConnection => {}
_ => {
panic!(
"Wrong message type received should have been CloseConnection! but was {:?}",
result
);
}
}
let wrapped_message =
timeout(Duration::from_millis(10), deserialize_from_stream(&mut server_reader))
.await
.unwrap()
.unwrap();
match wrapped_message.message {
ServerMessage::CloseConnection => {}
_ => {
panic!(
"Wrong message type received should have been CloseConnection! but was {:?}",
wrapped_message.message
);
}
}
}
}