use super::ValRef;
use crate::{
OpCode,
cc::{
cc::ConcurrencyControl,
context::{CCNode, Context, TxOutcome},
group::TxnState,
wal::{WalDel, WalPut, WalReplace},
},
index::tree::{Iter, Tree},
map::flow::ForegroundWritePermit,
types::data::{Key, Record, Ver},
utils::{
Handle, NULL_CMD,
data::Position,
observe::{
CounterMetric, EventKind, HistogramMetric, LATENCY_SAMPLE_SHIFT, ObserveEvent,
observe_elapsed, sampled_instant,
},
},
};
use crossbeam_epoch::Guard;
use std::cell::{Cell, UnsafeCell};
use std::ops::RangeBounds;
use std::sync::atomic::Ordering::Relaxed;
fn get_impl<K: AsRef<[u8]>>(
ctx: &Context,
cc: &ConcurrencyControl,
tree: &Tree,
group_id: u8,
start_ts: u64,
k: K,
) -> Result<ValRef, OpCode> {
#[cfg(feature = "extra_check")]
assert!(!k.as_ref().is_empty(), "key must be non-empty");
let g = crossbeam_epoch::pin();
let key = Key::new(k.as_ref(), Ver::new(start_ts, NULL_CMD));
let r = tree.traverse(&g, key, |txid, record_gid| {
cc.is_visible_to(ctx, group_id, record_gid, start_ts, txid)
})?;
Ok(r)
}
fn seek_impl<'a, K>(
cc: &'a ConcurrencyControl,
tree: &'a Tree,
group_id: u8,
start_ts: u64,
prefix: K,
) -> Iter<'a>
where
K: AsRef<[u8]>,
{
let b = prefix.as_ref();
#[cfg(feature = "extra_check")]
assert!(!b.is_empty(), "prefix can't be empty");
let upper = prefix_upper_exclusive(b);
if let Some(ref upper) = upper {
tree.range(b..upper.as_slice(), move |ctx, txid, record_gid| {
cc.is_visible_to(ctx, group_id, record_gid, start_ts, txid)
})
} else {
tree.range(b.., move |ctx, txid, record_gid| {
cc.is_visible_to(ctx, group_id, record_gid, start_ts, txid)
})
}
}
fn range_impl<'a, K, R>(
cc: &'a ConcurrencyControl,
tree: &'a Tree,
group_id: u8,
start_ts: u64,
range: R,
) -> Iter<'a>
where
K: AsRef<[u8]>,
R: RangeBounds<K>,
{
tree.range(range, move |ctx, txid, record_gid| {
cc.is_visible_to(ctx, group_id, record_gid, start_ts, txid)
})
}
fn prefix_upper_exclusive(prefix: &[u8]) -> Option<Vec<u8>> {
let mut upper = prefix.to_vec();
for i in (0..upper.len()).rev() {
if upper[i] != u8::MAX {
upper[i] += 1;
upper.truncate(i + 1);
return Some(upper);
}
}
None
}
pub struct TxnKV<'a> {
ctx: &'a Context,
state: UnsafeCell<TxnState>,
tree: &'a Tree,
bucket_id: u64,
is_end: Cell<bool>,
limit: usize,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum FailCause {
Aborted,
Conflict,
}
impl<'a> TxnKV<'a> {
pub(crate) fn new(ctx: &'a Context, tree: &'a Tree) -> Result<Self, OpCode> {
let start_ts = ctx.alloc_oracle();
let gid = ctx.next_group_id();
let g = ctx.group(gid);
let start_ckpt = g.ckpt_cnt.load(Relaxed);
let mut state = TxnState::new(gid, start_ts, start_ckpt);
let bucket_id = tree.bucket_id();
let max_ckpt_per_txn = tree.store.opt.max_ckpt_per_txn;
tree.bucket.state.inc_txn_ref();
let begin_res = {
let mut log = g.logging.lock();
log.record_begin(start_ts).map(|lsn| {
state.begin_lsn = lsn;
state.prev_lsn = lsn;
g.active_txns.insert(start_ts, state.prev_lsn);
})
};
if let Err(e) = begin_res {
g.leave_inflight();
tree.bucket.state.dec_txn_ref();
return Err(e);
}
ctx.opt.observer.counter(CounterMetric::TxnBegin, 1);
g.cc.commit_tree.compact(ctx);
Ok(Self {
ctx,
state: UnsafeCell::new(state),
tree,
bucket_id,
is_end: Cell::new(false),
limit: max_ckpt_per_txn,
})
}
fn should_abort(&self) -> Result<(), OpCode> {
let state = self.state_ref();
let g = self.ctx.group(state.group_id);
if self.is_end.get() || g.ckpt_cnt.load(Relaxed) - state.start_ckpt >= self.limit {
return Err(OpCode::AbortTx);
}
Ok(())
}
#[inline]
fn state_ref(&self) -> &TxnState {
unsafe { &*self.state.get() }
}
#[inline]
#[allow(clippy::mut_from_ref)]
fn state_mut(&self) -> &mut TxnState {
unsafe { &mut *self.state.get() }
}
#[inline]
fn observe_counter(&self, metric: CounterMetric, delta: u64) {
self.ctx.opt.observer.counter(metric, delta);
}
#[inline]
fn observe_event(&self, event: ObserveEvent) {
self.ctx.opt.observer.event(event);
}
#[inline]
fn before_write_budget(&self, estimated_bytes: usize) -> ForegroundWritePermit {
self.tree
.bucket
.before_foreground_write(estimated_bytes as u64)
}
#[inline]
fn conflict_abort(&self, txid: u64) -> OpCode {
self.observe_counter(CounterMetric::TxnConflictAbort, 1);
self.observe_event(ObserveEvent {
kind: EventKind::TxnConflictAbort,
bucket_id: self.bucket_id,
txid,
file_id: 0,
value: 0,
});
OpCode::AbortTx
}
#[inline]
fn write_abort(&self, start_ts: u64, cause: FailCause) -> OpCode {
match cause {
FailCause::Aborted => OpCode::AbortTx,
FailCause::Conflict => self.conflict_abort(start_ts),
}
}
#[inline]
fn is_visible_for_write(
&self,
group_id: usize,
start_ts: u64,
txid: u64,
record_gid: u8,
) -> bool {
self.ctx.group(group_id).cc.is_visible_to(
self.ctx,
group_id as u8,
record_gid,
start_ts,
txid,
)
}
fn resolve_latest_for_write(
&self,
opt: &Option<(Key, ValRef)>,
state: &TxnState,
) -> Result<Option<ValRef>, FailCause> {
let Some((rk, rv)) = opt else {
return Ok(None);
};
let gid = state.group_id;
let start_ts = state.start_ts;
if self.is_visible_for_write(gid, start_ts, rk.txid, rv.group_id()) {
return Ok(Some(rv.clone()));
}
if self.ctx.is_aborted(rk.txid) && self.ctx.get_aborted(rk.txid) == Some(TxOutcome::Aborted)
{
return Err(FailCause::Aborted);
}
Err(FailCause::Conflict)
}
fn clean_aborted(&self, g: &Guard, raw: &[u8]) -> Result<bool, OpCode> {
let latest = match self
.tree
.get(g, Key::new(raw, Ver::new(u64::MAX, u32::MAX)))
{
Ok((k, _)) => Some(k),
Err(OpCode::NotFound | OpCode::Again) => None,
Err(e) => return Err(e),
};
let Some(k) = latest else {
return Ok(false);
};
if self.ctx.get_aborted(k.ver().txid) != Some(TxOutcome::Aborted) {
return Ok(false);
}
match self.tree.remove_aborted_head(g, raw, k.ver().txid) {
Ok(true) => {
g.flush();
Ok(true)
}
Ok(false) => Ok(false),
Err(OpCode::Again) => {
g.flush();
Ok(false)
}
Err(e) => Err(e),
}
}
fn modify<F>(
&self,
k: &[u8],
v: &[u8],
estimated_bytes: usize,
mut f: F,
) -> Result<Option<ValRef>, OpCode>
where
F: FnMut(
&Option<(Key, ValRef)>,
Ver,
&mut TxnState,
&mut FailCause,
) -> Result<(u8, Position), OpCode>,
{
#[cfg(feature = "extra_check")]
assert!(!k.as_ref().is_empty(), "key must be non-empty");
loop {
self.should_abort()?;
let g = crossbeam_epoch::pin();
let state = self.state_mut();
let start_ts = state.start_ts;
let gid = state.group_id;
let cmd_id_val = state.cmd_id;
state.cmd_id += 1;
let key = Key::new(k, Ver::new(start_ts, cmd_id_val));
let val = Record::normal(gid as u8, v);
let _write_permit = self.before_write_budget(estimated_bytes);
let mut abort_cause = FailCause::Conflict;
let res = self.tree.update(&g, key, val, |opt| {
f(opt, *key.ver(), state, &mut abort_cause)
});
match res {
Err(OpCode::AbortTx) if abort_cause == FailCause::Aborted => {
let _ = self.clean_aborted(&g, k)?;
continue;
}
_ => return res,
}
}
}
fn put_impl(&self, k: &[u8], v: &[u8], logged: &mut bool) -> Result<(), OpCode> {
let estimated = k.len().saturating_add(v.len());
self.modify(k, v, estimated, |opt, ver, state, abort_cause| {
let gid = state.group_id;
let g = self.ctx.group(gid);
let current = match self.resolve_latest_for_write(opt, state) {
Ok(current) => current,
Err(cause) => {
*abort_cause = cause;
return Err(self.write_abort(state.start_ts, cause));
}
};
let r = match current {
None => Ok(()),
Some(current) => {
if current.is_put() {
Err(OpCode::Exist)
} else {
Ok(())
}
}
};
if r.is_ok() && !*logged {
*logged = true;
state.modified = true;
let mut log = g.logging.lock();
let new_pos = log.record_update(
&Key::new(k, ver),
WalPut::new(v.len()),
v,
state.prev_lsn,
self.bucket_id,
)?;
state.prev_lsn = new_pos;
}
r.map(|_| (gid as u8, state.prev_lsn))
})
.map(|_| ())
}
fn update_impl(&self, k: &[u8], v: &[u8], logged: &mut bool) -> Result<ValRef, OpCode> {
let estimated = k.len().saturating_add(v.len().saturating_mul(2));
let mut old_visible = None;
self.modify(k, v, estimated, |opt, ver, state, abort_cause| {
let gid = state.group_id;
let g = self.ctx.group(gid);
let current = match self.resolve_latest_for_write(opt, state) {
Ok(current) => current,
Err(cause) => {
*abort_cause = cause;
return Err(self.write_abort(state.start_ts, cause));
}
};
let Some(current) = current else {
return Err(OpCode::NotFound);
};
if current.is_del() {
return Err(OpCode::NotFound);
}
old_visible = Some(current.clone());
if !*logged {
state.modified = true;
*logged = true;
let mut log = g.logging.lock();
let new_pos = log.record_update(
&Key::new(k, ver),
WalReplace::new(v.len()),
v,
state.prev_lsn,
self.bucket_id,
)?;
state.prev_lsn = new_pos;
}
Ok((gid as u8, state.prev_lsn))
})
.map(|_| old_visible.expect("visible value must exist for update"))
}
pub fn put<K, V>(&self, k: K, v: V) -> Result<(), OpCode>
where
K: AsRef<[u8]>,
V: AsRef<[u8]>,
{
let mut logged = false;
self.put_impl(k.as_ref(), v.as_ref(), &mut logged)
}
pub fn update<K, V>(&self, k: K, v: V) -> Result<ValRef, OpCode>
where
K: AsRef<[u8]>,
V: AsRef<[u8]>,
{
let mut logged = false;
self.update_impl(k.as_ref(), v.as_ref(), &mut logged)
}
pub fn upsert<K, V>(&self, k: K, v: V) -> Result<Option<ValRef>, OpCode>
where
K: AsRef<[u8]>,
V: AsRef<[u8]>,
{
let mut logged = false;
let mut old_visible = None;
let (k, v) = (k.as_ref(), v.as_ref());
let estimated = k.len().saturating_add(v.len().saturating_mul(2));
self.modify(k, v, estimated, |opt, ver, state, abort_cause| {
let gid = state.group_id;
let g = self.ctx.group(gid);
let current = match self.resolve_latest_for_write(opt, state) {
Ok(current) => current,
Err(cause) => {
*abort_cause = cause;
return Err(self.write_abort(state.start_ts, cause));
}
};
old_visible = current.clone();
if !logged {
state.modified = true;
logged = true;
let mut log = g.logging.lock();
let new_pos = match current {
None => log.record_update(
&Key::new(k, ver),
WalPut::new(v.len()),
v,
state.prev_lsn,
self.bucket_id,
)?,
Some(_) => log.record_update(
&Key::new(k, ver),
WalReplace::new(v.len()),
v,
state.prev_lsn,
self.bucket_id,
)?,
};
state.prev_lsn = new_pos;
}
Ok((gid as u8, state.prev_lsn))
})
.map(|_| old_visible)
}
pub fn del<T>(&self, k: T) -> Result<ValRef, OpCode>
where
T: AsRef<[u8]>,
{
loop {
self.should_abort()?;
let state = self.state_mut();
let (gid, start_ts) = (state.group_id, state.start_ts);
let cmd_id_val = state.cmd_id;
state.cmd_id += 1;
let key = Key::new(k.as_ref(), Ver::new(start_ts, cmd_id_val));
let val = Record::remove(gid as u8);
let mut logged = false;
let mut old_visible = None;
let guard = crossbeam_epoch::pin();
let _write_permit = self.before_write_budget(key.raw.len());
let mut abort_cause = FailCause::Conflict;
let res = self.tree.update(&guard, key, val, |opt| {
let g = self.ctx.group(gid);
let current = match self.resolve_latest_for_write(opt, state) {
Ok(current) => current,
Err(cause) => {
abort_cause = cause;
return Err(self.write_abort(start_ts, cause));
}
};
let Some(current) = current else {
return Err(OpCode::NotFound);
};
if current.is_del() {
return Err(OpCode::NotFound);
}
old_visible = Some(current);
if !logged {
logged = true;
state.modified = true;
let mut log = g.logging.lock();
let new_pos = log.record_update(
&key,
WalDel::new(),
[].as_slice(),
state.prev_lsn,
self.bucket_id,
)?;
state.prev_lsn = new_pos;
}
Ok((gid as u8, state.prev_lsn))
});
match res {
Err(OpCode::AbortTx) if abort_cause == FailCause::Aborted => {
let _ = self.clean_aborted(&guard, key.raw)?;
continue;
}
_ => return res.map(|_| old_visible.expect("visible value must exist for delete")),
}
}
}
pub fn commit(self) -> Result<(), OpCode> {
self.should_abort()?;
let state = self.state_ref();
let commit_started = sampled_instant(state.start_ts, LATENCY_SAMPLE_SHIFT);
let g = self.ctx.group(state.group_id);
#[cfg(feature = "failpoints")]
crate::utils::failpoint::check("mace_txn_commit_begin")?;
if !state.modified {
{
let mut log = g.logging.lock();
log.record_commit(state.start_ts)?;
}
g.active_txns.remove(&state.start_ts);
self.is_end.set(true);
self.observe_counter(CounterMetric::TxnCommit, 1);
observe_elapsed(
self.ctx.opt.observer.as_ref(),
HistogramMetric::TxnCommitMicros,
commit_started,
);
return Ok(());
}
let commit_ts = self.ctx.alloc_oracle();
{
let mut log = g.logging.lock();
log.record_commit(state.start_ts)?;
#[cfg(feature = "failpoints")]
crate::utils::failpoint::check("mace_txn_commit_after_record_commit")?;
log.sync(false)?;
#[cfg(feature = "failpoints")]
crate::utils::failpoint::check("mace_txn_commit_after_wal_sync")?;
}
g.cc.commit_tree.append(state.start_ts, commit_ts);
g.cc.latest_cts.store(commit_ts, Relaxed);
g.cc.collect_wmk(self.ctx);
g.active_txns.remove(&state.start_ts);
self.is_end.set(true);
self.observe_counter(CounterMetric::TxnCommit, 1);
observe_elapsed(
self.ctx.opt.observer.as_ref(),
HistogramMetric::TxnCommitMicros,
commit_started,
);
Ok(())
}
#[inline]
pub fn get<K>(&self, k: K) -> Result<ValRef, OpCode>
where
K: AsRef<[u8]>,
{
let state = self.state_ref();
let group_id = state.group_id;
get_impl(
self.ctx,
&self.ctx.group(group_id).cc,
self.tree,
group_id as u8,
state.start_ts,
k,
)
}
#[inline]
pub fn seek<K>(&self, prefix: K) -> Iter<'_>
where
K: AsRef<[u8]>,
{
let state = self.state_ref();
let group_id = state.group_id;
seek_impl(
&self.ctx.group(group_id).cc,
self.tree,
group_id as u8,
state.start_ts,
prefix,
)
}
#[inline]
pub fn range<K, R>(&self, range: R) -> Iter<'_>
where
K: AsRef<[u8]>,
R: RangeBounds<K>,
{
let state = self.state_ref();
let group_id = state.group_id;
range_impl(
&self.ctx.group(group_id).cc,
self.tree,
group_id as u8,
state.start_ts,
range,
)
}
}
impl Drop for TxnKV<'_> {
fn drop(&mut self) {
let group_id = self.state_ref().group_id;
if !self.is_end.get() {
let state = self.state_ref();
let grp = self.ctx.group(state.group_id);
let modified = state.modified;
let mut log = grp.logging.lock();
log.record_abort(state.start_ts)
.inspect_err(|e| {
log::error!("can't record abort, {:?}", e);
})
.expect("can't fail");
if modified {
log.sync(false)
.inspect_err(|e| {
log::error!("can't sync abort chain before enqueue, {:?}", e);
})
.expect("can't fail");
}
drop(log);
if modified {
self.ctx.add_aborted(state.start_ts);
self.ctx.enqueue_abort_clean(
state.start_ts,
self.bucket_id,
state.group_id as u8,
state.prev_lsn,
state.begin_lsn.file_id,
);
} else {
}
self.observe_counter(CounterMetric::TxnAbort, 1);
grp.active_txns.remove(&state.start_ts);
self.is_end.set(true);
}
self.ctx.group(group_id).leave_inflight();
self.tree.bucket.state.dec_txn_ref();
}
}
pub struct TxnView<'a> {
ctx: &'a Context,
cc: Handle<CCNode>,
group_id: u8,
tree: &'a Tree,
}
impl<'a> TxnView<'a> {
pub(crate) fn new(ctx: &'a Context, tree: &'a Tree) -> Result<Self, OpCode> {
let cc = ctx.alloc_cc();
Ok(Self {
ctx,
cc,
group_id: u8::MAX,
tree,
})
}
#[inline]
pub fn get<K: AsRef<[u8]>>(&self, k: K) -> Result<ValRef, OpCode> {
get_impl(
self.ctx,
&self.cc,
self.tree,
self.group_id,
self.cc.start_ts,
k,
)
}
#[inline]
pub fn seek<K>(&self, prefix: K) -> Iter<'_>
where
K: AsRef<[u8]>,
{
seek_impl(&self.cc, self.tree, self.group_id, self.cc.start_ts, prefix)
}
#[inline]
pub fn range<K, R>(&self, range: R) -> Iter<'_>
where
K: AsRef<[u8]>,
R: RangeBounds<K>,
{
range_impl(&self.cc, self.tree, self.group_id, self.cc.start_ts, range)
}
}
impl Drop for TxnView<'_> {
fn drop(&mut self) {
self.ctx.free_cc(self.cc);
}
}
#[cfg(test)]
mod test {
use super::prefix_upper_exclusive;
use crate::{Mace, OpCode, Options, RandomPath};
#[test]
fn txnkv() {
txnkv_impl().unwrap();
}
#[test]
fn prefix_upper_exclusive_handles_carry() {
assert_eq!(
prefix_upper_exclusive(&[0x61, 0x62, 0x63]),
Some(vec![0x61, 0x62, 0x64])
);
assert_eq!(
prefix_upper_exclusive(&[0x61, 0xff, 0xff]),
Some(vec![0x62])
);
assert_eq!(prefix_upper_exclusive(&[0xff]), None);
assert_eq!(prefix_upper_exclusive(&[0xff, 0xff]), None);
}
fn txnkv_impl() -> Result<(), OpCode> {
let path = RandomPath::tmp();
let _ = std::fs::remove_dir_all(&*path);
let opt = Options::new(&*path).validate().unwrap();
let mace = Mace::new(opt)?;
let (k1, k2) = ("beast".as_bytes(), "senpai".as_bytes());
let (v1, v2) = ("114514".as_bytes(), "1919810".as_bytes());
let db = mace.new_bucket("default")?;
let kv = db.begin()?;
kv.put(k1, v1).expect("can't put");
kv.put(k2, v2).expect("can't put");
let r = kv.del(k1).expect("can't del");
assert_eq!(r.slice(), v1);
kv.commit()?;
let kv = db.begin()?;
let r = kv.get(k1);
assert!(r.is_err());
let r = kv.get(k2).expect("can't get");
assert_eq!(r.slice(), v2);
let r = kv.del(k2).expect("can't del");
assert_eq!(r.slice(), v2);
drop(kv);
let kv = db.begin()?;
let r = kv.get(k1);
assert!(r.is_err());
let r = kv.del(k2).expect("can't del");
assert_eq!(r.slice(), v2);
let r = kv.del(k2);
assert!(r.is_err());
kv.commit()?;
let kv = db.begin()?;
let r = kv.get(k1);
assert!(r.is_err());
let r = kv.get(k2);
assert!(r.is_err());
kv.commit()?;
{
let kv = db.begin()?;
kv.put("1", "10")?;
kv.commit()?;
let kv = db.begin()?;
kv.update("1", "11").expect("can't replace");
drop(kv);
let view = db.view()?;
let x = view.get("1").expect("can't get");
assert_eq!(x.slice(), "10".as_bytes());
}
{
let kv = db.begin()?;
kv.put("2", "20")?;
kv.update("2", "21")?;
let r = kv.get("2").unwrap();
assert_eq!(r.slice(), "21".as_bytes());
kv.del("2")?;
drop(kv);
let view = db.view()?;
let x = view.get("2");
assert!(x.is_err());
}
{
let kv = db.begin()?;
kv.put("11", "10")?;
kv.commit()?;
let kv = db.begin()?;
kv.upsert("11", "11").expect("can't replace");
drop(kv);
let view = db.view()?;
let x = view.get("11").expect("can't get");
assert_eq!(x.slice(), "10".as_bytes());
}
{
let kv = db.begin()?;
kv.put("22", "20")?;
kv.upsert("22", "21")?;
let r = kv.get("22").unwrap();
assert_eq!(r.slice(), "21".as_bytes());
kv.del("22")?;
drop(kv);
let view = db.view()?;
let x = view.get("22");
assert!(x.is_err());
}
{
let kv = db.begin()?;
kv.put("elder", "+1s")?;
kv.del("elder")?;
kv.commit()?;
let kv = db.begin()?;
let r = kv.update("elder", "mo");
assert!(r.is_err());
kv.upsert("elder", "mo")?;
kv.commit()?;
let view = db.view()?;
assert_eq!(view.get("elder").unwrap().slice(), "mo".as_bytes());
}
drop(db);
drop(mace);
Ok(())
}
#[test]
fn cross_long_txn() {
cross_long_txn_impl().unwrap();
}
fn cross_long_txn_impl() -> Result<(), OpCode> {
let path = RandomPath::new();
let mut opt = Options::new(&*path);
let consolidate_threshold = 256;
opt.tmp_store = true;
opt.split_elems = consolidate_threshold * 2;
opt.consolidate_threshold = consolidate_threshold;
let mace = Mace::new(opt.validate().unwrap())?;
let db = mace.new_bucket("default")?;
let kv = db.begin()?;
kv.put("foo", "bar")?;
kv.commit()?;
let view = db.view()?;
let kv = db.begin()?;
kv.update("foo", "bar1")?;
kv.update("foo", "bar2")?;
for i in 0..consolidate_threshold {
let x = format!("key_{i}");
kv.put(&x, &x)?;
}
let r = kv.get("foo")?;
assert_eq!(r.slice(), "bar2".as_bytes());
kv.commit()?;
let v = view.get("foo")?;
assert_eq!(v.slice(), "bar".as_bytes());
drop(view);
drop(db);
drop(mace);
Ok(())
}
}