surrealcs 0.4.4

The SurrealCS client code for SurrealDB
Documentation
//! Defines the functionality around the `NotStarted` transaction state.
use super::bridge::SendToTransactionClientActor;
use super::interface::{Any, InProgress, NotStarted, Transaction};
use crate::router::send_to_router;
use crate::transactions::writer_bridge::establish_transaction_connection;
use nanoservices_utils::errors::NanoServiceError;
use std::future::Future;
use std::marker::PhantomData;
use surrealcs_kernel::{
	logging::messages::connections::id::get_connection_index,
	logging::messages::transactions::id::create_id,
	messages::client::router::RouterMessage,
	messages::server::interface::{ServerMessage, ServerTransactionMessage},
};

impl Transaction<NotStarted> {
	/// Constructs a new transaction using the standard implementation.
	///
	/// # Notes
	/// This function is not unit tested as it is a simple wrapper around
	/// the internal constructor which is unit tested.
	///
	/// # Returns
	/// the constructed the `Transaction` struct
	pub async fn new() -> Result<Transaction<NotStarted>, NanoServiceError> {
		Transaction::internal_new(send_to_router).await
	}

	/// The internal constructor for the `Transaction` struct.
	///
	/// # Notes
	/// The server is not contacted with this constructor, it merely
	/// registers the transaction with the router of a connection in
	/// the client and creates a transaction actor in the client to
	/// listen for messages.
	///
	/// interface is a little more indepth, if you are just using the
	/// standard implementation use the `new` method.
	///
	/// # Arguments
	/// * `closure`: A closure that enables us to send a message to the main router
	///
	/// # Returns
	/// the constructed the `Transaction` struct
	pub async fn internal_new<F, Fut>(
		closure: F,
	) -> Result<Transaction<NotStarted>, NanoServiceError>
	where
		F: FnOnce(RouterMessage) -> Fut,
		Fut: Future<Output = Result<RouterMessage, NanoServiceError>>,
	{
		let (return_tx, client_id, h_rx, connection_id) =
			establish_transaction_connection(closure).await?;

		let connection_index = get_connection_index(&connection_id);
		let transaction_id = create_id(connection_index, client_id);

		Ok(Transaction {
			client_id,
			server_id: None,
			receiver: h_rx,
			sender: return_tx,
			connection_id,
			transaction_id,
			state: PhantomData,
		})
	}

	/// Converts the transaction into an `Any` transaction.
	///
	/// # Returns
	/// the transaction as an `Any` transaction
	#[allow(dead_code)]
	pub fn into_any(self) -> Transaction<Any> {
		Transaction {
			client_id: self.client_id,
			server_id: self.server_id,
			receiver: self.receiver,
			sender: self.sender,
			connection_id: self.connection_id,
			transaction_id: self.transaction_id,
			state: PhantomData,
		}
	}

	/// starts the transaction by sending the initial transaction operation to the server.
	///
	/// # Arguments
	/// * `operation`: the initial transaction operation to send to the server
	///
	/// # Returns
	/// A tuple containing the response from the server and the in progress transaction
	pub async fn begin<T: SendToTransactionClientActor>(
		mut self,
		operation: ServerTransactionMessage,
	) -> Result<(ServerTransactionMessage, Transaction<InProgress>), NanoServiceError> {
		let message = ServerMessage::BeginTransaction(operation);
		let transaction = T::send_to_transaction_client_actor(&mut self, message).await?;

		// construct the in progress transaction
		let in_progress_transaction = Transaction {
			client_id: self.client_id,
			server_id: self.server_id,
			receiver: self.receiver,
			sender: self.sender.clone(),
			connection_id: self.connection_id,
			transaction_id: self.transaction_id,
			state: PhantomData,
		};
		Ok((transaction, in_progress_transaction))
	}
}

#[cfg(test)]
mod tests {

	use super::*;
	use crate::generate_mock_handle;
	use std::future::Future;
	use surrealcs_kernel::messages::client::message::TransactionMessage;
	use surrealcs_kernel::messages::server::kv_operations::{MessageGet, ResponseGet};
	// use surrealcs_kernel::messages::server::message::KeyValueOperationType;
	use tokio::sync::mpsc;

	#[tokio::test]
	async fn test_internal_new() {
		// mimics a test writer
		async fn test_writer(mut rx: mpsc::Receiver<TransactionMessage>) {
			// handle a register (just send directly back)
			match rx.recv().await.unwrap() {
				TransactionMessage::Register(sender) => {
					let message = TransactionMessage::Registered(1);
					sender.send(message).await.unwrap();

					// handle a transaction (just send directly back)
					match rx.recv().await.unwrap() {
						TransactionMessage::TransactionOperation(wrapper) => {
							let message = TransactionMessage::TransactionOperation(wrapper);
							sender.send(message).await.unwrap();
						}
						_ => {
							panic!("should have gotten a transaction operation second");
						}
					}
				}
				_ => {
					panic!("should have gotten the register first");
				}
			}
		}

		// define the writer channel
		let (writer_tx, writer_rx) = mpsc::channel::<TransactionMessage>(32);

		// create a fake main router closure that returns the writer
		let closue =
			|_message: RouterMessage| async { Ok(RouterMessage::ReturnConnection((writer_tx, 0))) };

		// spawn the writer to listen to messages from the
		tokio::spawn(async move {
			test_writer(writer_rx).await;
		});

		let transaction = Transaction::<NotStarted>::internal_new(closue).await.unwrap();

		assert_eq!(transaction.client_id, 1);
		assert_eq!(transaction.server_id, None);
	}

	#[tokio::test]
	async fn test_begin() {
		generate_mock_handle! {
			struct MockHandle;
			match ServerMessage::BeginTransaction(message) => {
				match message {
					ServerTransactionMessage::Get(message) => {
						assert_eq!(message.key, b"keys".to_vec());
					}
					_ => {
						panic!("message not as expected: {:?}", message);
					}
				}
			}
			return Ok(ServerTransactionMessage::ResponseGet(
				ResponseGet { value: Some(b"value".to_vec()) }
			));
		}

		// define the channel for the actor
		let (tx, _rx) = mpsc::channel(32);

		// define the channel for the handle
		let (_h_tx, h_rx) = mpsc::channel(32);

		let transaction = Transaction::<NotStarted> {
			client_id: 1,
			server_id: None,
			receiver: h_rx,
			sender: tx,
			connection_id: "connection_id".to_string(),
			transaction_id: "transaction_id".to_string(),
			state: PhantomData,
		};
		let message = ServerTransactionMessage::Get(MessageGet {
			key: b"keys".to_vec(),
			version: None,
		});

		let (outcome, _transaction) = transaction.begin::<MockHandle>(message).await.unwrap();

		let response = match outcome {
			ServerTransactionMessage::ResponseGet(response) => response,
			_ => panic!("wrong message type"),
		};
		assert_eq!(response.value, Some(b"value".to_vec()));
	}
}