use crate::client::AmateRSClient;
use crate::error::{Result, SdkError};
use amaters_core::{CipherBlob, Key, Query};
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum TransactionState {
Active,
Committed,
RolledBack,
}
#[derive(Debug, Clone)]
enum TransactionOp {
Set { key: Key, value: CipherBlob },
Delete { key: Key },
}
pub struct Transaction {
collection: String,
ops: Vec<TransactionOp>,
client: Arc<AmateRSClient>,
state: TransactionState,
}
impl Transaction {
pub fn new(client: Arc<AmateRSClient>, collection: impl Into<String>) -> Self {
Self {
collection: collection.into(),
ops: Vec::new(),
client,
state: TransactionState::Active,
}
}
pub fn set(&mut self, key: Key, value: CipherBlob) -> Result<()> {
self.ensure_active()?;
self.ops.push(TransactionOp::Set { key, value });
Ok(())
}
pub fn delete(&mut self, key: Key) -> Result<()> {
self.ensure_active()?;
self.ops.push(TransactionOp::Delete { key });
Ok(())
}
pub async fn get(&self, key: &Key) -> Result<Option<CipherBlob>> {
self.ensure_active()?;
for op in self.ops.iter().rev() {
match op {
TransactionOp::Set { key: k, value: v } if k == key => {
return Ok(Some(v.clone()));
}
TransactionOp::Delete { key: k } if k == key => {
return Ok(None);
}
_ => {}
}
}
self.client.get(&self.collection, key).await
}
pub async fn commit(&mut self) -> Result<()> {
self.ensure_active()?;
if !self.ops.is_empty() {
let queries: Vec<Query> = self
.ops
.drain(..)
.map(|op| match op {
TransactionOp::Set { key, value } => Query::Set {
collection: self.collection.clone(),
key,
value,
},
TransactionOp::Delete { key } => Query::Delete {
collection: self.collection.clone(),
key,
},
})
.collect();
self.client.execute_batch(queries).await?;
}
self.state = TransactionState::Committed;
Ok(())
}
pub fn rollback(&mut self) -> Result<()> {
self.ensure_active()?;
self.ops.clear();
self.state = TransactionState::RolledBack;
Ok(())
}
pub fn pending_ops(&self) -> usize {
self.ops.len()
}
pub fn is_active(&self) -> bool {
self.state == TransactionState::Active
}
pub fn collection(&self) -> &str {
&self.collection
}
fn ensure_active(&self) -> Result<()> {
if self.state != TransactionState::Active {
Err(SdkError::InvalidState(
"transaction already committed or rolled back".to_string(),
))
} else {
Ok(())
}
}
}
impl Drop for Transaction {
fn drop(&mut self) {
if self.state == TransactionState::Active && !self.ops.is_empty() {
tracing::warn!(
pending_ops = self.ops.len(),
collection = %self.collection,
"Transaction dropped with uncommitted operation(s) — changes discarded",
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::ClientConfig;
use amaters_core::{CipherBlob, Key};
fn offline_client() -> Arc<AmateRSClient> {
let config = ClientConfig::new("http://127.0.0.1:50051");
Arc::new(AmateRSClient::new_offline(config))
}
#[test]
fn test_transaction_rollback_clears_buffer() {
let client = offline_client();
let mut tx = Transaction::new(client, "users");
let key = Key::from_str("k1");
let val = CipherBlob::new(vec![1, 2, 3]);
tx.set(key, val).expect("set should succeed on active tx");
assert_eq!(tx.pending_ops(), 1);
tx.rollback().expect("rollback should succeed on active tx");
assert_eq!(tx.pending_ops(), 0);
assert!(!tx.is_active());
}
#[test]
fn test_transaction_double_commit_returns_error() {
let client = offline_client();
let mut tx = Transaction::new(client, "users");
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("failed to build runtime");
rt.block_on(async {
tx.commit().await.expect("first commit should succeed");
let err = tx
.commit()
.await
.expect_err("second commit should return Err");
assert!(
matches!(err, SdkError::InvalidState(_)),
"expected InvalidState, got: {err}"
);
});
}
#[test]
fn test_transaction_commit_then_rollback_is_error() {
let client = offline_client();
let mut tx = Transaction::new(client, "users");
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("failed to build runtime");
rt.block_on(async {
tx.commit().await.expect("commit should succeed");
let err = tx
.rollback()
.expect_err("rollback after commit should return Err");
assert!(
matches!(err, SdkError::InvalidState(_)),
"expected InvalidState, got: {err}"
);
});
}
#[test]
fn test_transaction_rollback_then_commit_is_error() {
let client = offline_client();
let mut tx = Transaction::new(client, "users");
tx.rollback().expect("rollback should succeed on active tx");
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("failed to build runtime");
rt.block_on(async {
let err = tx
.commit()
.await
.expect_err("commit after rollback should return Err");
assert!(
matches!(err, SdkError::InvalidState(_)),
"expected InvalidState, got: {err}"
);
});
}
#[tokio::test]
async fn test_transaction_read_sees_local_set() {
let client = offline_client();
let mut tx = Transaction::new(client, "users");
let key = Key::from_str("local_key");
let val = CipherBlob::new(vec![10, 20, 30]);
tx.set(key.clone(), val.clone())
.expect("set should succeed");
let result = tx.get(&key).await.expect("get should succeed (local hit)");
assert_eq!(
result.as_ref().map(|b| b.to_vec()),
Some(val.to_vec()),
"get should return the locally staged value"
);
}
#[tokio::test]
async fn test_transaction_read_sees_local_delete_as_none() {
let client = offline_client();
let mut tx = Transaction::new(client, "users");
let key = Key::from_str("will_delete");
let val = CipherBlob::new(vec![1]);
tx.set(key.clone(), val).expect("set should succeed");
tx.delete(key.clone()).expect("delete should succeed");
let result = tx
.get(&key)
.await
.expect("get should succeed (local delete hit)");
assert!(
result.is_none(),
"locally deleted key should appear as None"
);
}
#[tokio::test]
async fn test_transaction_read_last_write_wins() {
let client = offline_client();
let mut tx = Transaction::new(client, "users");
let key = Key::from_str("overwritten");
let v1 = CipherBlob::new(vec![1]);
let v2 = CipherBlob::new(vec![2]);
tx.set(key.clone(), v1).expect("first set");
tx.set(key.clone(), v2.clone()).expect("second set");
let result = tx.get(&key).await.expect("get should succeed");
assert_eq!(
result.as_ref().map(|b| b.to_vec()),
Some(v2.to_vec()),
"last write should win"
);
}
#[test]
fn test_transaction_empty_drop_no_warn() {
let client = offline_client();
let tx = Transaction::new(client, "noop");
drop(tx); }
#[tracing_test::traced_test]
#[test]
fn test_transaction_drop_warns_uncommitted() {
let client = offline_client();
let mut tx = Transaction::new(client, "events");
let key = Key::from_str("pending");
let val = CipherBlob::new(vec![0xFF]);
tx.set(key, val).expect("set should succeed");
drop(tx);
assert!(
logs_contain("Transaction dropped with uncommitted operation(s)"),
"expected a tracing warn about uncommitted ops"
);
}
}