surrealcs 0.4.4

The SurrealCS client code for SurrealDB
Documentation
//! Defines the writer actor that writes messages to the TCP stream and thus to the server.
//!
//! # Routes
//! The writer actor listens for messages directly from client actors and supports the folloing routes:
//!
//! ## Register
//! The writer actor directs the register message to the router actor to be registered with the following
//! loop:
//! ```bash
//! [actor] -> (Sender) -> [writer] -> (Sender) -> [router] -> (ID) -> [actor]
//! ```
//!
//! ## TransactionOperation
//! The writer actor sends the transaction operation message to the server with the following loop:
//! ```bash
//! [actor] -> (TransactionOperation) -> [writer] -> (WrappedServerMessage) -> [server]
//! ```
//!
//! ## Ping
//! The writer actor sends the ping message to the server with the following loop:
//! ```bash
//! [actor] -> (Ping) -> [writer] -> (Ping) -> [server]
//! ```
use super::recovery_process::interface::run_recovery_process;
use surrealcs_kernel::logging::messages::actors_client::writer::log_client_writer_message;
use surrealcs_kernel::messages::client::message::TransactionMessage;
use tokio::{net::tcp::OwnedWriteHalf, sync::mpsc};
use tracing::instrument;
use utils::send_message_over_tcp;

mod close_connection;
mod ping;
mod register;
mod utils;

/// Accepts messages, serializes them, and sends them to the server via TCP.
///
/// # Arguments
/// * `writer`: the writer to the TCP stream
/// * `rx`: The receiver to accept messages to the actor
/// * `router_tx`: The sender to the router actor
#[instrument(level = "trace", target = "surrealcs::client::writer", skip_all)]
pub async fn writer_actor(
	writer: OwnedWriteHalf,
	mut rx: mpsc::Receiver<TransactionMessage>,
	tx: mpsc::Sender<TransactionMessage>,
	router_tx: mpsc::Sender<TransactionMessage>,
	address: String,
	connection_id: String,
	ping_tx: mpsc::Sender<TransactionMessage>,
) {
	let mut writer = writer;

	while let Some(message) = rx.recv().await {
		log_client_writer_message(&message);

		// we still want to control what is sent through to the server hence the extra steps
		match message {
			TransactionMessage::Ping((client_id, con_id)) => {
				let outcome = ping::ping(client_id, con_id, &mut writer).await;
				if outcome.is_err() {
					// this cannot be abstracted out as it is replacing the writer with a new one in the recovery process
					writer = run_recovery_process(
						&mut rx,
						tx.clone(),
						router_tx.clone(),
						address.clone(),
						connection_id.clone(),
					)
					.await
					.unwrap();
				}
				continue;
			}
			TransactionMessage::Register(sender) => {
				register::register(sender, &router_tx).await;
				continue;
			}
			TransactionMessage::TransactionOperation(wrapped_op) => {
				let outcome = send_message_over_tcp(&mut writer, &wrapped_op).await;
				if outcome.is_err() {
					// this cannot be abstracted out as it is replacing the writer with a new one in the recovery process
					writer = run_recovery_process(
						&mut rx,
						tx.clone(),
						router_tx.clone(),
						address.clone(),
						connection_id.clone(),
					)
					.await
					.unwrap();
				}
				continue;
			}
			TransactionMessage::CloseConnection => {
				close_connection::close_connection(
					ping_tx,
					connection_id.clone(),
					router_tx.clone(),
					writer,
				)
				.await;
				break;
			}
			_ => {}
		}
		tracing::error!("the message that slipped through for the writer: {:?}", message)
	}
}

#[cfg(test)]
mod tests {
	use tokio::net::{TcpListener, TcpStream};

	use surrealcs_kernel::messages::serialization::bincode_processes::message::deserialize_from_stream;
	use surrealcs_kernel::messages::server::interface::ServerMessage;
	use surrealcs_kernel::messages::server::wrapper::WrappedServerMessage;
	use tokio::time::{timeout, Duration};

	use super::*;

	static CONNECTION_ID: &str = "1-1234567890";

	#[tokio::test]
	async fn test_writer_actor_send_operation() {
		let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
		let addr = listener.local_addr().unwrap();
		let address = addr.to_string();

		let client_stream = TcpStream::connect(addr).await.unwrap();
		let (_client_reader, client_writer) = client_stream.into_split();

		let server_stream: TcpStream = listener.accept().await.unwrap().0;
		let (mut server_reader, _server_writer) = server_stream.into_split();

		let (router_tx, _router_rx) = mpsc::channel::<TransactionMessage>(32);
		let (ping_tx, _ping_rx) = mpsc::channel::<TransactionMessage>(32);
		let (tx, rx) = mpsc::channel::<TransactionMessage>(32);
		let tx_ref = tx.clone();
		tokio::spawn(async move {
			writer_actor(
				client_writer,
				rx,
				tx_ref,
				router_tx,
				address,
				CONNECTION_ID.into(),
				ping_tx,
			)
			.await;
		});
		let message = WrappedServerMessage::new(20, ServerMessage::Ping(1), CONNECTION_ID.into());
		let wrapped_message = TransactionMessage::TransactionOperation(message);
		tx.send(wrapped_message).await.unwrap();

		// get the recurring ping
		let server_message = deserialize_from_stream(&mut server_reader).await.unwrap();
		match server_message.message {
			ServerMessage::Ping(_) => {}
			_ => {
				panic!("Server message not as expected");
			}
		}
		assert_eq!(server_message.client_id, 20);
		assert_eq!(server_message.connection_id, CONNECTION_ID.to_string());
	}

