use serde::{Deserialize, Serialize};
use crate::sketches::Sketch;
use crate::{Error, Result};
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct CorrelatedHistogram2D {
col_a_bins: u32,
col_b_bins: u32,
a_min: f64,
a_max: f64,
b_min: f64,
b_max: f64,
cells: Vec<u64>,
total: u64,
}
impl CorrelatedHistogram2D {
pub fn try_new(col_a_bins: usize, col_b_bins: usize) -> Result<Self> {
if col_a_bins == 0 || col_b_bins == 0 {
return Err(Error::InvalidSketch(
"CorrelatedHistogram2D bin counts must be > 0".into(),
));
}
if col_a_bins > u32::MAX as usize || col_b_bins > u32::MAX as usize {
return Err(Error::InvalidSketch(
"CorrelatedHistogram2D bin counts must fit in u32".into(),
));
}
let size = col_a_bins
.checked_mul(col_b_bins)
.ok_or_else(|| Error::InvalidSketch("CorrelatedHistogram2D size overflow".into()))?;
Ok(Self {
col_a_bins: col_a_bins as u32,
col_b_bins: col_b_bins as u32,
a_min: 0.0,
a_max: 0.0,
b_min: 0.0,
b_max: 0.0,
cells: vec![0u64; size],
total: 0,
})
}
pub fn new(col_a_bins: usize, col_b_bins: usize) -> Result<Self> {
Self::try_new(col_a_bins, col_b_bins)
}
fn validate(&self) -> Result<()> {
if self.col_a_bins == 0 || self.col_b_bins == 0 {
return Err(Error::InvalidSketch(
"CorrelatedHistogram2D decoded bins must both be > 0".into(),
));
}
let expected = (self.col_a_bins as usize)
.checked_mul(self.col_b_bins as usize)
.ok_or_else(|| {
Error::InvalidSketch("CorrelatedHistogram2D decoded cells overflow".into())
})?;
if self.cells.len() != expected {
return Err(Error::InvalidSketch(format!(
"CorrelatedHistogram2D cells.len() {} != col_a_bins*col_b_bins = {}",
self.cells.len(),
expected
)));
}
for (name, v) in [
("a_min", self.a_min),
("a_max", self.a_max),
("b_min", self.b_min),
("b_max", self.b_max),
] {
if v.is_nan() {
return Err(Error::InvalidSketch(format!(
"CorrelatedHistogram2D {name} is NaN"
)));
}
}
Ok(())
}
pub fn from_pairs(pairs: &[(f64, f64)], col_a_bins: usize, col_b_bins: usize) -> Result<Self> {
let mut h = Self::new(col_a_bins, col_b_bins)?;
if pairs.is_empty() {
return Ok(h);
}
let (mut amin, mut amax) = (f64::INFINITY, f64::NEG_INFINITY);
let (mut bmin, mut bmax) = (f64::INFINITY, f64::NEG_INFINITY);
for &(a, b) in pairs {
if !a.is_finite() || !b.is_finite() {
continue;
}
if a < amin {
amin = a;
}
if a > amax {
amax = a;
}
if b < bmin {
bmin = b;
}
if b > bmax {
bmax = b;
}
}
if !amin.is_finite() || !bmin.is_finite() {
return Ok(h);
}
h.a_min = amin;
h.a_max = amax;
h.b_min = bmin;
h.b_max = bmax;
for &(a, b) in pairs {
if !a.is_finite() || !b.is_finite() {
continue;
}
let i = Self::bucket_index(a, amin, amax, h.col_a_bins);
let j = Self::bucket_index(b, bmin, bmax, h.col_b_bins);
let pos = i * (h.col_b_bins as usize) + j;
h.cells[pos] = h.cells[pos].saturating_add(1);
h.total = h.total.saturating_add(1);
}
Ok(h)
}
fn bucket_index(v: f64, lo: f64, hi: f64, nbins: u32) -> usize {
let n = nbins as usize;
if n == 0 {
return 0;
}
if hi <= lo {
return 0;
}
if v <= lo {
return 0;
}
if v >= hi {
return n - 1;
}
let frac = (v - lo) / (hi - lo);
let idx = (frac * (n as f64)).floor() as usize;
idx.min(n - 1)
}
fn range_bucket_span(
q_lo: f64,
q_hi: f64,
col_min: f64,
col_max: f64,
nbins: u32,
) -> Option<(usize, usize)> {
if q_lo > q_hi {
return None;
}
if nbins == 0 {
return None;
}
if col_max < col_min {
return None;
}
if col_max <= col_min {
if q_lo <= col_min && q_hi >= col_min {
return Some((0, (nbins as usize) - 1));
}
return None;
}
if q_hi < col_min || q_lo > col_max {
return None;
}
let lo_idx = Self::bucket_index(q_lo, col_min, col_max, nbins);
let hi_idx = Self::bucket_index(q_hi, col_min, col_max, nbins);
Some((lo_idx, hi_idx))
}
pub fn estimate_range(&self, a_lo: f64, a_hi: f64, b_lo: f64, b_hi: f64) -> u64 {
if self.total == 0 {
return 0;
}
let a_span =
match Self::range_bucket_span(a_lo, a_hi, self.a_min, self.a_max, self.col_a_bins) {
Some(s) => s,
None => return 0,
};
let b_span =
match Self::range_bucket_span(b_lo, b_hi, self.b_min, self.b_max, self.col_b_bins) {
Some(s) => s,
None => return 0,
};
let bw = self.col_b_bins as usize;
let mut sum: u64 = 0;
for i in a_span.0..=a_span.1 {
let row = i * bw;
for j in b_span.0..=b_span.1 {
sum = sum.saturating_add(self.cells[row + j]);
}
}
sum
}
pub fn merge(&mut self, other: &Self) -> Result<()> {
if self.col_a_bins != other.col_a_bins || self.col_b_bins != other.col_b_bins {
return Err(Error::InvalidSketch(
"CorrelatedHistogram2D bin dimension mismatch in merge".into(),
));
}
if self.total == 0 {
self.a_min = other.a_min;
self.a_max = other.a_max;
self.b_min = other.b_min;
self.b_max = other.b_max;
} else if other.total != 0
&& (self.a_min != other.a_min
|| self.a_max != other.a_max
|| self.b_min != other.b_min
|| self.b_max != other.b_max)
{
return Err(Error::InvalidSketch(
"CorrelatedHistogram2D bin layout mismatch (min/max differ) in merge".into(),
));
}
for (a, b) in self.cells.iter_mut().zip(other.cells.iter()) {
*a = a.saturating_add(*b);
}
self.total = self.total.saturating_add(other.total);
Ok(())
}
pub fn cell_counts(&self) -> &[u64] {
&self.cells
}
pub fn col_a_bins(&self) -> usize {
self.col_a_bins as usize
}
pub fn col_b_bins(&self) -> usize {
self.col_b_bins as usize
}
pub fn total(&self) -> u64 {
self.total
}
pub fn a_range(&self) -> (f64, f64) {
(self.a_min, self.a_max)
}
pub fn b_range(&self) -> (f64, f64) {
(self.b_min, self.b_max)
}
}
impl Sketch for CorrelatedHistogram2D {
const KIND: &'static str = "samkhya.correlated2d-v1";
fn to_bytes(&self) -> Result<Vec<u8>> {
bincode::serialize(self).map_err(Into::into)
}
fn from_bytes(bytes: &[u8]) -> Result<Self> {
let s: Self = bincode::deserialize(bytes).map_err(Error::from)?;
s.validate()?;
Ok(s)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_rejects_zero_bins() {
assert!(CorrelatedHistogram2D::new(0, 16).is_err());
assert!(CorrelatedHistogram2D::new(16, 0).is_err());
}
#[test]
fn full_range_covers_total_input_count() {
let pairs: Vec<(f64, f64)> = (0..1000).map(|i| (i as f64, (i % 50) as f64)).collect();
let h = CorrelatedHistogram2D::from_pairs(&pairs, 16, 16).unwrap();
assert_eq!(h.total(), 1000);
let (a_lo, a_hi) = h.a_range();
let (b_lo, b_hi) = h.b_range();
assert_eq!(h.estimate_range(a_lo, a_hi, b_lo, b_hi), 1000);
assert_eq!(
h.estimate_range(a_lo - 100.0, a_hi + 100.0, b_lo - 100.0, b_hi + 100.0),
1000
);
}
#[test]
fn widening_either_dimension_never_decreases() {
let pairs: Vec<(f64, f64)> = (0..500)
.map(|i| ((i as f64) * 0.7, ((i * 3) % 73) as f64))
.collect();
let h = CorrelatedHistogram2D::from_pairs(&pairs, 16, 16).unwrap();
let base = h.estimate_range(50.0, 150.0, 10.0, 40.0);
let wider_a_lo = h.estimate_range(20.0, 150.0, 10.0, 40.0);
let wider_a_hi = h.estimate_range(50.0, 200.0, 10.0, 40.0);
let wider_b_lo = h.estimate_range(50.0, 150.0, 0.0, 40.0);
let wider_b_hi = h.estimate_range(50.0, 150.0, 10.0, 80.0);
assert!(wider_a_lo >= base);
assert!(wider_a_hi >= base);
assert!(wider_b_lo >= base);
assert!(wider_b_hi >= base);
}
#[test]
fn empty_range_returns_zero() {
let pairs: Vec<(f64, f64)> = (0..100).map(|i| (i as f64, i as f64)).collect();
let h = CorrelatedHistogram2D::from_pairs(&pairs, 8, 8).unwrap();
assert_eq!(h.estimate_range(50.0, 10.0, 0.0, 100.0), 0);
assert_eq!(h.estimate_range(200.0, 300.0, 0.0, 100.0), 0);
assert_eq!(h.estimate_range(0.0, 100.0, 500.0, 600.0), 0);
}
#[test]
fn empty_pairs_handled() {
let h = CorrelatedHistogram2D::from_pairs(&[], 8, 8).unwrap();
assert_eq!(h.total(), 0);
assert_eq!(h.estimate_range(0.0, 100.0, 0.0, 100.0), 0);
}
#[test]
fn round_trip_preserves_cells() {
let pairs: Vec<(f64, f64)> = (0..400)
.map(|i| ((i % 20) as f64, (i / 20) as f64))
.collect();
let h = CorrelatedHistogram2D::from_pairs(&pairs, 8, 8).unwrap();
let bytes = h.to_bytes().unwrap();
let h2 = CorrelatedHistogram2D::from_bytes(&bytes).unwrap();
assert_eq!(h.cells, h2.cells);
assert_eq!(h.total, h2.total);
assert_eq!(h.col_a_bins, h2.col_a_bins);
assert_eq!(h.col_b_bins, h2.col_b_bins);
assert_eq!(h.a_min, h2.a_min);
assert_eq!(h.a_max, h2.a_max);
assert_eq!(h.b_min, h2.b_min);
assert_eq!(h.b_max, h2.b_max);
}
#[test]
fn merge_combines_compatible_grids() {
let pairs_a: Vec<(f64, f64)> = (0..100).map(|i| (i as f64, (i % 10) as f64)).collect();
let pairs_b: Vec<(f64, f64)> = (0..100).map(|i| (i as f64, (i % 10) as f64)).collect();
let mut h1 = CorrelatedHistogram2D::from_pairs(&pairs_a, 8, 8).unwrap();
let h2 = CorrelatedHistogram2D::from_pairs(&pairs_b, 8, 8).unwrap();
h1.merge(&h2).unwrap();
assert_eq!(h1.total(), 200);
}
#[test]
fn merge_dimension_mismatch_errors() {
let h_a = CorrelatedHistogram2D::new(8, 8).unwrap();
let mut h_b = CorrelatedHistogram2D::new(16, 8).unwrap();
assert!(h_b.merge(&h_a).is_err());
}
#[test]
fn diagonal_correlation_tighter_than_independent() {
let n: u64 = 4000;
let pairs: Vec<(f64, f64)> = (0..n).map(|i| (i as f64, (i % 4) as f64)).collect();
let h = CorrelatedHistogram2D::from_pairs(&pairs, 16, 4).unwrap();
assert_eq!(h.total(), n);
let a_lo = 0.0;
let a_hi = (n as f64) / 4.0 - 1.0;
let b_lo = 0.0;
let b_hi = 0.0;
let est = h.estimate_range(a_lo, a_hi, b_lo, b_hi);
let (full_a_lo, full_a_hi) = h.a_range();
let (full_b_lo, full_b_hi) = h.b_range();
let marg_a = h.estimate_range(a_lo, a_hi, full_b_lo, full_b_hi) as f64;
let marg_b = h.estimate_range(full_a_lo, full_a_hi, b_lo, b_hi) as f64;
let independent = marg_a * marg_b / (n as f64);
assert!(
(est as f64) <= independent * 1.5 + 1.0,
"2D est {est} exceeded independent estimate {independent}"
);
assert!(
(est as f64) < marg_a,
"joint estimate {est} should be strictly less than marginal-A {marg_a}"
);
let full_b = h.estimate_range(a_lo, a_hi, full_b_lo, full_b_hi);
assert_eq!(full_b as f64, marg_a);
}
#[test]
fn try_new_rejects_each_zero_dimension() {
assert!(CorrelatedHistogram2D::try_new(0, 0).is_err());
assert!(CorrelatedHistogram2D::try_new(0, 8).is_err());
assert!(CorrelatedHistogram2D::try_new(8, 0).is_err());
}
#[test]
fn try_new_accepts_valid_dimensions() {
let h = CorrelatedHistogram2D::try_new(4, 6).unwrap();
assert_eq!(h.col_a_bins(), 4);
assert_eq!(h.col_b_bins(), 6);
assert_eq!(h.cell_counts().len(), 24);
}
#[test]
fn from_bytes_rejects_all_zero_payload() {
for n in [4usize, 16, 64, 256, 1024, 4096] {
let zeros = vec![0u8; n];
assert!(
CorrelatedHistogram2D::from_bytes(&zeros).is_err(),
"all-zero len {n} accepted by from_bytes"
);
}
}
#[test]
fn from_bytes_rejects_cell_length_mismatch() {
#[derive(serde::Serialize)]
struct Wire {
col_a_bins: u32,
col_b_bins: u32,
a_min: f64,
a_max: f64,
b_min: f64,
b_max: f64,
cells: Vec<u64>,
total: u64,
}
let bad = Wire {
col_a_bins: 4,
col_b_bins: 4,
a_min: 0.0,
a_max: 1.0,
b_min: 0.0,
b_max: 1.0,
cells: vec![0u64; 7], total: 0,
};
let bytes = bincode::serialize(&bad).unwrap();
assert!(
CorrelatedHistogram2D::from_bytes(&bytes).is_err(),
"cell-length mismatch accepted"
);
}
#[test]
fn from_bytes_rejects_nan_min_or_max() {
#[derive(serde::Serialize)]
struct Wire {
col_a_bins: u32,
col_b_bins: u32,
a_min: f64,
a_max: f64,
b_min: f64,
b_max: f64,
cells: Vec<u64>,
total: u64,
}
let bad = Wire {
col_a_bins: 2,
col_b_bins: 2,
a_min: f64::NAN,
a_max: 1.0,
b_min: 0.0,
b_max: 1.0,
cells: vec![0u64; 4],
total: 0,
};
let bytes = bincode::serialize(&bad).unwrap();
assert!(
CorrelatedHistogram2D::from_bytes(&bytes).is_err(),
"NaN a_min accepted"
);
}
#[test]
fn from_bytes_accepts_valid_payload() {
let pairs: Vec<(f64, f64)> = (0..100).map(|i| (i as f64, i as f64)).collect();
let h = CorrelatedHistogram2D::from_pairs(&pairs, 4, 4).unwrap();
let bytes = h.to_bytes().unwrap();
let decoded = CorrelatedHistogram2D::from_bytes(&bytes).unwrap();
assert_eq!(h.total(), decoded.total());
}
#[test]
fn cell_counts_row_major_layout() {
let pairs = vec![(0.0, 0.0), (3.0, 3.0)];
let h = CorrelatedHistogram2D::from_pairs(&pairs, 4, 4).unwrap();
let cells = h.cell_counts();
assert_eq!(cells.len(), 16);
assert_eq!(cells[0], 1);
assert_eq!(cells[15], 1);
let touched: u64 = cells.iter().sum();
assert_eq!(touched, 2);
}
}