use super::bridge::SendToTransactionClientActor;
use super::interface::{Any, InProgress, NotStarted, Transaction};
use crate::router::send_to_router;
use crate::transactions::writer_bridge::establish_transaction_connection;
use nanoservices_utils::errors::NanoServiceError;
use std::future::Future;
use std::marker::PhantomData;
use surrealcs_kernel::{
logging::messages::connections::id::get_connection_index,
logging::messages::transactions::id::create_id,
messages::client::router::RouterMessage,
messages::server::interface::{ServerMessage, ServerTransactionMessage},
};
impl Transaction<NotStarted> {
pub async fn new() -> Result<Transaction<NotStarted>, NanoServiceError> {
Transaction::internal_new(send_to_router).await
}
pub async fn internal_new<F, Fut>(
closure: F,
) -> Result<Transaction<NotStarted>, NanoServiceError>
where
F: FnOnce(RouterMessage) -> Fut,
Fut: Future<Output = Result<RouterMessage, NanoServiceError>>,
{
let (return_tx, client_id, h_rx, connection_id) =
establish_transaction_connection(closure).await?;
let connection_index = get_connection_index(&connection_id);
let transaction_id = create_id(connection_index, client_id);
Ok(Transaction {
client_id,
server_id: None,
receiver: h_rx,
sender: return_tx,
connection_id,
transaction_id,
state: PhantomData,
})
}
#[allow(dead_code)]
pub fn into_any(self) -> Transaction<Any> {
Transaction {
client_id: self.client_id,
server_id: self.server_id,
receiver: self.receiver,
sender: self.sender,
connection_id: self.connection_id,
transaction_id: self.transaction_id,
state: PhantomData,
}
}
pub async fn begin<T: SendToTransactionClientActor>(
mut self,
operation: ServerTransactionMessage,
) -> Result<(ServerTransactionMessage, Transaction<InProgress>), NanoServiceError> {
let message = ServerMessage::BeginTransaction(operation);
let transaction = T::send_to_transaction_client_actor(&mut self, message).await?;
let in_progress_transaction = Transaction {
client_id: self.client_id,
server_id: self.server_id,
receiver: self.receiver,
sender: self.sender.clone(),
connection_id: self.connection_id,
transaction_id: self.transaction_id,
state: PhantomData,
};
Ok((transaction, in_progress_transaction))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::generate_mock_handle;
use std::future::Future;
use surrealcs_kernel::messages::client::message::TransactionMessage;
use surrealcs_kernel::messages::server::kv_operations::{MessageGet, ResponseGet};
use tokio::sync::mpsc;
#[tokio::test]
async fn test_internal_new() {
async fn test_writer(mut rx: mpsc::Receiver<TransactionMessage>) {
match rx.recv().await.unwrap() {
TransactionMessage::Register(sender) => {
let message = TransactionMessage::Registered(1);
sender.send(message).await.unwrap();
match rx.recv().await.unwrap() {
TransactionMessage::TransactionOperation(wrapper) => {
let message = TransactionMessage::TransactionOperation(wrapper);
sender.send(message).await.unwrap();
}
_ => {
panic!("should have gotten a transaction operation second");
}
}
}
_ => {
panic!("should have gotten the register first");
}
}
}
let (writer_tx, writer_rx) = mpsc::channel::<TransactionMessage>(32);
let closue =
|_message: RouterMessage| async { Ok(RouterMessage::ReturnConnection((writer_tx, 0))) };
tokio::spawn(async move {
test_writer(writer_rx).await;
});
let transaction = Transaction::<NotStarted>::internal_new(closue).await.unwrap();
assert_eq!(transaction.client_id, 1);
assert_eq!(transaction.server_id, None);
}
#[tokio::test]
async fn test_begin() {
generate_mock_handle! {
struct MockHandle;
match ServerMessage::BeginTransaction(message) => {
match message {
ServerTransactionMessage::Get(message) => {
assert_eq!(message.key, b"keys".to_vec());
}
_ => {
panic!("message not as expected: {:?}", message);
}
}
}
return Ok(ServerTransactionMessage::ResponseGet(
ResponseGet { value: Some(b"value".to_vec()) }
));
}
let (tx, _rx) = mpsc::channel(32);
let (_h_tx, h_rx) = mpsc::channel(32);
let transaction = Transaction::<NotStarted> {
client_id: 1,
server_id: None,
receiver: h_rx,
sender: tx,
connection_id: "connection_id".to_string(),
transaction_id: "transaction_id".to_string(),
state: PhantomData,
};
let message = ServerTransactionMessage::Get(MessageGet {
key: b"keys".to_vec(),
version: None,
});
let (outcome, _transaction) = transaction.begin::<MockHandle>(message).await.unwrap();
let response = match outcome {
ServerTransactionMessage::ResponseGet(response) => response,
_ => panic!("wrong message type"),
};
assert_eq!(response.value, Some(b"value".to_vec()));
}
}