surrealcs 0.4.4

The SurrealCS client code for SurrealDB
Documentation
//! Defines the actor for reconnecting to the server when the connection is lost.
use futures::Future;
use nanoservices_utils::errors::NanoServiceError;
use tokio::net::TcpStream;

/// Attempts to connect to the server.
///
/// # Arguments
/// * `address`: The address of the server the connection is being made to
///
/// # Returns
/// A `Result` with the `TcpStream` if the connection is successful, otherwise a `NanoServiceError`
pub async fn attempt_tcp_connection(address: String) -> Result<TcpStream, NanoServiceError> {
	match TcpStream::connect(address).await.map_err(|e| {
		NanoServiceError::new(
			format!("Error connecting to server: {:?}", e),
			nanoservices_utils::errors::NanoServiceErrorStatus::Unknown,
		)
	}) {
		Ok(stream) => Ok(stream),
		Err(e) => Err(e),
	}
}

/// The constructor for creating a connection to the server.
///
/// # Notes
/// This is a struct as opposed to an actor to keep the ownership around the created TCP stream
/// simple. If it was an actor, the handling of async messaging would be simplier, but the sending
/// of the connected TCP stream over a channel would be more complex.
///
/// # Fields
/// * `address`: The address of the server the connection is being made to
#[derive(Debug, PartialEq)]
pub struct ConnectionCreator {
	pub address: String,
}

impl ConnectionCreator {
	/// Attempts to connect to the server.
	///
	/// # Arguments
	/// * `connection_closure`: The closure that will attempt to connect to the server
	///
	/// # Returns
	/// A `Result` with the `TcpStream` if the connection is successful, otherwise a `NanoServiceError`
	pub async fn attempt_connection<F, Fut, Y>(
		&self,
		connection_closure: F,
	) -> Result<Y, NanoServiceError>
	where
		F: FnOnce(String) -> Fut + Copy + Send,
		Fut: Future<Output = Result<Y, NanoServiceError>> + Send,
	{
		connection_closure(self.address.clone()).await
	}
}

#[cfg(test)]
mod tests {

	use super::*;
	use std::sync::atomic::AtomicUsize;
	use tokio::net::TcpListener;

	static CONNECTION_ATTEMPT: AtomicUsize = AtomicUsize::new(0);

	/// To mock the connection attempt
	async fn check_connection(_address: String) -> Result<bool, NanoServiceError> {
		let count = CONNECTION_ATTEMPT.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
		if count == 3 {
			return Ok(true);
		}
		Err(NanoServiceError::new(
			"Error connecting to server".to_string(),
			nanoservices_utils::errors::NanoServiceErrorStatus::Unknown,
		))
	}

	/// To check that we can pass in other futures for the connection attempt
	#[tokio::test]
	async fn test_connection_creator() {
		let connector = ConnectionCreator {
			address: "fake".to_string(),
		};

		let outcome = connector.attempt_connection(check_connection).await;
		assert!(outcome.is_err());

		let outcome = connector.attempt_connection(check_connection).await;
		assert!(outcome.is_err());

		let outcome = connector.attempt_connection(check_connection).await;
		assert!(outcome.is_err());

		let outcome = connector.attempt_connection(check_connection).await;
		assert!(outcome.is_ok());
	}

	#[tokio::test]
	async fn test_live_reconnection() {
		let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
		let addr = listener.local_addr().unwrap().to_string();
		let addr_ref = addr.clone();

		// drop to simulate a dropped connection
		std::mem::drop(listener);

		let connector = ConnectionCreator {
			address: addr.clone(),
		};

		assert!(connector.attempt_connection(attempt_tcp_connection).await.is_err());

		// reconnect to the server to check if the stream is returned
		let _listener = TcpListener::bind(addr_ref).await.unwrap();

		assert!(connector.attempt_connection(attempt_tcp_connection).await.is_ok());
	}
}