use crate::{
AzothDb, AzothError, CanonicalReadTxn, CanonicalStore, CanonicalTxn, CommitInfo, EventId,
Result, TypedValue,
};
use azoth_lmdb::preflight_cache::{CachedValue, PreflightCache};
use std::collections::HashSet;
use std::sync::Arc;
pub const MAX_DECLARED_KEYS: usize = 10_000;
fn check_key_limit(count: usize) -> Result<()> {
if count > MAX_DECLARED_KEYS {
return Err(AzothError::Config(format!(
"Transaction declares {} keys, which exceeds the maximum of {}. \
Consider batching operations or increasing MAX_DECLARED_KEYS.",
count, MAX_DECLARED_KEYS
)));
}
Ok(())
}
pub struct AsyncTransaction {
db: Arc<AzothDb>,
declared_keys: HashSet<Vec<u8>>,
validators: Vec<PreflightValidator>,
}
type PreflightValidator = Box<dyn FnOnce(&PreflightContext) -> Result<()> + Send + 'static>;
impl AsyncTransaction {
pub fn new(db: Arc<AzothDb>) -> Self {
Self {
db,
declared_keys: HashSet::new(),
validators: Vec::new(),
}
}
pub fn keys(mut self, keys: Vec<Vec<u8>>) -> Self {
self.declared_keys.extend(keys);
self
}
pub fn validate<F>(mut self, f: F) -> Self
where
F: FnOnce(&PreflightContext) -> Result<()> + Send + 'static,
{
self.validators.push(Box::new(f));
self
}
pub async fn execute<F>(self, f: F) -> Result<CommitInfo>
where
F: FnOnce(&mut TransactionContext<'_>) -> Result<()> + Send + 'static,
{
let db = self.db;
let declared_keys = self.declared_keys;
let validators = self.validators;
check_key_limit(declared_keys.len())?;
tokio::task::spawn_blocking(move || {
let lock_manager = db.canonical().lock_manager();
let keys_vec: Vec<&[u8]> = declared_keys.iter().map(|k| k.as_slice()).collect();
let _locks = lock_manager.acquire_keys(&keys_vec)?;
let cache = db.canonical().preflight_cache();
let ctx = PreflightContext::new(&db, cache, &declared_keys);
for validator in validators {
validator(&ctx)?;
}
let txn = db.canonical().write_txn()?;
let mut update_ctx = TransactionContext {
txn,
declared_keys: &declared_keys,
value_cache: std::cell::RefCell::new(std::collections::HashMap::new()),
};
f(&mut update_ctx)?;
let commit_info = update_ctx.txn.commit()?;
let keys_to_invalidate: Vec<Vec<u8>> = declared_keys.iter().cloned().collect();
cache.invalidate_keys(&keys_to_invalidate);
Ok(commit_info)
})
.await
.map_err(|e| AzothError::Internal(format!("Transaction task failed: {}", e)))?
}
}
pub struct PreflightContext<'a> {
db: &'a AzothDb,
cache: &'a Arc<PreflightCache>,
declared_keys: &'a HashSet<Vec<u8>>,
}
impl<'a> PreflightContext<'a> {
fn new(
db: &'a AzothDb,
cache: &'a Arc<PreflightCache>,
declared_keys: &'a HashSet<Vec<u8>>,
) -> Self {
Self {
db,
cache,
declared_keys,
}
}
fn check_key_declared(&self, key: &[u8]) -> Result<()> {
if !self.declared_keys.contains(key) {
return Err(AzothError::UndeclaredKeyAccess {
key: String::from_utf8_lossy(key).to_string(),
});
}
Ok(())
}
pub fn get(&self, key: &[u8]) -> Result<TypedValue> {
self.check_key_declared(key)?;
if let Some(cached) = self.cache.get(key) {
match cached {
CachedValue::Some(bytes) => return TypedValue::from_bytes(&bytes),
CachedValue::None => {
return Err(AzothError::InvalidState("Key does not exist".into()))
}
}
}
let txn = self.db.canonical().read_txn()?;
match txn.get_state(key)? {
Some(bytes) => {
self.cache
.insert(key.to_vec(), CachedValue::Some(bytes.clone()));
TypedValue::from_bytes(&bytes)
}
None => {
self.cache.insert(key.to_vec(), CachedValue::None);
Err(AzothError::InvalidState("Key does not exist".into()))
}
}
}
pub fn get_opt(&self, key: &[u8]) -> Result<Option<TypedValue>> {
self.check_key_declared(key)?;
if let Some(cached) = self.cache.get(key) {
match cached {
CachedValue::Some(bytes) => return Ok(Some(TypedValue::from_bytes(&bytes)?)),
CachedValue::None => return Ok(None),
}
}
let txn = self.db.canonical().read_txn()?;
match txn.get_state(key)? {
Some(bytes) => {
self.cache
.insert(key.to_vec(), CachedValue::Some(bytes.clone()));
Ok(Some(TypedValue::from_bytes(&bytes)?))
}
None => {
self.cache.insert(key.to_vec(), CachedValue::None);
Ok(None)
}
}
}
pub fn exists(&self, key: &[u8]) -> Result<bool> {
self.check_key_declared(key)?;
if let Some(cached) = self.cache.get(key) {
match cached {
CachedValue::Some(_) => return Ok(true),
CachedValue::None => return Ok(false),
}
}
let txn = self.db.canonical().read_txn()?;
match txn.get_state(key)? {
Some(bytes) => {
self.cache.insert(key.to_vec(), CachedValue::Some(bytes));
Ok(true)
}
None => {
self.cache.insert(key.to_vec(), CachedValue::None);
Ok(false)
}
}
}
}
pub struct TransactionContext<'a> {
txn: <crate::LmdbCanonicalStore as CanonicalStore>::Txn<'a>,
declared_keys: &'a HashSet<Vec<u8>>,
value_cache: std::cell::RefCell<std::collections::HashMap<Vec<u8>, TypedValue>>,
}
impl<'a> TransactionContext<'a> {
fn check_key_declared(&self, key: &[u8]) -> Result<()> {
if !self.declared_keys.contains(key) {
return Err(AzothError::UndeclaredKeyAccess {
key: String::from_utf8_lossy(key).to_string(),
});
}
Ok(())
}
pub fn get(&self, key: &[u8]) -> Result<TypedValue> {
self.check_key_declared(key)?;
{
let cache = self.value_cache.borrow();
if let Some(cached) = cache.get(key) {
return Ok(cached.clone());
}
}
match self.txn.get_state(key)? {
Some(bytes) => {
let value = TypedValue::from_bytes(&bytes)?;
self.value_cache
.borrow_mut()
.insert(key.to_vec(), value.clone());
Ok(value)
}
None => Err(AzothError::InvalidState("Key does not exist".into())),
}
}
pub fn get_opt(&self, key: &[u8]) -> Result<Option<TypedValue>> {
self.check_key_declared(key)?;
{
let cache = self.value_cache.borrow();
if let Some(cached) = cache.get(key) {
return Ok(Some(cached.clone()));
}
}
match self.txn.get_state(key)? {
Some(bytes) => {
let value = TypedValue::from_bytes(&bytes)?;
self.value_cache
.borrow_mut()
.insert(key.to_vec(), value.clone());
Ok(Some(value))
}
None => Ok(None),
}
}
pub fn set(&mut self, key: &[u8], value: &TypedValue) -> Result<()> {
self.check_key_declared(key)?;
let bytes = value.to_bytes()?;
self.txn.put_state(key, &bytes)?;
self.value_cache
.borrow_mut()
.insert(key.to_vec(), value.clone());
Ok(())
}
pub fn delete(&mut self, key: &[u8]) -> Result<()> {
self.check_key_declared(key)?;
self.txn.del_state(key)?;
self.value_cache.borrow_mut().remove(key);
Ok(())
}
pub fn exists(&self, key: &[u8]) -> Result<bool> {
self.check_key_declared(key)?;
Ok(self.txn.get_state(key)?.is_some())
}
pub fn update<F>(&mut self, key: &[u8], f: F) -> Result<()>
where
F: FnOnce(Option<TypedValue>) -> Result<TypedValue>,
{
self.check_key_declared(key)?;
let old = self.get_opt(key)?;
let new = f(old)?;
self.set(key, &new)
}
pub fn log<T: serde::Serialize>(&mut self, event_type: &str, payload: &T) -> Result<EventId> {
let json =
serde_json::to_string(payload).map_err(|e| AzothError::Serialization(e.to_string()))?;
let event = format!("{}:{}", event_type, json);
self.txn.append_event(event.as_bytes())
}
pub fn log_many<T: serde::Serialize>(
&mut self,
events: &[(&str, T)],
) -> Result<(EventId, EventId)> {
let encoded: Vec<Vec<u8>> = events
.iter()
.map(|(event_type, payload)| {
let json = serde_json::to_string(payload)
.map_err(|e| AzothError::Serialization(e.to_string()))?;
Ok(format!("{}:{}", event_type, json).into_bytes())
})
.collect::<Result<Vec<Vec<u8>>>>()?;
self.txn.append_events(&encoded)
}
pub fn log_bytes(&mut self, event: &[u8]) -> Result<EventId> {
self.txn.append_event(event)
}
pub fn iter_state(&self) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
self.txn.iter_state()
}
}
#[allow(clippy::type_complexity)]
pub struct Transaction<'a> {
db: &'a AzothDb,
declared_keys: HashSet<Vec<u8>>,
validators: Vec<Box<dyn FnOnce(&PreflightContext) -> Result<()> + 'a>>,
}
impl<'a> Transaction<'a> {
pub fn new(db: &'a AzothDb) -> Self {
Self {
db,
declared_keys: HashSet::new(),
validators: Vec::new(),
}
}
pub fn keys(mut self, keys: Vec<Vec<u8>>) -> Self {
self.declared_keys.extend(keys);
self
}
pub fn validate<F>(mut self, f: F) -> Self
where
F: FnOnce(&PreflightContext) -> Result<()> + 'a,
{
self.validators.push(Box::new(f));
self
}
pub fn preflight<F>(self, f: F) -> Self
where
F: FnOnce(&PreflightContext) -> Result<()> + 'a,
{
self.validate(f)
}
pub fn require<F>(mut self, key: Vec<u8>, validator: F) -> Self
where
F: FnOnce(Option<TypedValue>) -> Result<()> + 'a,
{
self.declared_keys.insert(key.clone());
self.validators.push(Box::new(move |ctx| {
let value = ctx.get_opt(&key)?;
validator(value)
}));
self
}
pub fn require_exists(mut self, key: Vec<u8>) -> Self {
self.declared_keys.insert(key.clone());
self.preflight(move |ctx| {
if !ctx.exists(&key)? {
return Err(AzothError::PreflightFailed(format!(
"Key {:?} must exist",
String::from_utf8_lossy(&key)
)));
}
Ok(())
})
}
pub fn require_min(mut self, key: Vec<u8>, min: i64) -> Self {
self.declared_keys.insert(key.clone());
self.preflight(move |ctx| {
let value = ctx.get(&key)?.as_i64()?;
if value < min {
return Err(AzothError::PreflightFailed(format!(
"Value {} < minimum {}",
value, min
)));
}
Ok(())
})
}
pub fn require_max(mut self, key: Vec<u8>, max: i64) -> Self {
self.declared_keys.insert(key.clone());
self.preflight(move |ctx| {
let value = ctx.get(&key)?.as_i64()?;
if value > max {
return Err(AzothError::PreflightFailed(format!(
"Value {} > maximum {}",
value, max
)));
}
Ok(())
})
}
pub fn execute<F>(self, f: F) -> Result<CommitInfo>
where
F: for<'b> FnOnce(&mut TransactionContext<'b>) -> Result<()>,
{
if let Ok(_handle) = tokio::runtime::Handle::try_current() {
let msg = "Transaction::execute() called from async context! \
This can cause deadlocks. Use AsyncTransaction instead.";
#[cfg(debug_assertions)]
{
panic!("{}", msg);
}
#[cfg(not(debug_assertions))]
{
tracing::error!("{}", msg);
}
}
check_key_limit(self.declared_keys.len())?;
let lock_manager = self.db.canonical().lock_manager();
let keys_vec: Vec<&[u8]> = self.declared_keys.iter().map(|k| k.as_slice()).collect();
let _locks = lock_manager.acquire_keys(&keys_vec)?;
let cache = self.db.canonical().preflight_cache();
let ctx = PreflightContext::new(self.db, cache, &self.declared_keys);
for validator in self.validators {
validator(&ctx)?;
}
let txn = self.db.canonical().write_txn()?;
let mut update_ctx = TransactionContext {
txn,
declared_keys: &self.declared_keys,
value_cache: std::cell::RefCell::new(std::collections::HashMap::new()),
};
f(&mut update_ctx)?;
let commit_info = update_ctx.txn.commit()?;
let keys_to_invalidate: Vec<Vec<u8>> = self.declared_keys.iter().cloned().collect();
cache.invalidate_keys(&keys_to_invalidate);
Ok(commit_info)
}
pub fn execute_blocking<F>(self, f: F) -> Result<CommitInfo>
where
F: for<'b> FnOnce(&mut TransactionContext<'b>) -> Result<()>,
{
check_key_limit(self.declared_keys.len())?;
let lock_manager = self.db.canonical().lock_manager();
let keys_vec: Vec<&[u8]> = self.declared_keys.iter().map(|k| k.as_slice()).collect();
let _locks = lock_manager.acquire_keys(&keys_vec)?;
let cache = self.db.canonical().preflight_cache();
let ctx = PreflightContext::new(self.db, cache, &self.declared_keys);
for validator in self.validators {
validator(&ctx)?;
}
let txn = self.db.canonical().write_txn()?;
let mut update_ctx = TransactionContext {
txn,
declared_keys: &self.declared_keys,
value_cache: std::cell::RefCell::new(std::collections::HashMap::new()),
};
f(&mut update_ctx)?;
let commit_info = update_ctx.txn.commit()?;
let keys_to_invalidate: Vec<Vec<u8>> = self.declared_keys.iter().cloned().collect();
cache.invalidate_keys(&keys_to_invalidate);
Ok(commit_info)
}
}
pub async fn execute_transaction_async<F>(
db: Arc<AzothDb>,
keys: Vec<Vec<u8>>,
f: F,
) -> Result<CommitInfo>
where
F: for<'b> FnOnce(&mut TransactionContext<'b>) -> Result<()> + Send + 'static,
{
tokio::task::spawn_blocking(move || Transaction::new(&db).keys(keys).execute(f))
.await
.map_err(|e| AzothError::Internal(format!("Transaction task failed: {}", e)))?
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn create_test_db() -> (AzothDb, TempDir) {
let temp_dir = tempfile::tempdir().unwrap();
let db = AzothDb::open(temp_dir.path()).unwrap();
(db, temp_dir)
}
#[test]
fn test_basic_transaction() {
let (db, _temp) = create_test_db();
let mut txn = db.canonical().write_txn().unwrap();
let balance_bytes = TypedValue::I64(100).to_bytes().unwrap();
txn.put_state(b"balance", &balance_bytes).unwrap();
txn.commit().unwrap();
let result = Transaction::new(&db)
.keys(vec![b"balance".to_vec()])
.preflight(|ctx| {
let balance = ctx.get(b"balance")?.as_i64()?;
assert_eq!(balance, 100);
Ok(())
})
.execute(|ctx| {
let balance = ctx.get(b"balance")?.as_i64()?;
ctx.set(b"balance", &TypedValue::I64(balance - 50))?;
ctx.log_bytes(b"withdraw:50")?;
Ok(())
});
if let Err(e) = &result {
eprintln!("Transaction failed: {:?}", e);
}
assert!(result.is_ok());
let txn = db.canonical().read_txn().unwrap();
let balance_bytes = txn.get_state(b"balance").unwrap().unwrap();
let balance = TypedValue::from_bytes(&balance_bytes)
.unwrap()
.as_i64()
.unwrap();
assert_eq!(balance, 50);
}
#[test]
fn test_preflight_failure() {
let (db, _temp) = create_test_db();
let mut txn = db.canonical().write_txn().unwrap();
let balance_bytes = TypedValue::I64(10).to_bytes().unwrap();
txn.put_state(b"balance", &balance_bytes).unwrap();
txn.commit().unwrap();
let result = Transaction::new(&db)
.require_min(b"balance".to_vec(), 50)
.execute(|ctx| {
ctx.set(b"balance", &TypedValue::I64(0))?;
Ok(())
});
assert!(result.is_err());
let txn = db.canonical().read_txn().unwrap();
let balance_bytes = txn.get_state(b"balance").unwrap().unwrap();
let balance = TypedValue::from_bytes(&balance_bytes)
.unwrap()
.as_i64()
.unwrap();
assert_eq!(balance, 10);
}
#[test]
fn test_multiple_constraints() {
let (db, _temp) = create_test_db();
let mut txn = db.canonical().write_txn().unwrap();
let balance_bytes = TypedValue::I64(75).to_bytes().unwrap();
txn.put_state(b"balance", &balance_bytes).unwrap();
txn.commit().unwrap();
let result = Transaction::new(&db)
.require_exists(b"balance".to_vec())
.require_min(b"balance".to_vec(), 50)
.require_max(b"balance".to_vec(), 100)
.execute(|ctx| {
ctx.set(b"balance", &TypedValue::I64(60))?;
Ok(())
});
assert!(result.is_ok());
}
#[test]
fn test_multiple_events() {
let (db, _temp) = create_test_db();
let result = Transaction::new(&db).execute(|ctx| {
ctx.log_bytes(b"event1")?;
ctx.log_bytes(b"event2")?;
ctx.log_bytes(b"event3")?;
Ok(())
});
let commit_info = result.unwrap();
assert_eq!(commit_info.events_written, 3);
}
#[test]
fn test_undeclared_key_access_in_preflight() {
let (db, _temp) = create_test_db();
let mut txn = db.canonical().write_txn().unwrap();
txn.put_state(b"key1", &TypedValue::I64(1).to_bytes().unwrap())
.unwrap();
txn.put_state(b"key2", &TypedValue::I64(2).to_bytes().unwrap())
.unwrap();
txn.commit().unwrap();
let result = Transaction::new(&db)
.keys(vec![b"key1".to_vec()])
.validate(|ctx| {
let _ = ctx.get(b"key1")?; let _ = ctx.get(b"key2")?; Ok(())
})
.execute(|_ctx| Ok(()));
assert!(matches!(
result,
Err(AzothError::UndeclaredKeyAccess { .. })
));
}
#[test]
fn test_undeclared_key_access_in_execute() {
let (db, _temp) = create_test_db();
let mut txn = db.canonical().write_txn().unwrap();
txn.put_state(b"key1", &TypedValue::I64(1).to_bytes().unwrap())
.unwrap();
txn.put_state(b"key2", &TypedValue::I64(2).to_bytes().unwrap())
.unwrap();
txn.commit().unwrap();
let result = Transaction::new(&db)
.keys(vec![b"key1".to_vec()])
.execute(|ctx| {
let _ = ctx.get(b"key1")?; ctx.set(b"key2", &TypedValue::I64(99))?; Ok(())
});
assert!(matches!(
result,
Err(AzothError::UndeclaredKeyAccess { .. })
));
}
#[test]
fn test_require_auto_declares_key() {
let (db, _temp) = create_test_db();
let mut txn = db.canonical().write_txn().unwrap();
txn.put_state(b"balance", &TypedValue::I64(100).to_bytes().unwrap())
.unwrap();
txn.commit().unwrap();
let result = Transaction::new(&db)
.require_min(b"balance".to_vec(), 50)
.execute(|ctx| {
ctx.set(b"balance", &TypedValue::I64(75))?;
Ok(())
});
assert!(result.is_ok());
}
#[test]
fn test_multi_key_transaction() {
let (db, _temp) = create_test_db();
let mut txn = db.canonical().write_txn().unwrap();
txn.put_state(b"account_a", &TypedValue::I64(1000).to_bytes().unwrap())
.unwrap();
txn.put_state(b"account_b", &TypedValue::I64(500).to_bytes().unwrap())
.unwrap();
txn.commit().unwrap();
let result = Transaction::new(&db)
.keys(vec![b"account_b".to_vec(), b"account_a".to_vec()]) .validate(|ctx| {
let balance = ctx.get(b"account_a")?.as_i64()?;
if balance < 100 {
return Err(AzothError::PreflightFailed("Insufficient funds".into()));
}
Ok(())
})
.execute(|ctx| {
let a = ctx.get(b"account_a")?.as_i64()?;
let b = ctx.get(b"account_b")?.as_i64()?;
ctx.set(b"account_a", &TypedValue::I64(a - 100))?;
ctx.set(b"account_b", &TypedValue::I64(b + 100))?;
ctx.log(
"transfer",
&serde_json::json!({"from": "a", "to": "b", "amount": 100}),
)?;
Ok(())
});
assert!(result.is_ok());
let txn = db.canonical().read_txn().unwrap();
let a = TypedValue::from_bytes(&txn.get_state(b"account_a").unwrap().unwrap())
.unwrap()
.as_i64()
.unwrap();
let b = TypedValue::from_bytes(&txn.get_state(b"account_b").unwrap().unwrap())
.unwrap()
.as_i64()
.unwrap();
assert_eq!(a, 900);
assert_eq!(b, 600);
}
}