use super::recovery_process::interface::run_recovery_process;
use surrealcs_kernel::logging::messages::actors_client::writer::log_client_writer_message;
use surrealcs_kernel::messages::client::message::TransactionMessage;
use tokio::{net::tcp::OwnedWriteHalf, sync::mpsc};
use tracing::instrument;
use utils::send_message_over_tcp;
mod close_connection;
mod ping;
mod register;
mod utils;
#[instrument(level = "trace", target = "surrealcs::client::writer", skip_all)]
pub async fn writer_actor(
writer: OwnedWriteHalf,
mut rx: mpsc::Receiver<TransactionMessage>,
tx: mpsc::Sender<TransactionMessage>,
router_tx: mpsc::Sender<TransactionMessage>,
address: String,
connection_id: String,
ping_tx: mpsc::Sender<TransactionMessage>,
) {
let mut writer = writer;
while let Some(message) = rx.recv().await {
log_client_writer_message(&message);
match message {
TransactionMessage::Ping((client_id, con_id)) => {
let outcome = ping::ping(client_id, con_id, &mut writer).await;
if outcome.is_err() {
writer = run_recovery_process(
&mut rx,
tx.clone(),
router_tx.clone(),
address.clone(),
connection_id.clone(),
)
.await
.unwrap();
}
continue;
}
TransactionMessage::Register(sender) => {
register::register(sender, &router_tx).await;
continue;
}
TransactionMessage::TransactionOperation(wrapped_op) => {
let outcome = send_message_over_tcp(&mut writer, &wrapped_op).await;
if outcome.is_err() {
writer = run_recovery_process(
&mut rx,
tx.clone(),
router_tx.clone(),
address.clone(),
connection_id.clone(),
)
.await
.unwrap();
}
continue;
}
TransactionMessage::CloseConnection => {
close_connection::close_connection(
ping_tx,
connection_id.clone(),
router_tx.clone(),
writer,
)
.await;
break;
}
_ => {}
}
tracing::error!("the message that slipped through for the writer: {:?}", message)
}
}
#[cfg(test)]
mod tests {
use tokio::net::{TcpListener, TcpStream};
use surrealcs_kernel::messages::serialization::bincode_processes::message::deserialize_from_stream;
use surrealcs_kernel::messages::server::interface::ServerMessage;
use surrealcs_kernel::messages::server::wrapper::WrappedServerMessage;
use tokio::time::{timeout, Duration};
use super::*;
static CONNECTION_ID: &str = "1-1234567890";
#[tokio::test]
async fn test_writer_actor_send_operation() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let address = addr.to_string();
let client_stream = TcpStream::connect(addr).await.unwrap();
let (_client_reader, client_writer) = client_stream.into_split();
let server_stream: TcpStream = listener.accept().await.unwrap().0;
let (mut server_reader, _server_writer) = server_stream.into_split();
let (router_tx, _router_rx) = mpsc::channel::<TransactionMessage>(32);
let (ping_tx, _ping_rx) = mpsc::channel::<TransactionMessage>(32);
let (tx, rx) = mpsc::channel::<TransactionMessage>(32);
let tx_ref = tx.clone();
tokio::spawn(async move {
writer_actor(
client_writer,
rx,
tx_ref,
router_tx,
address,
CONNECTION_ID.into(),
ping_tx,
)
.await;
});
let message = WrappedServerMessage::new(20, ServerMessage::Ping(1), CONNECTION_ID.into());
let wrapped_message = TransactionMessage::TransactionOperation(message);
tx.send(wrapped_message).await.unwrap();
let server_message = deserialize_from_stream(&mut server_reader).await.unwrap();
match server_message.message {
ServerMessage::Ping(_) => {}
_ => {
panic!("Server message not as expected");
}
}
assert_eq!(server_message.client_id, 20);
assert_eq!(server_message.connection_id, CONNECTION_ID.to_string());
}
#[tokio::test]
async fn test_writer_actor_send_ping() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let address = addr.to_string();
let client_stream = TcpStream::connect(addr).await.unwrap();
let (_client_reader, client_writer) = client_stream.into_split();
let server_stream = listener.accept().await.unwrap().0;
let (mut server_reader, _server_writer) = server_stream.into_split();
let (router_tx, _router_rx) = mpsc::channel::<TransactionMessage>(32);
let (ping_tx, _ping_rx) = mpsc::channel::<TransactionMessage>(32);
let (tx, rx) = mpsc::channel::<TransactionMessage>(32);
let tx_ref = tx.clone();
tokio::spawn(async move {
writer_actor(
client_writer,
rx,
tx_ref,
router_tx,
address,
CONNECTION_ID.into(),
ping_tx,
)
.await;
});
let message = TransactionMessage::Ping((1, CONNECTION_ID.into()));
tx.send(message).await.unwrap();
let server_message = deserialize_from_stream(&mut server_reader).await.unwrap();
match server_message.message {
ServerMessage::Ping(_) => {}
_ => {
panic!("Server message not as expected");
}
}
}
#[tokio::test]
async fn test_writer_actor_send_register() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let address = addr.to_string();
let client_stream = TcpStream::connect(addr).await.unwrap();
let (_client_reader, client_writer) = client_stream.into_split();
let server_stream = listener.accept().await.unwrap().0;
let (_server_reader, _server_writer) = server_stream.into_split();
let (router_tx, mut router_rx) = mpsc::channel::<TransactionMessage>(32);
let (ping_tx, _ping_rx) = mpsc::channel::<TransactionMessage>(32);
let (tx, rx) = mpsc::channel::<TransactionMessage>(32);
let tx_ref = tx.clone();
tokio::spawn(async move {
writer_actor(
client_writer,
rx,
tx_ref,
router_tx,
address,
CONNECTION_ID.into(),
ping_tx,
)
.await;
});
let (r_tx, _r_rx) = mpsc::channel::<TransactionMessage>(32);
let message = TransactionMessage::Register(r_tx);
tx.send(message).await.unwrap();
let result = router_rx.recv().await.unwrap();
match result {
TransactionMessage::Register(_) => {}
_ => {
panic!(
"Wrong message type received should have been Register! but was {:?}",
result
);
}
}
}
#[tokio::test]
async fn test_close_connection() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let address = addr.to_string();
let client_stream = TcpStream::connect(addr).await.unwrap();
let (_client_reader, client_writer) = client_stream.into_split();
let server_stream = listener.accept().await.unwrap().0;
let (_server_reader, _server_writer) = server_stream.into_split();
let (router_tx, mut router_rx) = mpsc::channel::<TransactionMessage>(32);
let (ping_tx, mut ping_rx) = mpsc::channel::<TransactionMessage>(32);
let (tx, rx) = mpsc::channel::<TransactionMessage>(32);
let tx_ref = tx.clone();
let writer_handle = tokio::spawn(async move {
writer_actor(
client_writer,
rx,
tx_ref,
router_tx,
address,
CONNECTION_ID.into(),
ping_tx,
)
.await;
});
let message = TransactionMessage::CloseConnection;
tx.send(message).await.unwrap();
let timeout_duration = Duration::from_secs(1);
timeout(timeout_duration, writer_handle).await.unwrap().unwrap();
let result = router_rx.recv().await.unwrap();
match result {
TransactionMessage::CloseConnection => {}
_ => {
panic!(
"Wrong message type received should have been CloseConnection! but was {:?}",
result
);
}
}
let result = ping_rx.recv().await.unwrap();
match result {
TransactionMessage::CloseConnection => {}
_ => {
panic!(
"Wrong message type received should have been CloseConnection! but was {:?}",
result
);
}
}
}
}