use crate::{Error, Result};
pub(crate) const TAG_INSERT: u8 = 0;
pub(crate) const TAG_REMOVE: u8 = 1;
pub(crate) const TAG_NAMESPACE_NAME: u8 = 2;
pub(crate) const TAG_ENCRYPTED_FLAG: u8 = 0x80;
pub(crate) const TAG_KIND_MASK: u8 = 0x7F;
pub(crate) const NONCE_LEN: usize = 12;
pub(crate) const TAG_LEN: usize = 16;
#[derive(Debug)]
pub(crate) enum RecordView<'a> {
Insert {
ns_id: u32,
key: &'a [u8],
value: &'a [u8],
expires_at: u64,
},
Remove {
ns_id: u32,
key: &'a [u8],
},
NamespaceName {
ns_id: u32,
name: &'a [u8],
},
}
#[derive(Debug)]
pub(crate) enum OwnedRecord {
Insert {
ns_id: u32,
key: Vec<u8>,
value: Vec<u8>,
expires_at: u64,
},
Remove {
ns_id: u32,
key: Vec<u8>,
},
NamespaceName {
ns_id: u32,
name: Vec<u8>,
},
}
impl OwnedRecord {
pub(crate) fn ns_id(&self) -> u32 {
match self {
Self::Insert { ns_id, .. }
| Self::Remove { ns_id, .. }
| Self::NamespaceName { ns_id, .. } => *ns_id,
}
}
}
#[inline]
pub(crate) fn write_u32(buf: &mut Vec<u8>, value: u32) {
buf.extend_from_slice(&value.to_le_bytes());
}
#[inline]
pub(crate) fn write_u64(buf: &mut Vec<u8>, value: u64) {
buf.extend_from_slice(&value.to_le_bytes());
}
#[inline]
pub(crate) fn read_u32(bytes: &[u8], offset: usize) -> Result<u32> {
if offset + 4 > bytes.len() {
return Err(Error::Corrupted {
offset: offset as u64,
reason: "u32 read past end of buffer",
});
}
let mut buf = [0_u8; 4];
buf.copy_from_slice(&bytes[offset..offset + 4]);
Ok(u32::from_le_bytes(buf))
}
#[inline]
pub(crate) fn read_u64(bytes: &[u8], offset: usize) -> Result<u64> {
if offset + 8 > bytes.len() {
return Err(Error::Corrupted {
offset: offset as u64,
reason: "u64 read past end of buffer",
});
}
let mut buf = [0_u8; 8];
buf.copy_from_slice(&bytes[offset..offset + 8]);
Ok(u64::from_le_bytes(buf))
}
pub(crate) fn encode_insert_body(
out: &mut Vec<u8>,
ns_id: u32,
key: &[u8],
value: &[u8],
expires_at: u64,
) {
write_u32(out, ns_id);
write_u32(out, key.len() as u32);
out.extend_from_slice(key);
write_u32(out, value.len() as u32);
out.extend_from_slice(value);
write_u64(out, expires_at);
}
pub(crate) fn encode_remove_body(out: &mut Vec<u8>, ns_id: u32, key: &[u8]) {
write_u32(out, ns_id);
write_u32(out, key.len() as u32);
out.extend_from_slice(key);
}
pub(crate) fn encode_namespace_name_body(out: &mut Vec<u8>, ns_id: u32, name: &[u8]) {
write_u32(out, ns_id);
write_u32(out, name.len() as u32);
out.extend_from_slice(name);
}
pub(crate) fn decode_insert_body(body: &[u8]) -> Result<RecordView<'_>> {
let ns_id = read_u32(body, 0)?;
let key_len = read_u32(body, 4)? as usize;
let key_end = 8 + key_len;
if key_end > body.len() {
return Err(Error::Corrupted {
offset: 8,
reason: "insert body truncated mid-key",
});
}
let key = &body[8..key_end];
let value_len = read_u32(body, key_end)? as usize;
let value_start = key_end + 4;
let value_end = value_start + value_len;
if value_end > body.len() {
return Err(Error::Corrupted {
offset: value_start as u64,
reason: "insert body truncated mid-value",
});
}
let value = &body[value_start..value_end];
let expires_at = read_u64(body, value_end)?;
Ok(RecordView::Insert {
ns_id,
key,
value,
expires_at,
})
}
pub(crate) fn decode_remove_body(body: &[u8]) -> Result<RecordView<'_>> {
let ns_id = read_u32(body, 0)?;
let key_len = read_u32(body, 4)? as usize;
let key_end = 8 + key_len;
if key_end > body.len() {
return Err(Error::Corrupted {
offset: 8,
reason: "remove body truncated mid-key",
});
}
let key = &body[8..key_end];
Ok(RecordView::Remove { ns_id, key })
}
pub(crate) fn decode_namespace_name_body(body: &[u8]) -> Result<RecordView<'_>> {
let ns_id = read_u32(body, 0)?;
let name_len = read_u32(body, 4)? as usize;
let name_end = 8 + name_len;
if name_end > body.len() {
return Err(Error::Corrupted {
offset: 8,
reason: "namespace-name body truncated mid-name",
});
}
let name = &body[8..name_end];
Ok(RecordView::NamespaceName { ns_id, name })
}
pub(crate) fn decode_payload(payload: &[u8]) -> Result<RecordView<'_>> {
if payload.is_empty() {
return Err(Error::Corrupted {
offset: 0,
reason: "empty record payload",
});
}
let tag = payload[0];
if (tag & TAG_ENCRYPTED_FLAG) != 0 {
return Err(Error::Corrupted {
offset: 0,
reason: "encrypted record passed to plaintext decoder",
});
}
let body = &payload[1..];
match tag & TAG_KIND_MASK {
TAG_INSERT => decode_insert_body(body),
TAG_REMOVE => decode_remove_body(body),
TAG_NAMESPACE_NAME => decode_namespace_name_body(body),
unknown => Err(Error::Corrupted {
offset: 0,
reason: kind_error_for(unknown),
}),
}
}
pub(crate) fn decode_payload_encrypted<F>(payload: &[u8], decrypt: F) -> Result<OwnedRecord>
where
F: FnOnce(&[u8; NONCE_LEN], &[u8]) -> Result<Vec<u8>>,
{
if payload.len() < 1 + NONCE_LEN + TAG_LEN {
return Err(Error::Corrupted {
offset: 0,
reason: "encrypted payload shorter than nonce + AEAD tag",
});
}
let tag = payload[0];
if (tag & TAG_ENCRYPTED_FLAG) == 0 {
return Err(Error::Corrupted {
offset: 0,
reason: "plaintext record passed to encrypted decoder",
});
}
let kind = tag & TAG_KIND_MASK;
let mut nonce = [0_u8; NONCE_LEN];
nonce.copy_from_slice(&payload[1..1 + NONCE_LEN]);
let ciphertext = &payload[1 + NONCE_LEN..];
let plaintext = decrypt(&nonce, ciphertext)?;
match kind {
TAG_INSERT => match decode_insert_body(&plaintext)? {
RecordView::Insert {
ns_id,
key,
value,
expires_at,
} => Ok(OwnedRecord::Insert {
ns_id,
key: key.to_vec(),
value: value.to_vec(),
expires_at,
}),
_ => Err(Error::Corrupted {
offset: 0,
reason: "encrypted body shape mismatched its tag",
}),
},
TAG_REMOVE => match decode_remove_body(&plaintext)? {
RecordView::Remove { ns_id, key } => Ok(OwnedRecord::Remove {
ns_id,
key: key.to_vec(),
}),
_ => Err(Error::Corrupted {
offset: 0,
reason: "encrypted body shape mismatched its tag",
}),
},
TAG_NAMESPACE_NAME => match decode_namespace_name_body(&plaintext)? {
RecordView::NamespaceName { ns_id, name } => Ok(OwnedRecord::NamespaceName {
ns_id,
name: name.to_vec(),
}),
_ => Err(Error::Corrupted {
offset: 0,
reason: "encrypted body shape mismatched its tag",
}),
},
unknown => Err(Error::Corrupted {
offset: 0,
reason: kind_error_for(unknown),
}),
}
}
pub(crate) fn payload_len_at(bytes: &[u8], payload_start: usize) -> Result<usize> {
if payload_start < 4 {
return Err(Error::Corrupted {
offset: payload_start as u64,
reason: "payload_start within frame header",
});
}
if payload_start > bytes.len() {
return Err(Error::Corrupted {
offset: payload_start as u64,
reason: "payload_start past buffer end",
});
}
Ok(read_u32(bytes, payload_start - 4)? as usize)
}
pub(crate) fn payload_at(bytes: &[u8], payload_start: usize) -> Result<&[u8]> {
let len = payload_len_at(bytes, payload_start)?;
let end = payload_start.checked_add(len).ok_or(Error::Corrupted {
offset: payload_start as u64,
reason: "payload_start + length overflowed",
})?;
if end > bytes.len() {
return Err(Error::Corrupted {
offset: payload_start as u64,
reason: "payload extends past buffer end",
});
}
Ok(&bytes[payload_start..end])
}
#[inline]
fn kind_error_for(_kind: u8) -> &'static str {
"unknown record tag kind"
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn insert_body_round_trips() {
let mut body = Vec::new();
encode_insert_body(&mut body, 7, b"key-bytes", b"value-bytes", 12345);
match decode_insert_body(&body).expect("decode") {
RecordView::Insert {
ns_id,
key,
value,
expires_at,
} => {
assert_eq!(ns_id, 7);
assert_eq!(key, b"key-bytes");
assert_eq!(value, b"value-bytes");
assert_eq!(expires_at, 12345);
}
_ => panic!("expected Insert"),
}
}
#[test]
fn payload_round_trips_via_decode_payload() {
let mut payload = vec![TAG_INSERT];
encode_insert_body(&mut payload, 0, b"k", b"v", 0);
match decode_payload(&payload).expect("decode") {
RecordView::Insert {
ns_id,
key,
value,
expires_at,
} => {
assert_eq!(ns_id, 0);
assert_eq!(key, b"k");
assert_eq!(value, b"v");
assert_eq!(expires_at, 0);
}
_ => panic!("expected Insert"),
}
}
#[test]
fn empty_payload_errors() {
let result = decode_payload(&[]);
assert!(matches!(result, Err(Error::Corrupted { .. })));
}
#[test]
fn unknown_tag_errors() {
let result = decode_payload(&[0x42_u8, 0, 0, 0, 0, 0, 0, 0, 0]);
assert!(matches!(result, Err(Error::Corrupted { .. })));
}
#[test]
fn encrypted_tag_to_plaintext_decoder_errors() {
let payload = vec![TAG_INSERT | TAG_ENCRYPTED_FLAG];
let result = decode_payload(&payload);
assert!(matches!(result, Err(Error::Corrupted { .. })));
}
#[test]
fn payload_at_handles_basic_geometry() {
let mut frame = Vec::new();
frame.extend_from_slice(&0x4653_5901_u32.to_be_bytes()); frame.extend_from_slice(&5_u32.to_le_bytes()); frame.extend_from_slice(b"hello"); frame.extend_from_slice(&0_u32.to_le_bytes());
let payload_start = 8;
let payload = payload_at(&frame, payload_start).expect("payload_at");
assert_eq!(payload, b"hello");
assert_eq!(payload_len_at(&frame, payload_start).expect("len"), 5);
}
}