use super::engine::KvEngine;
use super::engine_helpers::{expiry_key, table_key};
use super::entry::NO_EXPIRY;
use super::hash_table::KvHashTable;
pub struct CasResult {
pub success: bool,
pub current_value: Option<Vec<u8>>,
}
#[derive(Debug)]
pub enum AtomicError {
TypeMismatch { detail: String },
Overflow,
}
impl KvEngine {
pub fn incr(
&mut self,
tenant_id: u64,
collection: &str,
key: &[u8],
delta: i64,
ttl_ms: u64,
now_ms: u64,
) -> Result<i64, AtomicError> {
let tkey = table_key(tenant_id, collection);
let table = self.ensure_table(tkey, tenant_id, collection);
let current = table.get(key, now_ms).map(|v| v.to_vec());
let old_i64 = match ¤t {
None => 0i64,
Some(bytes) => decode_msgpack_i64(bytes)?,
};
let new_i64 = old_i64.checked_add(delta).ok_or(AtomicError::Overflow)?;
let new_bytes = if let Some(ref cur) = current
&& let Ok(nodedb_types::Value::Object(mut map)) = nodedb_types::value_from_msgpack(cur)
&& map.len() > 1
{
let mut updated = false;
for (k, v) in map.iter_mut() {
if k == "key" {
continue;
}
if matches!(
v,
nodedb_types::Value::Integer(_) | nodedb_types::Value::Float(_)
) {
*v = nodedb_types::Value::Integer(new_i64);
updated = true;
break;
}
}
if updated {
nodedb_types::value_to_msgpack(&nodedb_types::Value::Object(map))
.unwrap_or_else(|_| zerompk::to_msgpack_vec(&new_i64).expect("i64 serializes"))
} else {
zerompk::to_msgpack_vec(&new_i64).expect("i64 always serializes")
}
} else {
zerompk::to_msgpack_vec(&new_i64).expect("i64 always serializes")
};
self.atomic_put(
tenant_id,
collection,
tkey,
key,
&new_bytes,
ttl_ms,
now_ms,
current.is_none(),
);
Ok(new_i64)
}
pub fn incr_float(
&mut self,
tenant_id: u64,
collection: &str,
key: &[u8],
delta: f64,
now_ms: u64,
) -> Result<f64, AtomicError> {
let tkey = table_key(tenant_id, collection);
let table = self.ensure_table(tkey, tenant_id, collection);
let current = table.get(key, now_ms).map(|v| v.to_vec());
let old_f64 = match ¤t {
None => 0.0f64,
Some(bytes) => decode_msgpack_f64(bytes)?,
};
let new_f64 = old_f64 + delta;
if new_f64.is_nan() || new_f64.is_infinite() {
return Err(AtomicError::Overflow);
}
let new_bytes = zerompk::to_msgpack_vec(&new_f64).expect("f64 always serializes");
self.atomic_put(
tenant_id,
collection,
tkey,
key,
&new_bytes,
0,
now_ms,
current.is_none(),
);
Ok(new_f64)
}
pub fn cas(
&mut self,
tenant_id: u64,
collection: &str,
key: &[u8],
expected: &[u8],
new_value: &[u8],
now_ms: u64,
) -> CasResult {
let tkey = table_key(tenant_id, collection);
let table = self.ensure_table(tkey, tenant_id, collection);
let current = table.get(key, now_ms).map(|v| v.to_vec());
let matches = match ¤t {
None => expected.is_empty(),
Some(v) => {
if v.as_slice() == expected {
true
} else if let Ok(nodedb_types::Value::Object(map)) =
nodedb_types::value_from_msgpack(v)
{
let expected_str = String::from_utf8_lossy(expected);
map.iter().any(|(k, val)| {
k != "key"
&& matches!(val, nodedb_types::Value::String(s) if s == expected_str.as_ref())
})
} else {
false
}
}
};
if matches {
let write_bytes = if let Some(ref cur) = current
&& let Ok(nodedb_types::Value::Object(mut map)) =
nodedb_types::value_from_msgpack(cur)
&& map.len() > 1
{
let new_str = String::from_utf8_lossy(new_value).to_string();
let mut updated = false;
for (k, v) in map.iter_mut() {
if k == "key" {
continue;
}
if matches!(v, nodedb_types::Value::String(_)) {
*v = nodedb_types::Value::String(new_str.clone());
updated = true;
break;
}
}
if updated {
nodedb_types::value_to_msgpack(&nodedb_types::Value::Object(map))
.unwrap_or_else(|_| new_value.to_vec())
} else {
new_value.to_vec()
}
} else {
new_value.to_vec()
};
self.atomic_put(
tenant_id,
collection,
tkey,
key,
&write_bytes,
0,
now_ms,
current.is_none(),
);
CasResult {
success: true,
current_value: current,
}
} else {
CasResult {
success: false,
current_value: current,
}
}
}
pub fn getset(
&mut self,
tenant_id: u64,
collection: &str,
key: &[u8],
new_value: &[u8],
now_ms: u64,
) -> Option<Vec<u8>> {
let tkey = table_key(tenant_id, collection);
let table = self.ensure_table(tkey, tenant_id, collection);
let old = table.get(key, now_ms).map(|v| v.to_vec());
let write_bytes = if let Some(ref cur) = old
&& let Ok(nodedb_types::Value::Object(mut map)) = nodedb_types::value_from_msgpack(cur)
&& map.len() > 1
{
let new_str = String::from_utf8_lossy(new_value).to_string();
let mut updated = false;
for (k, v) in map.iter_mut() {
if k == "key" {
continue;
}
if matches!(v, nodedb_types::Value::String(_)) {
*v = nodedb_types::Value::String(new_str.clone());
updated = true;
break;
}
}
if updated {
nodedb_types::value_to_msgpack(&nodedb_types::Value::Object(map))
.unwrap_or_else(|_| new_value.to_vec())
} else {
new_value.to_vec()
}
} else {
new_value.to_vec()
};
self.atomic_put(
tenant_id,
collection,
tkey,
key,
&write_bytes,
0,
now_ms,
old.is_none(),
);
old
}
fn ensure_table(&mut self, tkey: u64, tenant_id: u64, collection: &str) -> &mut KvHashTable {
if !self.tables.contains_key(&tkey) {
self.hash_to_tenant.entry(tkey).or_insert(tenant_id);
self.hash_to_collection
.entry(tkey)
.or_insert_with(|| collection.to_string());
self.tables.entry(tkey).or_insert_with(|| {
KvHashTable::new(
self.default_capacity,
self.load_factor_threshold,
self.rehash_batch_size,
self.inline_threshold,
)
});
}
self.tables.get_mut(&tkey).expect("just ensured")
}
#[allow(clippy::too_many_arguments)]
fn atomic_put(
&mut self,
tenant_id: u64,
collection: &str,
tkey: u64,
key: &[u8],
value: &[u8],
ttl_ms: u64,
now_ms: u64,
is_new_key: bool,
) {
let old_meta = if is_new_key {
None
} else {
self.tables.get(&tkey).and_then(|t| t.get_entry_meta(key))
};
let expire_at = if ttl_ms > 0 {
now_ms + ttl_ms
} else if let Some(ref meta) = old_meta {
meta.expire_at_ms
} else {
NO_EXPIRY
};
if let Some(ref meta) = old_meta
&& meta.has_ttl
{
let composite = expiry_key(tenant_id, collection, key);
self.expiry.cancel(&composite, meta.expire_at_ms);
}
let old_fields =
if !is_new_key && self.indexes.get(&tkey).is_some_and(|idx| !idx.is_empty()) {
self.tables
.get(&tkey)
.and_then(|t| t.get(key, now_ms))
.map(|old_val| {
super::engine_helpers::extract_all_field_values_from_msgpack(old_val)
})
} else {
None
};
let table = self.tables.get_mut(&tkey).expect("table ensured");
table.put(key, value, expire_at, nodedb_types::Surrogate::ZERO);
if expire_at != NO_EXPIRY {
let composite = expiry_key(tenant_id, collection, key);
self.expiry.insert(composite, expire_at);
}
if self.indexes.get(&tkey).is_some_and(|idx| !idx.is_empty()) {
let old_refs: Option<Vec<(&str, &[u8])>> = old_fields.as_ref().map(|fields| {
fields
.iter()
.map(|(k, v)| (k.as_str(), v.as_slice()))
.collect()
});
let new_fields = super::engine_helpers::extract_all_field_values_from_msgpack(value);
let new_refs: Vec<(&str, &[u8])> = new_fields
.iter()
.map(|(k, v)| (k.as_str(), v.as_slice()))
.collect();
if let Some(idx_set) = self.indexes.get_mut(&tkey) {
idx_set.on_put(key, &new_refs, old_refs.as_deref());
}
}
}
}
fn decode_msgpack_i64(bytes: &[u8]) -> Result<i64, AtomicError> {
if let Ok(v) = zerompk::from_msgpack::<i64>(bytes) {
return Ok(v);
}
if let Ok(v) = zerompk::from_msgpack::<u64>(bytes) {
return i64::try_from(v).map_err(|_| AtomicError::Overflow);
}
if let Ok(v) = zerompk::from_msgpack::<f64>(bytes)
&& v.fract() == 0.0
&& v >= i64::MIN as f64
&& v <= i64::MAX as f64
{
return Ok(v as i64);
}
if let Ok(nodedb_types::Value::Object(map)) = nodedb_types::value_from_msgpack(bytes) {
for (k, v) in &map {
if k == "key" {
continue;
}
match v {
nodedb_types::Value::Integer(i) => return Ok(*i),
nodedb_types::Value::Float(f) if f.fract() == 0.0 => return Ok(*f as i64),
_ => {}
}
}
}
Err(AtomicError::TypeMismatch {
detail: "value is not an integer".into(),
})
}
fn decode_msgpack_f64(bytes: &[u8]) -> Result<f64, AtomicError> {
if let Ok(v) = zerompk::from_msgpack::<f64>(bytes) {
return Ok(v);
}
if let Ok(v) = zerompk::from_msgpack::<i64>(bytes) {
return Ok(v as f64);
}
if let Ok(v) = zerompk::from_msgpack::<u64>(bytes) {
return Ok(v as f64);
}
if let Ok(nodedb_types::Value::Object(map)) = nodedb_types::value_from_msgpack(bytes) {
for (k, v) in &map {
if k == "key" {
continue;
}
match v {
nodedb_types::Value::Float(f) => return Ok(*f),
nodedb_types::Value::Integer(i) => return Ok(*i as f64),
_ => {}
}
}
}
Err(AtomicError::TypeMismatch {
detail: "value is not numeric".into(),
})
}
#[cfg(test)]
mod tests {
use nodedb_types::Surrogate;
use super::*;
fn make_engine() -> KvEngine {
KvEngine::new(1000, 16, 0.75, 4, 64, 1000, 1024)
}
#[test]
fn incr_new_key() {
let mut engine = make_engine();
let result = engine.incr(1, "counters", b"hits", 10, 0, 1000);
assert_eq!(result.unwrap(), 10);
}
#[test]
fn incr_existing_key() {
let mut engine = make_engine();
engine.incr(1, "counters", b"hits", 10, 0, 1000).unwrap();
let result = engine.incr(1, "counters", b"hits", 5, 0, 1000);
assert_eq!(result.unwrap(), 15);
}
#[test]
fn incr_negative_delta() {
let mut engine = make_engine();
engine.incr(1, "counters", b"gold", 100, 0, 1000).unwrap();
let result = engine.incr(1, "counters", b"gold", -30, 0, 1000);
assert_eq!(result.unwrap(), 70);
}
#[test]
fn incr_overflow() {
let mut engine = make_engine();
let bytes = zerompk::to_msgpack_vec(&i64::MAX).unwrap();
engine.put(1, "counters", b"max", &bytes, 0, 1000, Surrogate::ZERO);
let result = engine.incr(1, "counters", b"max", 1, 0, 1000);
assert!(matches!(result, Err(AtomicError::Overflow)));
}
#[test]
fn incr_type_mismatch() {
let mut engine = make_engine();
let bytes = zerompk::to_msgpack_vec(&"hello").unwrap();
engine.put(1, "counters", b"str", &bytes, 0, 1000, Surrogate::ZERO);
let result = engine.incr(1, "counters", b"str", 1, 0, 1000);
assert!(matches!(result, Err(AtomicError::TypeMismatch { .. })));
}
#[test]
fn incr_with_ttl_new_key() {
let mut engine = make_engine();
engine
.incr(1, "counters", b"daily", 1, 86_400_000, 1000)
.unwrap();
let ttl = engine.get_ttl_ms(1, "counters", b"daily", 1000);
assert!(ttl.is_some());
assert!(ttl.unwrap() > 0);
}
#[test]
fn incr_preserves_ttl_when_zero() {
let mut engine = make_engine();
let bytes = zerompk::to_msgpack_vec(&50i64).unwrap();
engine.put(1, "counters", b"temp", &bytes, 5000, 1000, Surrogate::ZERO);
engine.incr(1, "counters", b"temp", 10, 0, 1000).unwrap();
let ttl = engine.get_ttl_ms(1, "counters", b"temp", 1000);
assert!(ttl.is_some());
assert!(ttl.unwrap() > 0);
}
#[test]
fn incr_float_new_key() {
let mut engine = make_engine();
let result = engine.incr_float(1, "scores", b"dmg", 3.125, 1000);
assert!((result.unwrap() - 3.125).abs() < f64::EPSILON);
}
#[test]
fn incr_float_existing() {
let mut engine = make_engine();
engine.incr_float(1, "scores", b"dmg", 3.0, 1000).unwrap();
let result = engine.incr_float(1, "scores", b"dmg", 1.5, 1000);
assert!((result.unwrap() - 4.5).abs() < f64::EPSILON);
}
#[test]
fn incr_float_infinity_rejected() {
let mut engine = make_engine();
let bytes = zerompk::to_msgpack_vec(&f64::MAX).unwrap();
engine.put(1, "scores", b"big", &bytes, 0, 1000, Surrogate::ZERO);
let result = engine.incr_float(1, "scores", b"big", f64::MAX, 1000);
assert!(matches!(result, Err(AtomicError::Overflow)));
}
#[test]
fn cas_create_if_not_exists() {
let mut engine = make_engine();
let result = engine.cas(1, "state", b"player1", b"", b"idle", 1000);
assert!(result.success);
assert!(result.current_value.is_none());
let val = engine.get(1, "state", b"player1", 1000);
assert_eq!(val.as_deref(), Some(b"idle".as_slice()));
}
#[test]
fn cas_success() {
let mut engine = make_engine();
engine.put(1, "state", b"p1", b"idle", 0, 1000, Surrogate::ZERO);
let result = engine.cas(1, "state", b"p1", b"idle", b"in_match", 1000);
assert!(result.success);
assert_eq!(result.current_value.as_deref(), Some(b"idle".as_slice()));
let val = engine.get(1, "state", b"p1", 1000);
assert_eq!(val.as_deref(), Some(b"in_match".as_slice()));
}
#[test]
fn cas_failure() {
let mut engine = make_engine();
engine.put(1, "state", b"p1", b"fighting", 0, 1000, Surrogate::ZERO);
let result = engine.cas(1, "state", b"p1", b"idle", b"in_match", 1000);
assert!(!result.success);
assert_eq!(
result.current_value.as_deref(),
Some(b"fighting".as_slice())
);
let val = engine.get(1, "state", b"p1", 1000);
assert_eq!(val.as_deref(), Some(b"fighting".as_slice()));
}
#[test]
fn getset_new_key() {
let mut engine = make_engine();
let old = engine.getset(1, "session", b"tok", b"new-token", 1000);
assert!(old.is_none());
let val = engine.get(1, "session", b"tok", 1000);
assert_eq!(val.as_deref(), Some(b"new-token".as_slice()));
}
#[test]
fn getset_existing_key() {
let mut engine = make_engine();
engine.put(1, "session", b"tok", b"old-token", 0, 1000, Surrogate::ZERO);
let old = engine.getset(1, "session", b"tok", b"new-token", 1000);
assert_eq!(old.as_deref(), Some(b"old-token".as_slice()));
let val = engine.get(1, "session", b"tok", 1000);
assert_eq!(val.as_deref(), Some(b"new-token".as_slice()));
}
}