surrealcs 0.4.4

The SurrealCS client code for SurrealDB
Documentation
//! The Ping actor for keeping the connection alive and alerting if the ping times out.
use nanoservices_utils::errors::{NanoServiceError, NanoServiceErrorStatus};
use tokio::sync::mpsc;
use tokio::time::{timeout, Duration};
use tracing::instrument;

use surrealcs_kernel::logging::messages::connections::ping::PingJourney;
use surrealcs_kernel::messages::client::message::TransactionMessage;

const TARGET: &str = "surrealcs::client::pinger";

/// Constructs a ping actor for the connection.
///
/// # Arguments
/// * `writer_tx`: The sender for the connection writer actor
/// * `connection_id`: The ID of the connection
pub async fn ping_actor_constructor(
	writer_tx: mpsc::Sender<TransactionMessage>,
	connection_id: String,
	channels: (mpsc::Sender<TransactionMessage>, mpsc::Receiver<TransactionMessage>),
) -> Result<(), NanoServiceError> {
	let internal_tx = channels.0;
	let mut internal_rx = channels.1;

	// register the ping actor with the writer actor for the connection so it can recieve
	// responses from the router actor for the connection
	let register_message = TransactionMessage::Register(internal_tx);
	match writer_tx.send(register_message).await {
		Ok(_) => {}
		Err(_) => {
			return Err(NanoServiceError::new(
				"Failed to send register message for ping actor with connection router actor"
					.to_string(),
				NanoServiceErrorStatus::Unknown,
			));
		}
	}
	let register_response = match internal_rx.recv().await {
		Some(message) => message,
		None => {
			return Err(NanoServiceError::new(
				"Failed to recieve register response for ping actor with connection router actor"
					.to_string(),
				NanoServiceErrorStatus::Unknown,
			));
		}
	};
	let router_address = match register_response {
		TransactionMessage::Registered(index) => index,
		_ => {
			return Err(
                NanoServiceError::new(
                    format!(
                        "Got wrong response from connection router actor when registering the ping actor: {:?}", 
                        register_response),
                    NanoServiceErrorStatus::Unknown
                )
            );
		}
	};
	// spawn the ping actor
	tokio::spawn(async move {
		ping_actor(writer_tx, internal_rx, router_address, connection_id).await;
	});
	Ok(())
}

/// Actor for pinging the connection every second to check if the connection is live.
///
/// # Notes
/// One ping actor is created for each connection
///
/// # Arguments
/// * `rx`: The reciever for the connection writer actor
#[instrument(level = "trace", target = "surrealcs::client::pinger", skip_all)]
pub async fn ping_actor(
	writer_tx: mpsc::Sender<TransactionMessage>,
	mut internal_rx: mpsc::Receiver<TransactionMessage>,
	router_address: usize,
	connection_id: String,
) {
	// Wait initially for a small amount of time
	let timeout_duration = Duration::from_secs(2);
	// Send ping messages continually in a loop
	loop {
		let message = TransactionMessage::Ping((router_address, connection_id.clone()));

		// Send a ping to the writer actor
		match writer_tx.send(message).await {
			Ok(_) => {
				tracing::trace!(target: TARGET, connection_id = %connection_id, "ping: {}", PingJourney::SentByActor.as_str());
			}
			Err(_) => {
				tracing::error!(target: TARGET, connection_id = %connection_id, "failed to send ping");
			}
		}

		// wait for a response from the writer actor or a kill signal
		match timeout(timeout_duration, internal_rx.recv()).await {
			Ok(Some(message)) => match message {
				TransactionMessage::CloseConnection => {
					tracing::info!(target: TARGET, connection_id = %connection_id, "ping actor shutting down: {}", connection_id);
					break;
				}
				_ => {
					tracing::trace!(target: TARGET, connection_id = %connection_id, "pong: {}", PingJourney::RecievedByActor.as_str())
				}
			},
			Ok(None) => {
				tracing::error!(target: TARGET, connection_id = %connection_id, "pong timeout");
			}
			Err(e) => {
				tracing::error!(target: TARGET, connection_id = %connection_id, "pong failure: {e:?}");
			}
		}
		tokio::time::sleep(std::time::Duration::from_secs(1)).await;
	}
}

#[cfg(test)]
mod tests {

	use super::*;
	use tokio::time::Duration;

	static CONNECTION_ID: &str = "1-1234567890";

	async fn writer_actor_mock(mut rx: mpsc::Receiver<TransactionMessage>) {
		while let Some(message) = rx.recv().await {
			if let TransactionMessage::Register(sender) = message {
				let response = TransactionMessage::Registered(1);
				sender.send(response).await.unwrap();
			}
		}
	}

	#[tokio::test]
	async fn test_ping_actor_constructor_ok() {
		let (tx, rx) = mpsc::channel::<TransactionMessage>(32);
		tokio::spawn(async move {
			writer_actor_mock(rx).await;
		});
		let (internal_tx, internal_rx) = mpsc::channel::<TransactionMessage>(32);
		let result =
			ping_actor_constructor(tx, CONNECTION_ID.into(), (internal_tx, internal_rx)).await;
		assert!(result.is_ok());
	}

