use std::sync::Arc;
use redb::{Database, ReadableTable, TableDefinition, WriteTransaction};
use serde::{Deserialize, Serialize};
const COLUMN_STATS: TableDefinition<&str, &[u8]> = TableDefinition::new("column_stats");
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ColumnStats {
pub row_count: u64,
pub non_null_count: u64,
pub null_count: u64,
pub distinct_count: u64,
pub min_value: Option<String>,
pub max_value: Option<String>,
pub hll_registers: Vec<u8>,
}
pub(crate) const DEFAULT_HLL_M: usize = 256;
pub(crate) const DEFAULT_HLL_P: u32 = 8;
impl ColumnStats {
pub fn new() -> Self {
Self {
row_count: 0,
non_null_count: 0,
null_count: 0,
distinct_count: 0,
min_value: None,
max_value: None,
hll_registers: vec![0u8; DEFAULT_HLL_M],
}
}
pub fn observe(&mut self, value: Option<&serde_json::Value>) {
self.row_count += 1;
match value {
None | Some(serde_json::Value::Null) => {
self.null_count += 1;
}
Some(val) => {
self.non_null_count += 1;
let val_str = match val {
serde_json::Value::String(s) => s.clone(),
other => other.to_string(),
};
match &self.min_value {
None => self.min_value = Some(val_str.clone()),
Some(min) if val_str < *min => self.min_value = Some(val_str.clone()),
_ => {}
}
match &self.max_value {
None => self.max_value = Some(val_str.clone()),
Some(max) if val_str > *max => self.max_value = Some(val_str.clone()),
_ => {}
}
let hash = Self::hash_value(&val_str);
let register_idx = (hash as usize) & (DEFAULT_HLL_M - 1);
let remaining = hash >> DEFAULT_HLL_P;
let leading_zeros = if remaining == 0 {
(64 - DEFAULT_HLL_P) as u8
} else {
remaining.trailing_zeros() as u8 + 1
};
if leading_zeros > self.hll_registers[register_idx] {
self.hll_registers[register_idx] = leading_zeros;
}
self.distinct_count = self.hll_estimate();
}
}
}
fn hll_estimate(&self) -> u64 {
let m = self.hll_registers.len() as f64;
let alpha = 0.7213 / (1.0 + 1.079 / m);
let raw: f64 = alpha * m * m
/ self
.hll_registers
.iter()
.map(|&r| 2.0_f64.powi(-(r as i32)))
.sum::<f64>();
if raw <= 2.5 * m {
let zeros = self.hll_registers.iter().filter(|&&r| r == 0).count() as f64;
if zeros > 0.0 {
(m * (m / zeros).ln()) as u64
} else {
raw as u64
}
} else {
raw as u64
}
}
fn hash_value(s: &str) -> u64 {
let mut hash: u64 = 0xcbf29ce484222325;
for byte in s.as_bytes() {
hash ^= *byte as u64;
hash = hash.wrapping_mul(0x100000001b3);
}
hash
}
pub fn eq_selectivity(&self) -> f64 {
if self.distinct_count == 0 {
1.0
} else {
1.0 / self.distinct_count as f64
}
}
pub fn range_selectivity(&self) -> f64 {
0.33
}
}
impl Default for ColumnStats {
fn default() -> Self {
Self::new()
}
}
pub struct StatsStore {
db: Arc<Database>,
}
impl StatsStore {
pub fn open(db: Arc<Database>) -> crate::Result<Self> {
let write_txn = db.begin_write().map_err(|e| crate::Error::Storage {
engine: "stats".into(),
detail: format!("open write txn: {e}"),
})?;
{
let _ = write_txn.open_table(COLUMN_STATS);
}
write_txn.commit().map_err(|e| crate::Error::Storage {
engine: "stats".into(),
detail: format!("commit: {e}"),
})?;
Ok(Self { db })
}
pub fn get(
&self,
tenant_id: u32,
collection: &str,
field: &str,
) -> crate::Result<Option<ColumnStats>> {
let key = format!("{tenant_id}:{collection}:{field}");
let read_txn = self.db.begin_read().map_err(|e| crate::Error::Storage {
engine: "stats".into(),
detail: format!("read txn: {e}"),
})?;
let table = read_txn
.open_table(COLUMN_STATS)
.map_err(|e| crate::Error::Storage {
engine: "stats".into(),
detail: format!("open table: {e}"),
})?;
match table.get(key.as_str()) {
Ok(Some(guard)) => {
let bytes = guard.value();
let stats: ColumnStats =
rmp_serde::from_slice(bytes).map_err(|e| crate::Error::Storage {
engine: "stats".into(),
detail: format!("deserialize: {e}"),
})?;
Ok(Some(stats))
}
Ok(None) => Ok(None),
Err(e) => Err(crate::Error::Storage {
engine: "stats".into(),
detail: format!("get: {e}"),
}),
}
}
pub fn put(
&self,
tenant_id: u32,
collection: &str,
field: &str,
stats: &ColumnStats,
) -> crate::Result<()> {
let key = format!("{tenant_id}:{collection}:{field}");
let bytes = rmp_serde::to_vec_named(stats).map_err(|e| crate::Error::Storage {
engine: "stats".into(),
detail: format!("serialize: {e}"),
})?;
let write_txn = self.db.begin_write().map_err(|e| crate::Error::Storage {
engine: "stats".into(),
detail: format!("write txn: {e}"),
})?;
{
let mut table =
write_txn
.open_table(COLUMN_STATS)
.map_err(|e| crate::Error::Storage {
engine: "stats".into(),
detail: format!("open table: {e}"),
})?;
table
.insert(key.as_str(), bytes.as_slice())
.map_err(|e| crate::Error::Storage {
engine: "stats".into(),
detail: format!("insert: {e}"),
})?;
}
write_txn.commit().map_err(|e| crate::Error::Storage {
engine: "stats".into(),
detail: format!("commit: {e}"),
})?;
Ok(())
}
pub fn observe_document(
&self,
tenant_id: u32,
collection: &str,
doc: &serde_json::Value,
) -> crate::Result<()> {
if let Some(obj) = doc.as_object() {
for (field, value) in obj {
let mut stats = self.get(tenant_id, collection, field)?.unwrap_or_default();
stats.observe(Some(value));
self.put(tenant_id, collection, field, &stats)?;
}
}
Ok(())
}
pub fn observe_document_in_txn(
&self,
txn: &WriteTransaction,
tenant_id: u32,
collection: &str,
doc: &serde_json::Value,
) -> crate::Result<()> {
let Some(obj) = doc.as_object() else {
return Ok(());
};
if obj.is_empty() {
return Ok(());
}
let mut table = txn
.open_table(COLUMN_STATS)
.map_err(|e| crate::Error::Storage {
engine: "stats".into(),
detail: format!("open table: {e}"),
})?;
for (field, value) in obj {
let key = format!("{tenant_id}:{collection}:{field}");
let mut stats: ColumnStats = table
.get(key.as_str())
.ok()
.flatten()
.and_then(|guard| rmp_serde::from_slice(guard.value()).ok())
.unwrap_or_default();
stats.observe(Some(value));
let bytes = rmp_serde::to_vec_named(&stats).map_err(|e| crate::Error::Storage {
engine: "stats".into(),
detail: format!("serialize: {e}"),
})?;
table
.insert(key.as_str(), bytes.as_slice())
.map_err(|e| crate::Error::Storage {
engine: "stats".into(),
detail: format!("insert: {e}"),
})?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn hll_cardinality_estimate() {
let mut stats = ColumnStats::new();
for i in 0..1000 {
stats.observe(Some(&serde_json::Value::String(format!("value_{i}"))));
}
assert!(
stats.distinct_count > 700,
"too low: {}",
stats.distinct_count
);
assert!(
stats.distinct_count < 1400,
"too high: {}",
stats.distinct_count
);
}
#[test]
fn min_max_tracking() {
let mut stats = ColumnStats::new();
for v in &["charlie", "alice", "bob"] {
stats.observe(Some(&serde_json::Value::String(v.to_string())));
}
assert_eq!(stats.min_value.as_deref(), Some("alice"));
assert_eq!(stats.max_value.as_deref(), Some("charlie"));
assert_eq!(stats.non_null_count, 3);
assert_eq!(stats.null_count, 0);
}
#[test]
fn null_tracking() {
let mut stats = ColumnStats::new();
stats.observe(None);
stats.observe(Some(&serde_json::Value::Null));
stats.observe(Some(&serde_json::Value::String("val".into())));
assert_eq!(stats.null_count, 2);
assert_eq!(stats.non_null_count, 1);
assert_eq!(stats.row_count, 3);
}
#[test]
fn eq_selectivity() {
let mut stats = ColumnStats::new();
for i in 0..100 {
stats.observe(Some(&serde_json::Value::String(format!("v{i}"))));
}
let sel = stats.eq_selectivity();
assert!(sel > 0.005 && sel < 0.02, "selectivity: {sel}");
}
#[test]
fn stats_store_roundtrip() {
let dir = tempfile::tempdir().unwrap();
let db = Arc::new(Database::create(dir.path().join("stats.redb")).unwrap());
let store = StatsStore::open(db).unwrap();
let mut stats = ColumnStats::new();
stats.observe(Some(&serde_json::Value::String("hello".into())));
store.put(1, "users", "name", &stats).unwrap();
let loaded = store.get(1, "users", "name").unwrap().unwrap();
assert_eq!(loaded.row_count, 1);
assert_eq!(loaded.non_null_count, 1);
assert_eq!(loaded.min_value, Some("hello".to_string()));
}
}