use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use tracing::instrument;
use surrealcs_kernel::{
allocator::Allocator, logging::messages::actors_client::router::log_client_router_message,
messages::client::message::TransactionMessage,
};
mod deregister;
mod register;
mod transaction_operation;
#[instrument(name = "client_router_actor", skip(rx))]
pub async fn router_actor(
mut rx: mpsc::Receiver<TransactionMessage>,
address: String,
connection_id: String,
reader_handle: JoinHandle<()>,
) {
let mut allocator = Allocator::new();
while let Some(message) = rx.recv().await {
log_client_router_message(&message);
match message {
TransactionMessage::Register(sender) => {
register::register(sender, &mut allocator).await;
continue;
}
TransactionMessage::Deregister(index) => {
deregister::deregister(index, &mut allocator).await;
continue;
}
TransactionMessage::TransactionOperation(unwrapped_message) => {
transaction_operation::send_transaction_operation(
unwrapped_message,
&mut allocator,
)
.await;
continue;
}
TransactionMessage::CloseConnection => {
reader_handle.abort();
tracing::trace!("shutting down router actor for connection: {}", connection_id);
break;
}
_ => {}
}
tracing::error!("the message that slipped through for the router: {:?}", message);
}
}
#[cfg(test)]
mod tests {
use super::*;
use surrealcs_kernel::messages::server::interface::ServerMessage;
use surrealcs_kernel::messages::server::wrapper::WrappedServerMessage;
use tokio::sync::oneshot;
use tokio::time::{timeout, Duration};
static CONNECTION_ID: &str = "1-1234567890";
#[tokio::test]
async fn test_ok() {
let (tx, rx) = mpsc::channel::<TransactionMessage>(32);
let address = "127.0.0:8080".to_string();
let reader_actor_sim = async {};
let reader_handle = tokio::spawn(reader_actor_sim);
tokio::spawn(async move {
router_actor(rx, address, CONNECTION_ID.into(), reader_handle).await;
});
let (tx_1, mut rx_1) = mpsc::channel::<TransactionMessage>(32);
let (tx_2, mut rx_2) = mpsc::channel::<TransactionMessage>(32);
let (tx_3, mut rx_3) = mpsc::channel::<TransactionMessage>(32);
tx.send(TransactionMessage::Register(tx_1.clone())).await.unwrap();
tx.send(TransactionMessage::Register(tx_2.clone())).await.unwrap();
tx.send(TransactionMessage::Register(tx_3.clone())).await.unwrap();
let response_1 = rx_1.recv().await.unwrap();
let response_2 = rx_2.recv().await.unwrap();
let response_3 = rx_3.recv().await.unwrap();
match response_1 {
TransactionMessage::Registered(index) => {
assert_eq!(index, 0);
}
_ => {
panic!("unexpected response");
}
}
match response_2 {
TransactionMessage::Registered(index) => {
assert_eq!(index, 1);
}
_ => {
panic!("unexpected response");
}
}
match response_3 {
TransactionMessage::Registered(index) => {
assert_eq!(index, 2);
}
_ => {
panic!("unexpected response");
}
}
tx.send(TransactionMessage::Deregister(1)).await.unwrap();
let response = rx_2.recv().await.unwrap();
match response {
TransactionMessage::Unregistered => {}
_ => {
panic!("unexpected response");
}
}
tx.send(TransactionMessage::TransactionOperation(WrappedServerMessage::new(
0,
ServerMessage::Ping(0),
CONNECTION_ID.into(),
)))
.await
.unwrap();
let response = rx_1.recv().await.unwrap();
match response {
TransactionMessage::TransactionOperation(wrapped_message) => {
assert_eq!(wrapped_message.client_id, 0);
assert_eq!(wrapped_message.message, ServerMessage::Ping(0));
}
_ => {
panic!("unexpected response");
}
}
}
#[tokio::test]
async fn test_out_of_bounds_unregister() {
let reader_actor_sim = async {};
let reader_handle = tokio::spawn(reader_actor_sim);
let (tx, rx) = mpsc::channel::<TransactionMessage>(32);
let address = "127.0.0:8080".to_string();
tokio::spawn(async move {
router_actor(rx, address, CONNECTION_ID.into(), reader_handle).await;
});
let (_tx_1, mut rx_1) = mpsc::channel::<TransactionMessage>(32);
tx.send(TransactionMessage::Deregister(0)).await.unwrap();
let timeout_duration = Duration::from_secs(1);
let response = timeout(timeout_duration, rx_1.recv()).await;
match response {
Ok(_) => {
panic!("an unregister that is out of bounds should not return a response");
}
Err(err) => {
assert_eq!(format!("{:?}", err), "Elapsed(())");
}
}
}
#[tokio::test]
async fn test_already_deallocated_unregister() {
let reader_actor_sim = async {};
let reader_handle = tokio::spawn(reader_actor_sim);
let (tx, rx) = mpsc::channel::<TransactionMessage>(32);
let address = "127.0.0:8080".to_string();
tokio::spawn(async move {
router_actor(rx, address, CONNECTION_ID.into(), reader_handle).await;
});
let (tx_1, mut rx_1) = mpsc::channel::<TransactionMessage>(32);
tx.send(TransactionMessage::Register(tx_1.clone())).await.unwrap();
let response = rx_1.recv().await.unwrap();
match response {
TransactionMessage::Registered(index) => {
assert_eq!(index, 0);
}
_ => {
panic!("unexpected response from the register");
}
}
tx.send(TransactionMessage::Deregister(0)).await.unwrap();
let response = rx_1.recv().await.unwrap();
match response {
TransactionMessage::Unregistered => {}
_ => {
panic!("unexpected response from the deregister");
}
}
tx.send(TransactionMessage::Deregister(0)).await.unwrap();
let timeout_duration = Duration::from_secs(1);
let response = timeout(timeout_duration, rx_1.recv()).await;
match response {
Ok(_) => {
panic!("an unregister that is out of bounds should not return a response");
}
Err(err) => {
assert_eq!(format!("{:?}", err), "Elapsed(())");
}
}
}
#[tokio::test]
async fn test_register_deregister_ok() {
let reader_actor_sim = async {};
let reader_handle = tokio::spawn(reader_actor_sim);
let (tx, rx) = mpsc::channel::<TransactionMessage>(32);
let address = "127.0.0:8080".to_string();
tokio::spawn(async move {
router_actor(rx, address, CONNECTION_ID.into(), reader_handle).await;
});
let (tx_1, mut rx_1) = mpsc::channel::<TransactionMessage>(32);
tx.send(TransactionMessage::Register(tx_1.clone())).await.unwrap();
let response = rx_1.recv().await.unwrap();
match response {
TransactionMessage::Registered(index) => {
assert_eq!(index, 0);
}
_ => {
panic!("unexpected response from the register");
}
}
tx.send(TransactionMessage::Deregister(0)).await.unwrap();
let response = rx_1.recv().await.unwrap();
match response {
TransactionMessage::Unregistered => {}
_ => {
panic!("unexpected response from the deregister");
}
}
tx.send(TransactionMessage::Deregister(0)).await.unwrap();
let timeout_duration = Duration::from_secs(1);
let response = timeout(timeout_duration, rx_1.recv()).await;
match response {
Ok(_) => {
panic!("an unregister that is out of bounds should not return a response");
}
Err(err) => {
assert_eq!(format!("{:?}", err), "Elapsed(())");
}
}
}
#[tokio::test]
async fn test_shutdown_not_called() {
let (wr_tx, mut wr_rx) = mpsc::channel::<Option<oneshot::Sender<bool>>>(32);
let reader_handle = tokio::spawn(async move {
while let Some(tx) = wr_rx.recv().await {
tx.unwrap().send(true).unwrap();
}
});
let (tx, rx) = mpsc::channel::<TransactionMessage>(32);
let address = "127.0.0:8080".to_string();
let router_handle = tokio::spawn(async move {
router_actor(rx, address, CONNECTION_ID.into(), reader_handle).await;
});
let (one_tx, one_rx) = oneshot::channel();
wr_tx.send(Some(one_tx)).await.unwrap();
let response = one_rx.await.unwrap();
assert!(response);
tx.send(TransactionMessage::CloseConnection).await.unwrap();
let timeout_duration = Duration::from_secs(1);
timeout(timeout_duration, router_handle).await.unwrap().unwrap();
let (two_tx, _two_rx) = oneshot::channel();
if (wr_tx.send(Some(two_tx)).await).is_ok() {
panic!("the writer task should have shutdown");
}
}
}