use std::cmp::Ordering;
use bytes::Bytes;
use crate::types::{
schema::{ColumnType, TableSchema},
sequence::{OpType, SeqNum, SEQNUM_MAX},
value::FieldValue,
MeruError, Result,
};
#[derive(Clone, Debug)]
pub struct InternalKey {
encoded: Bytes,
pub seq: SeqNum,
pub op_type: OpType,
pk_values: Vec<FieldValue>,
}
impl InternalKey {
pub fn encode(
pk_values: &[FieldValue],
seq: SeqNum,
op_type: OpType,
schema: &TableSchema,
) -> Result<Self> {
let mut buf = Vec::with_capacity(64);
encode_pk_fields(pk_values, schema, &mut buf)?;
encode_tag(seq, op_type, &mut buf)?;
Ok(Self {
encoded: Bytes::from(buf),
seq,
op_type,
pk_values: pk_values.to_vec(),
})
}
pub fn seek_latest(pk_values: &[FieldValue], schema: &TableSchema) -> Result<Self> {
Self::encode(pk_values, SEQNUM_MAX, OpType::Put, schema)
}
#[inline]
pub fn as_bytes(&self) -> &[u8] {
&self.encoded
}
pub fn pk_values(&self) -> &[FieldValue] {
&self.pk_values
}
pub fn decode(raw: &[u8], schema: &TableSchema) -> Result<Self> {
if raw.len() < 8 {
return Err(MeruError::Corruption("internal key too short".into()));
}
let (pk_bytes, tag_bytes) = raw.split_at(raw.len() - 8);
let tag = u64::from_be_bytes(tag_bytes.try_into().unwrap());
let inverted_seq = tag >> 8;
let op_byte = (tag & 0xFF) as u8;
if inverted_seq > SEQNUM_MAX.0 {
return Err(MeruError::Corruption(format!(
"inverted_seq {inverted_seq} exceeds SEQNUM_MAX ({})",
SEQNUM_MAX.0
)));
}
let seq = SeqNum(SEQNUM_MAX.0 - inverted_seq);
let op_type = match op_byte {
0x00 => OpType::Delete,
0x01 => OpType::Put,
_ => {
return Err(MeruError::Corruption(format!(
"unknown op_type {op_byte:#x}"
)))
}
};
let pk_values = decode_pk_fields(pk_bytes, schema)?;
Ok(Self {
encoded: Bytes::copy_from_slice(raw),
seq,
op_type,
pk_values,
})
}
pub fn user_key_bytes(&self) -> &[u8] {
&self.encoded[..self.encoded.len() - 8]
}
pub fn encode_user_key(pk_values: &[FieldValue], schema: &TableSchema) -> Result<Vec<u8>> {
let mut buf = Vec::with_capacity(64);
encode_pk_fields(pk_values, schema, &mut buf)?;
Ok(buf)
}
}
impl PartialEq for InternalKey {
fn eq(&self, other: &Self) -> bool {
self.encoded == other.encoded
}
}
impl Eq for InternalKey {}
impl PartialOrd for InternalKey {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for InternalKey {
#[inline]
fn cmp(&self, other: &Self) -> Ordering {
self.encoded.cmp(&other.encoded)
}
}
fn encode_pk_fields(values: &[FieldValue], schema: &TableSchema, buf: &mut Vec<u8>) -> Result<()> {
if values.len() != schema.primary_key.len() {
return Err(MeruError::InvalidArgument(format!(
"expected {} PK values, got {}",
schema.primary_key.len(),
values.len()
)));
}
for (val, &col_idx) in values.iter().zip(schema.primary_key.iter()) {
encode_field(val, &schema.columns[col_idx].col_type, buf)?;
}
Ok(())
}
fn encode_field(val: &FieldValue, col_type: &ColumnType, buf: &mut Vec<u8>) -> Result<()> {
match (val, col_type) {
(FieldValue::Boolean(b), ColumnType::Boolean) => {
buf.push(u8::from(*b));
}
(FieldValue::Int32(v), ColumnType::Int32) => {
buf.extend_from_slice(&((*v as u32) ^ 0x8000_0000_u32).to_be_bytes());
}
(FieldValue::Int64(v), ColumnType::Int64) => {
buf.extend_from_slice(&((*v as u64) ^ 0x8000_0000_0000_0000_u64).to_be_bytes());
}
(FieldValue::Float(v), ColumnType::Float) => {
if v.is_nan() {
return Err(MeruError::InvalidArgument(
"NaN is not allowed in primary key columns (Float): \
IEEE 754 NaN has multiple bit representations, \
producing non-deterministic key encoding"
.into(),
));
}
buf.extend_from_slice(&order_preserving_f32(*v));
}
(FieldValue::Double(v), ColumnType::Double) => {
if v.is_nan() {
return Err(MeruError::InvalidArgument(
"NaN is not allowed in primary key columns (Double): \
IEEE 754 NaN has multiple bit representations, \
producing non-deterministic key encoding"
.into(),
));
}
buf.extend_from_slice(&order_preserving_f64(*v));
}
(FieldValue::Bytes(b), ColumnType::FixedLenByteArray(n)) => {
if b.len() != *n as usize {
return Err(MeruError::SchemaMismatch(format!(
"FixedLenByteArray({n}): got {} bytes",
b.len()
)));
}
buf.extend_from_slice(b);
}
(FieldValue::Bytes(b), ColumnType::ByteArray) => {
escape_byte_array(b, buf);
}
_ => {
return Err(MeruError::SchemaMismatch(format!(
"field value type mismatch with column type {col_type:?}"
)));
}
}
Ok(())
}
#[inline]
fn escape_byte_array(bytes: &[u8], buf: &mut Vec<u8>) {
for &b in bytes {
if b == 0x00 {
buf.push(0x00);
buf.push(0xFF);
} else {
buf.push(b);
}
}
buf.push(0x00); buf.push(0x00); }
#[inline]
fn order_preserving_f32(v: f32) -> [u8; 4] {
let bits = v.to_bits();
let encoded = if bits >> 31 == 1 {
!bits
} else {
bits ^ 0x8000_0000
};
encoded.to_be_bytes()
}
#[inline]
fn order_preserving_f64(v: f64) -> [u8; 8] {
let bits = v.to_bits();
let encoded = if bits >> 63 == 1 {
!bits
} else {
bits ^ 0x8000_0000_0000_0000
};
encoded.to_be_bytes()
}
fn encode_tag(seq: SeqNum, op_type: OpType, buf: &mut Vec<u8>) -> Result<()> {
if seq.0 > SEQNUM_MAX.0 {
return Err(MeruError::InvalidArgument(format!(
"sequence number {} exceeds SEQNUM_MAX ({})",
seq.0, SEQNUM_MAX.0
)));
}
let inverted = SEQNUM_MAX.0 - seq.0;
let tag = (inverted << 8) | (op_type as u64);
buf.extend_from_slice(&tag.to_be_bytes());
Ok(())
}
fn decode_pk_fields(pk_bytes: &[u8], schema: &TableSchema) -> Result<Vec<FieldValue>> {
let mut pos = 0usize;
let mut values = Vec::with_capacity(schema.primary_key.len());
for &col_idx in &schema.primary_key {
let col_type = &schema.columns[col_idx].col_type;
let (val, consumed) = decode_field(&pk_bytes[pos..], col_type)?;
values.push(val);
pos += consumed;
}
if pos != pk_bytes.len() {
return Err(MeruError::Corruption(format!(
"{} leftover bytes after decoding all PK fields",
pk_bytes.len() - pos
)));
}
Ok(values)
}
fn decode_field(bytes: &[u8], col_type: &ColumnType) -> Result<(FieldValue, usize)> {
match col_type {
ColumnType::Boolean => {
ensure_len(bytes, 1, "boolean")?;
Ok((FieldValue::Boolean(bytes[0] != 0x00), 1))
}
ColumnType::Int32 => {
ensure_len(bytes, 4, "int32")?;
let u = u32::from_be_bytes(bytes[..4].try_into().unwrap()) ^ 0x8000_0000;
Ok((FieldValue::Int32(u as i32), 4))
}
ColumnType::Int64 => {
ensure_len(bytes, 8, "int64")?;
let u = u64::from_be_bytes(bytes[..8].try_into().unwrap()) ^ 0x8000_0000_0000_0000;
Ok((FieldValue::Int64(u as i64), 8))
}
ColumnType::Float => {
ensure_len(bytes, 4, "float")?;
let bits = u32::from_be_bytes(bytes[..4].try_into().unwrap());
let orig = if bits >> 31 == 0 {
bits ^ 0x8000_0000
} else {
!bits
};
Ok((FieldValue::Float(f32::from_bits(orig)), 4))
}
ColumnType::Double => {
ensure_len(bytes, 8, "double")?;
let bits = u64::from_be_bytes(bytes[..8].try_into().unwrap());
let orig = if bits >> 63 == 0 {
bits ^ 0x8000_0000_0000_0000
} else {
!bits
};
Ok((FieldValue::Double(f64::from_bits(orig)), 8))
}
ColumnType::FixedLenByteArray(n) => {
let n = *n as usize;
ensure_len(bytes, n, "fixed-len byte array")?;
Ok((FieldValue::Bytes(Bytes::copy_from_slice(&bytes[..n])), n))
}
ColumnType::ByteArray => {
let (val, consumed) = unescape_byte_array(bytes)?;
Ok((FieldValue::Bytes(Bytes::from(val)), consumed))
}
}
}
fn ensure_len(bytes: &[u8], required: usize, field: &str) -> Result<()> {
if bytes.len() < required {
Err(MeruError::Corruption(format!(
"truncated {field} field: need {required}, have {}",
bytes.len()
)))
} else {
Ok(())
}
}
fn unescape_byte_array(bytes: &[u8]) -> Result<(Vec<u8>, usize)> {
let mut result = Vec::new();
let mut i = 0;
loop {
if i >= bytes.len() {
return Err(MeruError::Corruption(
"unterminated escaped byte array".into(),
));
}
if bytes[i] == 0x00 {
if i + 1 >= bytes.len() {
return Err(MeruError::Corruption(
"truncated escape/terminator sequence".into(),
));
}
match bytes[i + 1] {
0xFF => {
result.push(0x00);
i += 2;
}
0x00 => {
return Ok((result, i + 2));
}
other => {
return Err(MeruError::Corruption(format!(
"invalid byte sequence 0x00 followed by 0x{other:02X} \
(expected 0xFF for escape or 0x00 for terminator)"
)));
}
}
} else {
result.push(bytes[i]);
i += 1;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{
schema::{ColumnDef, ColumnType, TableSchema},
sequence::{OpType, SeqNum, SEQNUM_MAX},
value::FieldValue,
};
fn int64_schema() -> TableSchema {
TableSchema {
table_name: "t".into(),
columns: vec![ColumnDef {
name: "id".into(),
col_type: ColumnType::Int64,
nullable: false,
..Default::default()
}],
primary_key: vec![0],
..Default::default()
}
}
fn bytearray_schema() -> TableSchema {
TableSchema {
table_name: "t".into(),
columns: vec![ColumnDef {
name: "k".into(),
col_type: ColumnType::ByteArray,
nullable: false,
..Default::default()
}],
primary_key: vec![0],
..Default::default()
}
}
fn composite_schema() -> TableSchema {
TableSchema {
table_name: "t".into(),
columns: vec![
ColumnDef {
name: "a".into(),
col_type: ColumnType::Int32,
nullable: false,
..Default::default()
},
ColumnDef {
name: "b".into(),
col_type: ColumnType::ByteArray,
nullable: false,
..Default::default()
},
ColumnDef {
name: "v".into(),
col_type: ColumnType::ByteArray,
nullable: true,
..Default::default()
},
],
primary_key: vec![0, 1],
..Default::default()
}
}
#[test]
fn roundtrip_int64() {
let s = int64_schema();
let pk = vec![FieldValue::Int64(42)];
let k = InternalKey::encode(&pk, SeqNum(100), OpType::Put, &s).unwrap();
let d = InternalKey::decode(k.as_bytes(), &s).unwrap();
assert_eq!(d.seq, SeqNum(100));
assert_eq!(d.op_type, OpType::Put);
assert_eq!(d.pk_values()[0], FieldValue::Int64(42));
}
#[test]
fn roundtrip_negative_int64() {
let s = int64_schema();
let pk = vec![FieldValue::Int64(-1_000_000)];
let k = InternalKey::encode(&pk, SeqNum(1), OpType::Delete, &s).unwrap();
let d = InternalKey::decode(k.as_bytes(), &s).unwrap();
assert_eq!(d.pk_values()[0], FieldValue::Int64(-1_000_000));
assert_eq!(d.op_type, OpType::Delete);
}
#[test]
fn roundtrip_bytearray_with_nulls() {
let s = bytearray_schema();
let raw = Bytes::from(vec![0x61u8, 0x00, 0xFF, 0x00, 0x62]);
let pk = vec![FieldValue::Bytes(raw.clone())];
let k = InternalKey::encode(&pk, SeqNum(7), OpType::Put, &s).unwrap();
let d = InternalKey::decode(k.as_bytes(), &s).unwrap();
match &d.pk_values()[0] {
FieldValue::Bytes(b) => assert_eq!(&b[..], &raw[..]),
_ => panic!("expected Bytes"),
}
}
#[test]
fn roundtrip_composite() {
let s = composite_schema();
let pk = vec![
FieldValue::Int32(-5),
FieldValue::Bytes(Bytes::from("hello\x00world")),
];
let k = InternalKey::encode(&pk, SeqNum(99), OpType::Put, &s).unwrap();
let d = InternalKey::decode(k.as_bytes(), &s).unwrap();
assert_eq!(d.pk_values()[0], FieldValue::Int32(-5));
match &d.pk_values()[1] {
FieldValue::Bytes(b) => assert_eq!(b.as_ref(), b"hello\x00world"),
_ => panic!("expected Bytes"),
}
}
#[test]
fn newer_seq_sorts_first() {
let s = int64_schema();
let pk = vec![FieldValue::Int64(1)];
let k_old = InternalKey::encode(&pk, SeqNum(1), OpType::Put, &s).unwrap();
let k_new = InternalKey::encode(&pk, SeqNum(100), OpType::Put, &s).unwrap();
assert!(
k_new < k_old,
"newer seq must sort before older for same PK"
);
}
#[test]
fn seek_latest_sorts_first() {
let s = int64_schema();
let pk = vec![FieldValue::Int64(1)];
let seek = InternalKey::seek_latest(&pk, &s).unwrap();
let real = InternalKey::encode(&pk, SeqNum(999_999), OpType::Put, &s).unwrap();
assert!(seek <= real);
}
#[test]
fn pk_ascending_order() {
let s = int64_schema();
let k1 = InternalKey::encode(&[FieldValue::Int64(1)], SeqNum(0), OpType::Put, &s).unwrap();
let k2 = InternalKey::encode(&[FieldValue::Int64(2)], SeqNum(0), OpType::Put, &s).unwrap();
assert!(k1 < k2);
}
#[test]
fn negative_before_positive_int64() {
let s = int64_schema();
let neg =
InternalKey::encode(&[FieldValue::Int64(-1)], SeqNum(0), OpType::Put, &s).unwrap();
let pos = InternalKey::encode(&[FieldValue::Int64(1)], SeqNum(0), OpType::Put, &s).unwrap();
assert!(neg < pos);
}
#[test]
fn i64_min_before_zero_before_max() {
let s = int64_schema();
let kmin = InternalKey::encode(&[FieldValue::Int64(i64::MIN)], SeqNum(0), OpType::Put, &s)
.unwrap();
let kzero =
InternalKey::encode(&[FieldValue::Int64(0)], SeqNum(0), OpType::Put, &s).unwrap();
let kmax = InternalKey::encode(&[FieldValue::Int64(i64::MAX)], SeqNum(0), OpType::Put, &s)
.unwrap();
assert!(kmin < kzero && kzero < kmax);
}
#[test]
fn bytearray_lexicographic_order() {
let s = bytearray_schema();
let ka = InternalKey::encode(
&[FieldValue::Bytes(Bytes::from("abc"))],
SeqNum(0),
OpType::Put,
&s,
)
.unwrap();
let kb = InternalKey::encode(
&[FieldValue::Bytes(Bytes::from("abd"))],
SeqNum(0),
OpType::Put,
&s,
)
.unwrap();
let kc = InternalKey::encode(
&[FieldValue::Bytes(Bytes::from("abcd"))],
SeqNum(0),
OpType::Put,
&s,
)
.unwrap();
assert!(ka < kb);
assert!(ka < kc);
}
#[test]
fn bytearray_empty_and_null_keys_distinct_and_ordered() {
let s = bytearray_schema();
let k_empty = InternalKey::encode_user_key(&[FieldValue::Bytes(Bytes::new())], &s).unwrap();
let k_null1 =
InternalKey::encode_user_key(&[FieldValue::Bytes(Bytes::from_static(&[0u8]))], &s)
.unwrap();
let k_null2 =
InternalKey::encode_user_key(&[FieldValue::Bytes(Bytes::from_static(&[0u8, 0u8]))], &s)
.unwrap();
let k_one =
InternalKey::encode_user_key(&[FieldValue::Bytes(Bytes::from_static(&[0x01u8]))], &s)
.unwrap();
let k_null1_one = InternalKey::encode_user_key(
&[FieldValue::Bytes(Bytes::from_static(&[0u8, 0x01u8]))],
&s,
)
.unwrap();
assert_eq!(k_empty, vec![0x00, 0x00]);
assert_eq!(k_null1, vec![0x00, 0xFF, 0x00, 0x00]);
assert_eq!(k_null2, vec![0x00, 0xFF, 0x00, 0xFF, 0x00, 0x00]);
assert_eq!(k_one, vec![0x01, 0x00, 0x00]);
assert_eq!(k_null1_one, vec![0x00, 0xFF, 0x01, 0x00, 0x00]);
assert_ne!(k_empty, k_null1);
assert_ne!(k_null1, k_null2);
assert_ne!(k_null1, k_null1_one);
assert!(k_empty < k_null1);
assert!(k_null1 < k_null2);
assert!(k_null2 < k_null1_one);
assert!(k_null1_one < k_one);
}
#[test]
fn bytearray_escape_unescape_roundtrip() {
let cases: &[&[u8]] = &[
&[],
&[0x00],
&[0x00, 0x00],
&[0x00, 0x01],
&[0x01, 0x00],
&[0xFF],
&[0x00, 0xFF],
&[0xFF, 0x00],
&[0x00, 0xFF, 0x00],
b"hello",
b"hello\0world",
];
for case in cases {
let mut buf = Vec::new();
escape_byte_array(case, &mut buf);
let (decoded, consumed) = unescape_byte_array(&buf).unwrap();
assert_eq!(
decoded.as_slice(),
*case,
"roundtrip failed for {case:?} (encoded={buf:?})"
);
assert_eq!(
consumed,
buf.len(),
"consumed={consumed} but buf.len()={} for {case:?}",
buf.len()
);
}
}
#[test]
fn float_order_neg_before_pos() {
let s = TableSchema {
table_name: "t".into(),
columns: vec![ColumnDef {
name: "f".into(),
col_type: ColumnType::Float,
nullable: false,
..Default::default()
}],
primary_key: vec![0],
..Default::default()
};
let neg =
InternalKey::encode(&[FieldValue::Float(-1.0)], SeqNum(0), OpType::Put, &s).unwrap();
let pos =
InternalKey::encode(&[FieldValue::Float(1.0)], SeqNum(0), OpType::Put, &s).unwrap();
assert!(neg < pos);
}
#[test]
fn nan_float_pk_rejected() {
let s = TableSchema {
table_name: "t".into(),
columns: vec![ColumnDef {
name: "f".into(),
col_type: ColumnType::Float,
nullable: false,
..Default::default()
}],
primary_key: vec![0],
..Default::default()
};
let err = InternalKey::encode(&[FieldValue::Float(f32::NAN)], SeqNum(1), OpType::Put, &s)
.unwrap_err();
match err {
MeruError::InvalidArgument(msg) => assert!(msg.contains("NaN"), "msg: {msg}"),
other => panic!("expected InvalidArgument, got {other:?}"),
}
}
#[test]
fn nan_double_pk_rejected() {
let s = TableSchema {
table_name: "t".into(),
columns: vec![ColumnDef {
name: "d".into(),
col_type: ColumnType::Double,
nullable: false,
..Default::default()
}],
primary_key: vec![0],
..Default::default()
};
let err = InternalKey::encode(&[FieldValue::Double(f64::NAN)], SeqNum(1), OpType::Put, &s)
.unwrap_err();
match err {
MeruError::InvalidArgument(msg) => assert!(msg.contains("NaN"), "msg: {msg}"),
other => panic!("expected InvalidArgument, got {other:?}"),
}
}
#[test]
fn non_nan_floats_encode_fine() {
let s = TableSchema {
table_name: "t".into(),
columns: vec![ColumnDef {
name: "d".into(),
col_type: ColumnType::Double,
nullable: false,
..Default::default()
}],
primary_key: vec![0],
..Default::default()
};
for v in [
0.0_f64,
-0.0,
1.0,
-1.0,
f64::INFINITY,
f64::NEG_INFINITY,
f64::MIN_POSITIVE,
] {
InternalKey::encode(&[FieldValue::Double(v)], SeqNum(1), OpType::Put, &s)
.unwrap_or_else(|e| panic!("encoding {v} should succeed: {e}"));
}
}
#[test]
fn seqnum_max_roundtrip() {
let s = int64_schema();
let pk = vec![FieldValue::Int64(0)];
let k = InternalKey::encode(&pk, SEQNUM_MAX, OpType::Put, &s).unwrap();
let d = InternalKey::decode(k.as_bytes(), &s).unwrap();
assert_eq!(d.seq, SEQNUM_MAX);
}
}