	#[tokio::test]
	async fn test_writer_actor_send_ping() {
		let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
		let addr = listener.local_addr().unwrap();
		let address = addr.to_string();

		let client_stream = TcpStream::connect(addr).await.unwrap();
		let (_client_reader, client_writer) = client_stream.into_split();

		let server_stream = listener.accept().await.unwrap().0;
		let (mut server_reader, _server_writer) = server_stream.into_split();

		let (router_tx, _router_rx) = mpsc::channel::<TransactionMessage>(32);
		let (ping_tx, _ping_rx) = mpsc::channel::<TransactionMessage>(32);
		let (tx, rx) = mpsc::channel::<TransactionMessage>(32);
		let tx_ref = tx.clone();
		tokio::spawn(async move {
			writer_actor(
				client_writer,
				rx,
				tx_ref,
				router_tx,
				address,
				CONNECTION_ID.into(),
				ping_tx,
			)
			.await;
		});
		let message = TransactionMessage::Ping((1, CONNECTION_ID.into()));
		tx.send(message).await.unwrap();

		let server_message = deserialize_from_stream(&mut server_reader).await.unwrap();
		match server_message.message {
			ServerMessage::Ping(_) => {}
			_ => {
				panic!("Server message not as expected");
			}
		}
	}

	#[tokio::test]
	async fn test_writer_actor_send_register() {
		let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
		let addr = listener.local_addr().unwrap();
		let address = addr.to_string();

		let client_stream = TcpStream::connect(addr).await.unwrap();
		let (_client_reader, client_writer) = client_stream.into_split();

		let server_stream = listener.accept().await.unwrap().0;
		let (_server_reader, _server_writer) = server_stream.into_split();

		let (router_tx, mut router_rx) = mpsc::channel::<TransactionMessage>(32);
		let (ping_tx, _ping_rx) = mpsc::channel::<TransactionMessage>(32);
		let (tx, rx) = mpsc::channel::<TransactionMessage>(32);
		let tx_ref = tx.clone();
		tokio::spawn(async move {
			writer_actor(
				client_writer,
				rx,
				tx_ref,
				router_tx,
				address,
				CONNECTION_ID.into(),
				ping_tx,
			)
			.await;
		});

		let (r_tx, _r_rx) = mpsc::channel::<TransactionMessage>(32);
		let message = TransactionMessage::Register(r_tx);
		tx.send(message).await.unwrap();

		let result = router_rx.recv().await.unwrap();
		match result {
			TransactionMessage::Register(_) => {}
			_ => {
				panic!(
					"Wrong message type received should have been Register! but was {:?}",
					result
				);
			}
		}
	}

	#[tokio::test]
	async fn test_close_connection() {
		let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
		let addr = listener.local_addr().unwrap();
		let address = addr.to_string();

		let client_stream = TcpStream::connect(addr).await.unwrap();
		let (_client_reader, client_writer) = client_stream.into_split();

		let server_stream = listener.accept().await.unwrap().0;
		let (_server_reader, _server_writer) = server_stream.into_split();

		let (router_tx, mut router_rx) = mpsc::channel::<TransactionMessage>(32);
		let (ping_tx, mut ping_rx) = mpsc::channel::<TransactionMessage>(32);
		let (tx, rx) = mpsc::channel::<TransactionMessage>(32);
		let tx_ref = tx.clone();

		let writer_handle = tokio::spawn(async move {
			writer_actor(
				client_writer,
				rx,
				tx_ref,
				router_tx,
				address,
				CONNECTION_ID.into(),
				ping_tx,
			)
			.await;
		});

		// shutdown the writer actor
		let message = TransactionMessage::CloseConnection;
		tx.send(message).await.unwrap();

		// assert that the writer has shutdown
		let timeout_duration = Duration::from_secs(1);
		timeout(timeout_duration, writer_handle).await.unwrap().unwrap();

		// assert that the router has received the close connection message
		let result = router_rx.recv().await.unwrap();
		match result {
			TransactionMessage::CloseConnection => {}
			_ => {
				panic!(
					"Wrong message type received should have been CloseConnection! but was {:?}",
					result
				);
			}
		}

		// assert that the ping actor has received the close connection message
		let result = ping_rx.recv().await.unwrap();
		match result {
			TransactionMessage::CloseConnection => {}
			_ => {
				panic!(
					"Wrong message type received should have been CloseConnection! but was {:?}",
					result
				);
			}
		}
	}
}