use std::collections::BTreeMap;
use crate::db::Emdb;
use crate::storage::Op;
#[cfg(feature = "ttl")]
use crate::ttl::{expires_from_ttl, is_expired, now_unix_millis, record_expires_at, Ttl};
use crate::ttl::{record_new, record_value, Record};
use crate::{Error, Result};
pub struct Transaction<'db> {
pub(crate) db: &'db mut Emdb,
pending: Vec<Op>,
overlay: BTreeMap<Vec<u8>, Option<Record>>,
}
impl<'db> Transaction<'db> {
pub(crate) fn new(db: &'db mut Emdb) -> Self {
Self {
db,
pending: Vec::new(),
overlay: BTreeMap::new(),
}
}
pub fn insert(&mut self, key: impl Into<Vec<u8>>, value: impl Into<Vec<u8>>) -> Result<()> {
#[cfg(feature = "ttl")]
{
self.insert_with_ttl(key, value, Ttl::Default)
}
#[cfg(not(feature = "ttl"))]
{
let key = key.into();
let value = value.into();
let _old = self
.overlay
.insert(key.clone(), Some(record_new(value.clone(), None)));
self.pending.push(Op::Insert {
key,
value,
expires_at: None,
});
Ok(())
}
}
pub fn get(&self, key: impl AsRef<[u8]>) -> Result<Option<Vec<u8>>> {
if let Some(entry) = self.overlay.get(key.as_ref()) {
return self.visible_record(entry.as_ref());
}
let Some(record) = self.db.storage.get(key.as_ref()) else {
return Ok(None);
};
self.visible_record(Some(record))
}
pub fn remove(&mut self, key: impl AsRef<[u8]>) -> Result<Option<Vec<u8>>> {
let key_vec = key.as_ref().to_vec();
let previous = if let Some(entry) = self.overlay.get(key.as_ref()) {
self.visible_record(entry.as_ref())?
} else {
let base = self.db.storage.get(key.as_ref());
self.visible_record(base)?
};
let _old = self.overlay.insert(key_vec.clone(), None);
self.pending.push(Op::Remove { key: key_vec });
Ok(previous)
}
pub fn contains_key(&self, key: impl AsRef<[u8]>) -> Result<bool> {
Ok(self.get(key)?.is_some())
}
#[cfg(feature = "ttl")]
pub fn insert_with_ttl(
&mut self,
key: impl Into<Vec<u8>>,
value: impl Into<Vec<u8>>,
ttl: Ttl,
) -> Result<()> {
let key = key.into();
let value = value.into();
let now = now_unix_millis();
let expires_at = expires_from_ttl(ttl, self.db.default_ttl, now)?;
let _old = self
.overlay
.insert(key.clone(), Some(record_new(value.clone(), expires_at)));
self.pending.push(Op::Insert {
key,
value,
expires_at,
});
Ok(())
}
pub(crate) fn commit(&mut self) -> Result<()> {
let tx_id = self
.db
.last_tx_id
.checked_add(1)
.ok_or(Error::TransactionAborted("transaction id overflow"))?;
let op_count = u32::try_from(self.pending.len())
.map_err(|_overflow| Error::TransactionAborted("operation count overflow"))?;
self.db
.backend
.append(&Op::BatchBegin { tx_id, op_count })?;
let writes = std::mem::take(&mut self.pending);
for op in writes {
self.db.backend.append(&op)?;
}
self.db.backend.append(&Op::BatchEnd { tx_id })?;
self.db.backend.set_last_tx_id(tx_id)?;
let updates = std::mem::take(&mut self.overlay);
for (key, maybe_record) in updates {
match maybe_record {
Some(record) => {
let _old = self.db.storage.insert(key, record);
}
None => {
let _old = self.db.storage.remove(&key);
}
}
}
self.db.last_tx_id = tx_id;
Ok(())
}
fn visible_record(&self, maybe_record: Option<&Record>) -> Result<Option<Vec<u8>>> {
let Some(record) = maybe_record else {
return Ok(None);
};
#[cfg(feature = "ttl")]
{
let now = now_unix_millis();
if is_expired(record_expires_at(record), now) {
return Ok(None);
}
}
Ok(Some(record_value(record).to_vec()))
}
}
#[cfg(test)]
mod tests {
use crate::Emdb;
#[test]
fn transaction_commit_applies_overlay() {
let mut db = Emdb::open_in_memory();
let result = db.transaction(|tx| {
tx.insert("a", "1")?;
tx.insert("b", "2")?;
Ok(())
});
assert!(result.is_ok());
assert!(matches!(db.get("a"), Ok(Some(v)) if v == b"1".to_vec()));
assert!(matches!(db.get("b"), Ok(Some(v)) if v == b"2".to_vec()));
}
#[test]
fn transaction_rollback_discards_overlay() {
let mut db = Emdb::open_in_memory();
let result = db.transaction::<_, ()>(|tx| {
tx.insert("a", "1")?;
Err(crate::Error::TransactionAborted("rollback"))
});
assert!(result.is_err());
assert!(matches!(db.get("a"), Ok(None)));
}
#[test]
fn transaction_remove_reads_from_overlay() {
let mut db = Emdb::open_in_memory();
assert!(db.insert("a", "1").is_ok());
let result = db.transaction(|tx| {
let removed = tx.remove("a")?;
assert!(matches!(removed, Some(v) if v == b"1".to_vec()));
assert!(matches!(tx.get("a"), Ok(None)));
Ok(())
});
assert!(result.is_ok());
assert!(matches!(db.get("a"), Ok(None)));
}
}