#![allow(clippy::doc_markdown)]
use alloc::collections::BTreeMap;
use alloc::string::{String, ToString};
use alloc::vec::Vec;
#[derive(Debug, Clone, PartialEq)]
pub struct ColumnStats {
pub null_frac: f32,
pub n_distinct: u64,
pub histogram_bounds: Vec<String>,
}
#[derive(Debug, Clone, PartialEq, Default)]
pub struct Statistics {
inner: BTreeMap<(String, String), ColumnStats>,
modified_since: BTreeMap<String, u64>,
version: u64,
}
#[derive(Debug, PartialEq, Eq)]
pub enum StatisticsError {
Corrupt(String),
}
impl Statistics {
pub fn new() -> Self {
Self::default()
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn get(&self, table: &str, column: &str) -> Option<&ColumnStats> {
self.inner.get(&(table.to_string(), column.to_string()))
}
pub fn iter(&self) -> impl Iterator<Item = (&(String, String), &ColumnStats)> {
self.inner.iter()
}
pub fn set(&mut self, table: String, column: String, stats: ColumnStats) {
self.inner.insert((table, column), stats);
}
pub fn clear_table(&mut self, table: &str) {
self.inner.retain(|(t, _), _| t != table);
}
pub fn reset_modified(&mut self, table: &str) {
self.modified_since.insert(table.to_string(), 0);
}
pub fn record_modifications(&mut self, table: &str, n: u64) {
let entry = self.modified_since.entry(table.to_string()).or_default();
*entry = entry.saturating_add(n);
}
pub fn modified_since_last_analyze(&self, table: &str) -> u64 {
self.modified_since.get(table).copied().unwrap_or(0)
}
pub fn version(&self) -> u64 {
self.version
}
pub fn bump_version(&mut self) {
self.version = self.version.saturating_add(1);
}
pub fn serialize(&self) -> Vec<u8> {
let mut out = Vec::with_capacity(2 + self.inner.len() * 32);
let n = u16::try_from(self.inner.len()).expect("≤ 65,535 column-stats rows");
out.extend_from_slice(&n.to_le_bytes());
for ((table, col), stats) in &self.inner {
write_str(&mut out, table);
write_str(&mut out, col);
out.extend_from_slice(&stats.null_frac.to_le_bytes());
out.extend_from_slice(&stats.n_distinct.to_le_bytes());
let nb =
u16::try_from(stats.histogram_bounds.len()).expect("≤ 65,535 histogram bounds");
out.extend_from_slice(&nb.to_le_bytes());
for b in &stats.histogram_bounds {
write_str(&mut out, b);
}
}
let m = u16::try_from(self.modified_since.len()).expect("≤ 65,535 modified-row counters");
out.extend_from_slice(&m.to_le_bytes());
for (table, count) in &self.modified_since {
write_str(&mut out, table);
out.extend_from_slice(&count.to_le_bytes());
}
out
}
pub fn deserialize(buf: &[u8]) -> Result<Self, StatisticsError> {
let mut p = 0usize;
let n = read_u16(buf, &mut p)? as usize;
let mut inner = BTreeMap::new();
for _ in 0..n {
let table = read_str(buf, &mut p)?;
let col = read_str(buf, &mut p)?;
let null_frac_bytes = read_bytes(buf, &mut p, 4)?;
let null_frac = f32::from_le_bytes(
null_frac_bytes
.try_into()
.map_err(|_| StatisticsError::Corrupt("null_frac slice".to_string()))?,
);
let n_distinct = read_u64(buf, &mut p)?;
let nb = read_u16(buf, &mut p)? as usize;
let mut bounds = Vec::with_capacity(nb);
for _ in 0..nb {
bounds.push(read_str(buf, &mut p)?);
}
if inner
.insert(
(table.clone(), col.clone()),
ColumnStats {
null_frac,
n_distinct,
histogram_bounds: bounds,
},
)
.is_some()
{
return Err(StatisticsError::Corrupt(alloc::format!(
"duplicate spg_statistic key ({table:?}, {col:?})"
)));
}
}
let m = read_u16(buf, &mut p)? as usize;
let mut modified_since = BTreeMap::new();
for _ in 0..m {
let table = read_str(buf, &mut p)?;
let count = read_u64(buf, &mut p)?;
modified_since.insert(table, count);
}
if p != buf.len() {
return Err(StatisticsError::Corrupt(alloc::format!(
"trailing bytes in statistics payload: read {p}, len {}",
buf.len()
)));
}
Ok(Self {
inner,
modified_since,
version: 0,
})
}
}
pub const NUM_BUCKETS: usize = 100;
pub fn build_histogram(sorted_values: &[String]) -> Vec<String> {
if sorted_values.is_empty() {
return Vec::new();
}
let n = sorted_values.len();
if n <= NUM_BUCKETS + 1 {
return sorted_values.to_vec();
}
let mut bounds = Vec::with_capacity(NUM_BUCKETS + 1);
for i in 0..=NUM_BUCKETS {
let idx = (i as u64 * (n as u64 - 1)) / NUM_BUCKETS as u64;
bounds.push(sorted_values[idx as usize].clone());
}
bounds
}
pub fn estimate_n_distinct(sorted_values: &[String]) -> u64 {
if sorted_values.is_empty() {
return 0;
}
let mut count: u64 = 1;
let mut prev = &sorted_values[0];
for v in &sorted_values[1..] {
if v != prev {
count += 1;
prev = v;
}
}
count
}
fn write_str(out: &mut Vec<u8>, s: &str) {
let n = u16::try_from(s.len()).expect("table / column / bound names ≤ 65,535 bytes");
out.extend_from_slice(&n.to_le_bytes());
out.extend_from_slice(s.as_bytes());
}
fn read_bytes<'a>(buf: &'a [u8], p: &mut usize, n: usize) -> Result<&'a [u8], StatisticsError> {
let slice = buf
.get(*p..*p + n)
.ok_or_else(|| StatisticsError::Corrupt(alloc::format!("short read ({n} bytes)")))?;
*p += n;
Ok(slice)
}
fn read_u16(buf: &[u8], p: &mut usize) -> Result<u16, StatisticsError> {
let bytes = read_bytes(buf, p, 2)?;
Ok(u16::from_le_bytes(bytes.try_into().map_err(|_| {
StatisticsError::Corrupt("u16 slice".to_string())
})?))
}
fn read_u64(buf: &[u8], p: &mut usize) -> Result<u64, StatisticsError> {
let bytes = read_bytes(buf, p, 8)?;
Ok(u64::from_le_bytes(bytes.try_into().map_err(|_| {
StatisticsError::Corrupt("u64 slice".to_string())
})?))
}
fn read_str(buf: &[u8], p: &mut usize) -> Result<String, StatisticsError> {
let n = read_u16(buf, p)? as usize;
let slice = read_bytes(buf, p, n)?;
core::str::from_utf8(slice)
.map(ToString::to_string)
.map_err(|e| StatisticsError::Corrupt(alloc::format!("non-UTF-8 str: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
fn mk_cs(null_frac: f32, n_distinct: u64, bounds: &[&str]) -> ColumnStats {
ColumnStats {
null_frac,
n_distinct,
histogram_bounds: bounds.iter().map(|s| s.to_string()).collect(),
}
}
#[test]
fn empty_roundtrips() {
let s = Statistics::new();
let bytes = s.serialize();
let s2 = Statistics::deserialize(&bytes).unwrap();
assert_eq!(s, s2);
}
#[test]
fn single_column_roundtrips() {
let mut s = Statistics::new();
s.set(
"users".into(),
"id".into(),
mk_cs(0.0, 1000, &["1", "500", "1000"]),
);
let s2 = Statistics::deserialize(&s.serialize()).unwrap();
assert_eq!(s, s2);
let got = s2.get("users", "id").unwrap();
assert_eq!(got.n_distinct, 1000);
assert_eq!(got.histogram_bounds.len(), 3);
}
#[test]
fn multi_column_roundtrips_with_modified_counter() {
let mut s = Statistics::new();
s.set(
"users".into(),
"id".into(),
mk_cs(0.0, 100, &["1", "50", "100"]),
);
s.set(
"users".into(),
"name".into(),
mk_cs(0.1, 99, &["alice", "bob", "zoe"]),
);
s.record_modifications("users", 17);
let s2 = Statistics::deserialize(&s.serialize()).unwrap();
assert_eq!(s, s2);
assert_eq!(s2.modified_since_last_analyze("users"), 17);
}
#[test]
fn histogram_bounds_count_is_101_for_100_buckets() {
let vals: Vec<String> = (0..1000).map(|i| alloc::format!("{i:04}")).collect();
let bounds = build_histogram(&vals);
assert_eq!(bounds.len(), 101);
assert_eq!(bounds.first().unwrap(), "0000");
assert_eq!(bounds.last().unwrap(), "0999");
}
#[test]
fn deterministic_serialise_independent_of_insert_order() {
let mut s1 = Statistics::new();
s1.set("z".into(), "c1".into(), mk_cs(0.0, 1, &["x"]));
s1.set("a".into(), "c2".into(), mk_cs(0.0, 1, &["y"]));
let mut s2 = Statistics::new();
s2.set("a".into(), "c2".into(), mk_cs(0.0, 1, &["y"]));
s2.set("z".into(), "c1".into(), mk_cs(0.0, 1, &["x"]));
assert_eq!(s1.serialize(), s2.serialize());
}
#[test]
fn n_distinct_estimator_within_5pct_on_uniform_corpus() {
let mut vals: Vec<String> = Vec::with_capacity(10000);
for i in 0..10000 {
vals.push(alloc::format!("v{}", i % 100));
}
vals.sort();
let est = estimate_n_distinct(&vals);
assert_eq!(est, 100);
}
#[test]
fn clear_table_drops_only_target_rows() {
let mut s = Statistics::new();
s.set("a".into(), "c1".into(), mk_cs(0.0, 1, &["x"]));
s.set("a".into(), "c2".into(), mk_cs(0.0, 1, &["y"]));
s.set("b".into(), "c1".into(), mk_cs(0.0, 1, &["z"]));
s.clear_table("a");
assert_eq!(s.len(), 1);
assert!(s.get("a", "c1").is_none());
assert!(s.get("b", "c1").is_some());
}
#[test]
fn corrupt_short_read_errors() {
let buf = 1u16.to_le_bytes();
let err = Statistics::deserialize(&buf).unwrap_err();
assert!(matches!(err, StatisticsError::Corrupt(_)));
}
#[test]
fn build_histogram_passthrough_when_sample_is_small() {
let vals: Vec<String> = (0..5).map(|i| alloc::format!("v{i}")).collect();
let bounds = build_histogram(&vals);
assert_eq!(bounds.len(), 5);
assert_eq!(bounds[0], "v0");
assert_eq!(bounds[4], "v4");
}
}