use std::collections::BTreeMap;
use crate::db::Emdb;
use crate::storage::{Op, OpRef, Storage};
#[cfg(feature = "ttl")]
use crate::ttl::{expires_from_ttl, is_expired, now_unix_millis, record_expires_at, Ttl};
use crate::ttl::{record_new, record_value, Record};
use crate::{Error, Result};
pub struct Transaction<'db> {
db: &'db Emdb,
pending: Vec<Op>,
overlay: BTreeMap<Vec<u8>, Option<Record>>,
}
impl<'db> Transaction<'db> {
pub(crate) fn new(db: &'db Emdb) -> Result<Self> {
Ok(Self {
db,
pending: Vec::new(),
overlay: BTreeMap::new(),
})
}
pub fn insert(&mut self, key: impl Into<Vec<u8>>, value: impl Into<Vec<u8>>) -> Result<()> {
#[cfg(feature = "ttl")]
{
self.insert_with_ttl(key, value, Ttl::Default)
}
#[cfg(not(feature = "ttl"))]
{
let key = key.into();
let value = value.into();
let _old = self
.overlay
.insert(key.clone(), Some(record_new(value.clone(), None)));
self.pending.push(Op::Insert {
key,
value,
expires_at: None,
});
Ok(())
}
}
pub fn get(&self, key: impl AsRef<[u8]>) -> Result<Option<Vec<u8>>> {
let key = key.as_ref();
if let Some(entry) = self.overlay.get(key) {
return self.visible_record(entry.as_ref());
}
let shard = self.db.shard_for(key)?;
let Some(record) = shard.get(key) else {
return Ok(None);
};
let cloned = record.clone();
drop(shard);
self.visible_record(Some(&cloned))
}
pub fn remove(&mut self, key: impl AsRef<[u8]>) -> Result<Option<Vec<u8>>> {
let key_vec = key.as_ref().to_vec();
let previous = self.get(key.as_ref())?;
let _old = self.overlay.insert(key_vec.clone(), None);
self.pending.push(Op::Remove { key: key_vec });
Ok(previous)
}
pub fn contains_key(&self, key: impl AsRef<[u8]>) -> Result<bool> {
Ok(self.get(key)?.is_some())
}
#[cfg(feature = "ttl")]
pub fn insert_with_ttl(
&mut self,
key: impl Into<Vec<u8>>,
value: impl Into<Vec<u8>>,
ttl: Ttl,
) -> Result<()> {
let key = key.into();
let value = value.into();
let now = now_unix_millis();
let expires_at = expires_from_ttl(ttl, self.db.inner.config.default_ttl, now)?;
let _old = self
.overlay
.insert(key.clone(), Some(record_new(value.clone(), expires_at)));
self.pending.push(Op::Insert {
key,
value,
expires_at,
});
Ok(())
}
pub(crate) fn commit(&mut self) -> Result<()> {
let writes = std::mem::take(&mut self.pending);
let updates = std::mem::take(&mut self.overlay);
if writes.is_empty() && updates.is_empty() {
return Ok(());
}
let op_count = u32::try_from(writes.len())
.map_err(|_overflow| Error::TransactionAborted("operation count overflow"))?;
let tx_id = self.db.next_tx_id()?;
let mut backend_guard = self.db.lock_backend()?;
if let Some(backend) = backend_guard.as_mut() {
backend.append(OpRef::BatchBegin { tx_id, op_count })?;
for op in &writes {
backend.append(OpRef::from(op))?;
}
backend.append(OpRef::BatchEnd { tx_id })?;
backend.set_last_tx_id(tx_id)?;
}
let mut shards = self.db.index().write_all()?;
for (key, maybe_record) in updates {
let shard_idx = crate::index::Index::shard_for_key(&key);
let shard = match shards.get_mut(shard_idx) {
Some(shard) => shard,
None => return Err(Error::TransactionAborted("shard index out of range")),
};
match maybe_record {
Some(record) => {
let _old = shard.insert(key, record);
}
None => {
let _old = shard.remove(&key);
}
}
}
drop(shards);
drop(backend_guard);
Ok(())
}
fn visible_record(&self, maybe_record: Option<&Record>) -> Result<Option<Vec<u8>>> {
let Some(record) = maybe_record else {
return Ok(None);
};
#[cfg(feature = "ttl")]
{
let now = now_unix_millis();
if is_expired(record_expires_at(record), now) {
return Ok(None);
}
}
Ok(Some(record_value(record).to_vec()))
}
}
#[cfg(test)]
mod tests {
use crate::Emdb;
#[test]
fn transaction_commit_applies_overlay() {
let db = Emdb::open_in_memory();
let result = db.transaction(|tx| {
tx.insert("a", "1")?;
tx.insert("b", "2")?;
Ok(())
});
assert!(result.is_ok());
assert!(matches!(db.get("a"), Ok(Some(v)) if v == b"1".to_vec()));
assert!(matches!(db.get("b"), Ok(Some(v)) if v == b"2".to_vec()));
}
#[test]
fn transaction_rollback_discards_overlay() {
let db = Emdb::open_in_memory();
let result = db.transaction::<_, ()>(|tx| {
tx.insert("a", "1")?;
Err(crate::Error::TransactionAborted("rollback"))
});
assert!(result.is_err());
assert!(matches!(db.get("a"), Ok(None)));
}
#[test]
fn transaction_remove_reads_from_overlay() {
let db = Emdb::open_in_memory();
assert!(db.insert("a", "1").is_ok());
let result = db.transaction(|tx| {
let removed = tx.remove("a")?;
assert!(matches!(removed, Some(v) if v == b"1".to_vec()));
assert!(matches!(tx.get("a"), Ok(None)));
Ok(())
});
assert!(result.is_ok());
assert!(matches!(db.get("a"), Ok(None)));
}
}