	#[tokio::test]
	async fn test_ping_actor_constructor_err_sending_register() {
		async fn internal_writer_actor_mock(_rx: mpsc::Receiver<TransactionMessage>) {}
		let (tx, rx) = mpsc::channel::<TransactionMessage>(32);
		tokio::spawn(async move {
			internal_writer_actor_mock(rx).await;
		});
		std::thread::sleep(std::time::Duration::from_millis(10));
		let (internal_tx, internal_rx) = mpsc::channel::<TransactionMessage>(32);
		let result =
			ping_actor_constructor(tx, CONNECTION_ID.into(), (internal_tx, internal_rx)).await;

		let message =
			"Failed to recieve register response for ping actor with connection router actor"
				.to_string();
		match result {
			Err(e) => {
				assert_eq!(e.message, message);
			}
			_ => {
				panic!("wrong result, should be error");
			}
		}
	}

	#[tokio::test]
	async fn test_ping_actor_constructor_err_recieving_register_response() {
		async fn internal_writer_actor_mock(mut rx: mpsc::Receiver<TransactionMessage>) {
			rx.recv().await;
		}
		let (tx, rx) = mpsc::channel::<TransactionMessage>(32);
		tokio::spawn(async move {
			internal_writer_actor_mock(rx).await;
		});
		std::thread::sleep(std::time::Duration::from_millis(10));
		let (internal_tx, internal_rx) = mpsc::channel::<TransactionMessage>(32);
		let result =
			ping_actor_constructor(tx, CONNECTION_ID.into(), (internal_tx, internal_rx)).await;

		let message =
			"Failed to recieve register response for ping actor with connection router actor"
				.to_string();
		match result {
			Err(e) => {
				assert_eq!(e.message, message);
			}
			_ => {
				panic!("wrong result, should be error");
			}
		}
	}

	#[tokio::test]
	async fn test_ping_actor_constructor_err_wrong_response() {
		async fn internal_writer_actor_mock(mut rx: mpsc::Receiver<TransactionMessage>) {
			while let Some(message) = rx.recv().await {
				if let TransactionMessage::Register(sender) = message {
					let response = TransactionMessage::Unregistered;
					sender.send(response).await.unwrap();
				}
			}
		}
		let (tx, rx) = mpsc::channel::<TransactionMessage>(32);
		tokio::spawn(async move {
			internal_writer_actor_mock(rx).await;
		});
		std::thread::sleep(std::time::Duration::from_millis(10));
		let (internal_tx, internal_rx) = mpsc::channel::<TransactionMessage>(32);
		let result =
			ping_actor_constructor(tx, CONNECTION_ID.into(), (internal_tx, internal_rx)).await;

		let message = "Got wrong response from connection router actor when registering the ping actor: Unregistered".to_string();
		match result {
			Err(e) => {
				assert_eq!(e.message, message);
			}
			_ => {
				panic!("wrong result, should be error");
			}
		}
	}

	#[tokio::test]
	async fn test_ping_actor_ok() {
		let (writer_tx, mut writer_rx) = mpsc::channel::<TransactionMessage>(32);
		let (internal_tx, internal_rx) = mpsc::channel::<TransactionMessage>(32);
		let router_address = 1;
		tokio::spawn(async move {
			ping_actor(writer_tx, internal_rx, router_address, CONNECTION_ID.into()).await;
		});

		let message = writer_rx.recv().await.unwrap();
		match message {
			TransactionMessage::Ping(address) => {
				assert_eq!(address.0, router_address);
			}
			_ => {
				panic!("wrong message type, should be ping, got: {:?}", message);
			}
		}
		internal_tx.send(TransactionMessage::Ping((1, CONNECTION_ID.into()))).await.unwrap();
		std::thread::sleep(std::time::Duration::from_secs(1));
		let message = writer_rx.recv().await.unwrap();
		match message {
			TransactionMessage::Ping(address) => {
				assert_eq!(address.0, router_address);
			}
			_ => {
				panic!("wrong message type, should be ping, got: {:?}", message);
			}
		}
	}

	#[tokio::test]
	async fn test_ping_shutdown() {
		let (writer_tx, mut writer_rx) = mpsc::channel::<TransactionMessage>(32);
		let (internal_tx, internal_rx) = mpsc::channel::<TransactionMessage>(32);
		let router_address = 1;
		let ping_handle = tokio::spawn(async move {
			ping_actor(writer_tx, internal_rx, router_address, CONNECTION_ID.into()).await;
		});

		let message = writer_rx.recv().await.unwrap();
		match message {
			TransactionMessage::Ping(address) => {
				assert_eq!(address.0, router_address);
			}
			_ => {
				panic!("wrong message type, should be ping, got: {:?}", message);
			}
		}
		internal_tx.send(TransactionMessage::CloseConnection).await.unwrap();

		timeout(Duration::from_secs(1), ping_handle).await.unwrap().unwrap();
	}

	#[tokio::test]
	async fn test_ping_shutdown_with_dropped_writer() {
		let (writer_tx, mut writer_rx) = mpsc::channel::<TransactionMessage>(32);
		let (internal_tx, internal_rx) = mpsc::channel::<TransactionMessage>(32);
		let router_address = 1;
		let ping_handle = tokio::spawn(async move {
			ping_actor(writer_tx, internal_rx, router_address, CONNECTION_ID.into()).await;
		});

		let message = writer_rx.recv().await.unwrap();
		match message {
			TransactionMessage::Ping(address) => {
				assert_eq!(address.0, router_address);
			}
			_ => {
				panic!("wrong message type, should be ping, got: {:?}", message);
			}
		}
		drop(writer_rx);
		internal_tx.send(TransactionMessage::CloseConnection).await.unwrap();

		timeout(Duration::from_secs(1), ping_handle).await.unwrap().unwrap();
	}
}