use crate::error::{Error, Result};
use crate::schema::registry::ParsingContext;
use crate::storage::sstable::bti::encoder::ByteComparableEncoder;
use crate::types::{ComparatorType, Value};
use murmur3::murmur3_32;
use std::io::Cursor;
pub struct KeyDigestComputer {
encoder: ByteComparableEncoder,
}
impl KeyDigestComputer {
pub fn new() -> Self {
Self {
encoder: ByteComparableEncoder::new(),
}
}
pub fn compute_partition_key_digest(
&mut self,
partition_key_bytes: &[u8],
parsing_context: &ParsingContext,
) -> Result<Vec<u8>> {
let partition_values = self.parse_partition_key_bytes(
partition_key_bytes,
&parsing_context.partition_comparators,
)?;
let byte_comparable_key = self.encoder.encode_composite_key(&partition_values)?;
let mut cursor = Cursor::new(&byte_comparable_key);
let hash = murmur3_32(&mut cursor, 0)
.map_err(|e| Error::corruption(format!("Failed to compute Murmur3 hash: {}", e)))?;
Ok(hash.to_le_bytes().to_vec())
}
fn parse_partition_key_bytes(
&self,
key_bytes: &[u8],
partition_comparators: &[ComparatorType],
) -> Result<Vec<Value>> {
if partition_comparators.is_empty() {
return Err(Error::Schema(
"No partition key comparators provided".to_string(),
));
}
if partition_comparators.len() == 1 {
let value = self.parse_value_bytes(key_bytes, &partition_comparators[0])?;
return Ok(vec![value]);
}
let mut values = Vec::new();
let mut offset = 0;
for (index, comparator) in partition_comparators.iter().enumerate() {
if offset >= key_bytes.len() {
return Err(Error::corruption(
"Insufficient bytes for multi-component partition key".to_string(),
));
}
if offset + 2 > key_bytes.len() {
return Err(Error::corruption(
"Invalid component length in partition key".to_string(),
));
}
let component_len =
u16::from_be_bytes([key_bytes[offset], key_bytes[offset + 1]]) as usize;
offset += 2;
if offset + component_len > key_bytes.len() {
return Err(Error::corruption(
"Component length exceeds available bytes".to_string(),
));
}
let component_bytes = &key_bytes[offset..offset + component_len];
let value = self.parse_value_bytes(component_bytes, comparator)?;
values.push(value);
offset += component_len;
if offset >= key_bytes.len() {
return Err(Error::corruption(
"Missing end-of-component marker in multi-component partition key".to_string(),
));
}
if key_bytes[offset] != 0x00 {
return Err(Error::corruption(
"Invalid end-of-component marker in multi-component partition key".to_string(),
));
}
offset += 1;
if index + 1 == partition_comparators.len() && offset != key_bytes.len() {
return Err(Error::corruption(
"Unexpected trailing bytes in multi-component partition key".to_string(),
));
}
}
Ok(values)
}
fn parse_value_bytes(&self, bytes: &[u8], comparator: &ComparatorType) -> Result<Value> {
match comparator {
ComparatorType::Boolean => {
if bytes.len() != 1 {
return Err(Error::corruption("Invalid boolean bytes".to_string()));
}
Ok(Value::Boolean(bytes[0] != 0))
}
ComparatorType::TinyInt => {
if bytes.len() != 1 {
return Err(Error::corruption("Invalid tinyint bytes".to_string()));
}
Ok(Value::TinyInt(bytes[0] as i8))
}
ComparatorType::SmallInt => {
if bytes.len() != 2 {
return Err(Error::corruption("Invalid smallint bytes".to_string()));
}
let value = i16::from_be_bytes([bytes[0], bytes[1]]);
Ok(Value::SmallInt(value))
}
ComparatorType::Int => {
if bytes.len() != 4 {
return Err(Error::corruption("Invalid int bytes".to_string()));
}
let value = i32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
Ok(Value::Integer(value))
}
ComparatorType::BigInt => {
if bytes.len() != 8 {
return Err(Error::corruption("Invalid bigint bytes".to_string()));
}
let value = i64::from_be_bytes([
bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
]);
Ok(Value::BigInt(value))
}
ComparatorType::Counter => {
if bytes.len() != 8 {
return Err(Error::corruption("Invalid counter bytes".to_string()));
}
let value = i64::from_be_bytes([
bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
]);
Ok(Value::Counter(value))
}
ComparatorType::Float32 => {
if bytes.len() != 4 {
return Err(Error::corruption("Invalid float32 bytes".to_string()));
}
let bits = u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
let value = f32::from_bits(bits);
Ok(Value::Float32(value))
}
ComparatorType::Float => {
if bytes.len() != 8 {
return Err(Error::corruption("Invalid float bytes".to_string()));
}
let bits = u64::from_be_bytes([
bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
]);
let value = f64::from_bits(bits);
Ok(Value::Float(value))
}
ComparatorType::Text => {
let text = String::from_utf8(bytes.to_vec())
.map_err(|e| Error::corruption(format!("Invalid UTF-8 in text: {}", e)))?;
Ok(Value::Text(text))
}
ComparatorType::Blob => Ok(Value::Blob(bytes.to_vec())),
ComparatorType::Timestamp => {
if bytes.len() != 8 {
return Err(Error::corruption("Invalid timestamp bytes".to_string()));
}
let millis = i64::from_be_bytes([
bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
]);
Ok(Value::Timestamp(millis))
}
ComparatorType::Uuid => {
if bytes.len() != 16 {
return Err(Error::corruption("Invalid UUID bytes".to_string()));
}
let uuid_bytes: [u8; 16] = bytes
.try_into()
.map_err(|_| Error::invalid_format("Invalid UUID byte length"))?;
Ok(Value::Uuid(uuid_bytes))
}
ComparatorType::Date => {
if bytes.len() != 4 {
return Err(Error::corruption("Invalid date bytes".to_string()));
}
let stored = u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
let days_since_epoch = stored.wrapping_add(i32::MIN as u32) as i32;
Ok(Value::Date(days_since_epoch))
}
ComparatorType::List(_)
| ComparatorType::Set(_)
| ComparatorType::Map(_, _)
| ComparatorType::Tuple(_)
| ComparatorType::Udt { .. }
| ComparatorType::Frozen(_)
| ComparatorType::Custom(_)
| ComparatorType::Varint
| ComparatorType::Decimal
| ComparatorType::Duration
| ComparatorType::Json => {
log::warn!(
"Complex type {} in partition key - using blob fallback",
comparator.type_name()
);
Ok(Value::Blob(bytes.to_vec()))
}
}
}
pub fn compute_simple_digest(&self, partition_key: &[u8]) -> Result<Vec<u8>> {
let mut cursor = Cursor::new(partition_key);
let hash = murmur3_32(&mut cursor, 0)
.map_err(|e| Error::corruption(format!("Failed to compute Murmur3 hash: {}", e)))?;
Ok(hash.to_le_bytes().to_vec())
}
}
impl Default for KeyDigestComputer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::{KeyColumn, TableSchema};
use std::collections::HashMap;
fn create_test_parsing_context(partition_comparators: Vec<ComparatorType>) -> ParsingContext {
let schema = TableSchema {
keyspace: "test".to_string(),
table: "table".to_string(),
partition_keys: vec![KeyColumn {
name: "pk".to_string(),
data_type: "int".to_string(),
position: 0,
}],
clustering_keys: vec![],
columns: vec![],
comments: HashMap::new(),
};
ParsingContext {
schema,
partition_comparators,
clustering_comparators: vec![],
column_comparators: HashMap::new(),
}
}
#[test]
fn test_single_component_int_key() {
let mut computer = KeyDigestComputer::new();
let context = create_test_parsing_context(vec![ComparatorType::Int]);
let key_bytes = [0x00, 0x00, 0x00, 0x2A];
let digest = computer
.compute_partition_key_digest(&key_bytes, &context)
.unwrap();
assert_eq!(digest.len(), 4);
let digest2 = computer
.compute_partition_key_digest(&key_bytes, &context)
.unwrap();
assert_eq!(digest, digest2);
}
#[test]
fn test_single_component_text_key() {
let mut computer = KeyDigestComputer::new();
let context = create_test_parsing_context(vec![ComparatorType::Text]);
let key_bytes = b"hello";
let digest = computer
.compute_partition_key_digest(key_bytes, &context)
.unwrap();
assert_eq!(digest.len(), 4);
}
#[test]
fn test_multi_component_key() {
let mut computer = KeyDigestComputer::new();
let context = create_test_parsing_context(vec![ComparatorType::Int, ComparatorType::Text]);
let mut key_bytes = Vec::new();
key_bytes.extend_from_slice(&[0x00, 0x04]); key_bytes.extend_from_slice(&[0x00, 0x00, 0x00, 0x2A]); key_bytes.push(0x00); key_bytes.extend_from_slice(&[0x00, 0x05]); key_bytes.extend_from_slice(b"hello"); key_bytes.push(0x00);
let digest = computer
.compute_partition_key_digest(&key_bytes, &context)
.unwrap();
assert_eq!(digest.len(), 4);
}
#[test]
fn test_multi_component_key_rejects_missing_final_separator() {
let mut computer = KeyDigestComputer::new();
let context = create_test_parsing_context(vec![ComparatorType::Int, ComparatorType::Text]);
let mut key_bytes = Vec::new();
key_bytes.extend_from_slice(&[0x00, 0x04]);
key_bytes.extend_from_slice(&[0x00, 0x00, 0x00, 0x2A]);
key_bytes.push(0x00);
key_bytes.extend_from_slice(&[0x00, 0x05]);
key_bytes.extend_from_slice(b"hello");
let err = computer
.compute_partition_key_digest(&key_bytes, &context)
.expect_err("missing final separator must be rejected");
assert!(
err.to_string().contains("Missing end-of-component marker"),
"unexpected error: {err}"
);
}
#[test]
fn test_simple_digest_fallback() -> Result<()> {
let computer = KeyDigestComputer::new();
let key_bytes = b"test_key";
let digest = computer.compute_simple_digest(key_bytes)?;
assert_eq!(digest.len(), 4);
let digest2 = computer.compute_simple_digest(key_bytes)?;
assert_eq!(digest, digest2);
Ok(())
}
#[test]
fn test_byte_ordering_equivalence() {
let mut computer = KeyDigestComputer::new();
let context = create_test_parsing_context(vec![ComparatorType::Int]);
let key1_bytes = [0x00, 0x00, 0x00, 0x01]; let key2_bytes = [0x00, 0x00, 0x00, 0x02];
let digest1 = computer
.compute_partition_key_digest(&key1_bytes, &context)
.unwrap();
let digest2 = computer
.compute_partition_key_digest(&key2_bytes, &context)
.unwrap();
assert_ne!(digest1, digest2);
}
}