use crate::error::CodecError;
use crate::CODEC_SAMPLE_SIZE;
pub fn encode(values: &[f64]) -> Result<Vec<u8>, CodecError> {
let count = values.len() as u32;
if values.is_empty() {
let mut out = Vec::with_capacity(7);
out.extend_from_slice(&0u32.to_le_bytes());
out.push(0); out.extend_from_slice(&0u16.to_le_bytes());
return Ok(out);
}
let cut = find_best_cut(values);
let bits: Vec<u64> = values.iter().map(|v| v.to_bits()).collect();
let front_mask: u64 = if cut == 64 { 0 } else { u64::MAX << cut };
let tail_mask: u64 = if cut == 0 { 0 } else { (1u64 << cut) - 1 };
let tail_bytes_per_value = (cut as usize).div_ceil(8);
let fronts: Vec<u64> = bits.iter().map(|&b| (b & front_mask) >> cut).collect();
let mut dict: Vec<u64> = fronts.clone();
dict.sort_unstable();
dict.dedup();
let indices: Vec<u16> = fronts
.iter()
.map(|f| {
dict.binary_search(f)
.map(|idx| idx as u16)
.map_err(|_| CodecError::Corrupt {
detail: "ALP-RD front value missing from dictionary".into(),
})
})
.collect::<Result<_, _>>()?;
let dict_size = dict.len() as u16;
let use_u8_indices = dict.len() <= 256;
let mut out = Vec::with_capacity(
7 + dict.len() * 8
+ values.len() * if use_u8_indices { 1 } else { 2 }
+ values.len() * tail_bytes_per_value,
);
out.extend_from_slice(&count.to_le_bytes());
out.push(cut);
out.extend_from_slice(&dict_size.to_le_bytes());
for &entry in &dict {
out.extend_from_slice(&entry.to_le_bytes());
}
if use_u8_indices {
for &idx in &indices {
out.push(idx as u8);
}
} else {
for &idx in &indices {
out.extend_from_slice(&idx.to_le_bytes());
}
}
for &b in &bits {
let tail = b & tail_mask;
for byte_idx in 0..tail_bytes_per_value {
out.push((tail >> (byte_idx * 8)) as u8);
}
}
Ok(out)
}
pub fn decode(data: &[u8]) -> Result<Vec<f64>, CodecError> {
if data.len() < 7 {
return Err(CodecError::Truncated {
expected: 7,
actual: data.len(),
});
}
let count = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
let cut = data[4];
let dict_size = u16::from_le_bytes([data[5], data[6]]) as usize;
if count == 0 {
return Ok(Vec::new());
}
if cut > 64 {
return Err(CodecError::Corrupt {
detail: format!("invalid ALP-RD cut position: {cut}"),
});
}
let tail_bytes_per_value = (cut as usize).div_ceil(8);
let tail_mask: u64 = if cut == 0 { 0 } else { (1u64 << cut) - 1 };
let use_u8_indices = dict_size <= 256;
let mut pos = 7;
let dict_bytes = dict_size * 8;
if pos + dict_bytes > data.len() {
return Err(CodecError::Truncated {
expected: pos + dict_bytes,
actual: data.len(),
});
}
let mut dict = Vec::with_capacity(dict_size);
for _ in 0..dict_size {
dict.push(u64::from_le_bytes([
data[pos],
data[pos + 1],
data[pos + 2],
data[pos + 3],
data[pos + 4],
data[pos + 5],
data[pos + 6],
data[pos + 7],
]));
pos += 8;
}
let index_bytes = count * if use_u8_indices { 1 } else { 2 };
if pos + index_bytes > data.len() {
return Err(CodecError::Truncated {
expected: pos + index_bytes,
actual: data.len(),
});
}
let mut indices = Vec::with_capacity(count);
if use_u8_indices {
for i in 0..count {
indices.push(data[pos + i] as usize);
}
pos += count;
} else {
for i in 0..count {
let idx_pos = pos + i * 2;
indices.push(u16::from_le_bytes([data[idx_pos], data[idx_pos + 1]]) as usize);
}
pos += count * 2;
}
let tail_total = count * tail_bytes_per_value;
if pos + tail_total > data.len() {
return Err(CodecError::Truncated {
expected: pos + tail_total,
actual: data.len(),
});
}
let mut values = Vec::with_capacity(count);
for (i, &idx) in indices.iter().enumerate() {
if idx >= dict.len() {
return Err(CodecError::Corrupt {
detail: format!("ALP-RD dict index {idx} out of range (max {})", dict.len()),
});
}
let front = dict[idx] << cut;
let mut tail = 0u64;
let tail_pos = pos + i * tail_bytes_per_value;
for byte_idx in 0..tail_bytes_per_value {
tail |= (data[tail_pos + byte_idx] as u64) << (byte_idx * 8);
}
tail &= tail_mask;
values.push(f64::from_bits(front | tail));
}
Ok(values)
}
fn find_best_cut(values: &[f64]) -> u8 {
let sample_end = values.len().min(CODEC_SAMPLE_SIZE);
let sample = &values[..sample_end];
let bits: Vec<u64> = sample.iter().map(|v| v.to_bits()).collect();
let mut best_cut = 48u8;
let mut best_unique = usize::MAX;
for cut in 40..=56 {
let mut fronts: Vec<u64> = bits.iter().map(|&b| b >> cut).collect();
fronts.sort_unstable();
fronts.dedup();
if fronts.len() < best_unique {
best_unique = fronts.len();
best_cut = cut;
}
}
best_cut
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_roundtrip() {
let encoded = encode(&[]).unwrap();
let decoded = decode(&encoded).unwrap();
assert!(decoded.is_empty());
}
#[test]
fn pi_multiples() {
let values: Vec<f64> = (1..1000).map(|i| std::f64::consts::PI * i as f64).collect();
let encoded = encode(&values).unwrap();
let decoded = decode(&encoded).unwrap();
for (i, (a, b)) in values.iter().zip(decoded.iter()).enumerate() {
assert_eq!(a.to_bits(), b.to_bits(), "mismatch at {i}");
}
}
#[test]
fn scientific_data() {
let values: Vec<f64> = (0..1000).map(|i| (i as f64 * 0.001).exp()).collect();
let encoded = encode(&values).unwrap();
let decoded = decode(&encoded).unwrap();
for (i, (a, b)) in values.iter().zip(decoded.iter()).enumerate() {
assert_eq!(a.to_bits(), b.to_bits(), "mismatch at {i}");
}
}
#[test]
fn compression_ratio() {
let values: Vec<f64> = (1..10_000)
.map(|i| std::f64::consts::E * i as f64 + (i as f64).sqrt())
.collect();
let encoded = encode(&values).unwrap();
let raw_size = values.len() * 8;
let ratio = raw_size as f64 / encoded.len() as f64;
assert!(
ratio >= 0.95,
"ALP-RD should not expand >5%, got {ratio:.2}x"
);
let decoded = decode(&encoded).unwrap();
for (a, b) in values.iter().zip(decoded.iter()) {
assert_eq!(a.to_bits(), b.to_bits());
}
}
#[test]
fn special_values() {
let values = vec![
0.0,
-0.0,
f64::INFINITY,
f64::NEG_INFINITY,
f64::NAN,
f64::MIN,
f64::MAX,
f64::MIN_POSITIVE,
];
let encoded = encode(&values).unwrap();
let decoded = decode(&encoded).unwrap();
for (i, (a, b)) in values.iter().zip(decoded.iter()).enumerate() {
assert_eq!(a.to_bits(), b.to_bits(), "mismatch at {i}");
}
}
#[test]
fn identical_values() {
let values = vec![42.0f64; 1000];
let encoded = encode(&values).unwrap();
let decoded = decode(&encoded).unwrap();
for (a, b) in values.iter().zip(decoded.iter()) {
assert_eq!(a.to_bits(), b.to_bits());
}
assert!(encoded.len() < values.len() * 8);
}
#[test]
fn truncated_errors() {
assert!(decode(&[]).is_err());
assert!(decode(&[1, 0, 0, 0, 48, 0]).is_err()); }
}