use std::collections::{HashMap, VecDeque};
use serde::Serialize;
use super::FindingType;
use crate::detect::Finding;
#[derive(Debug, Clone)]
pub struct CorrelationConfig {
pub window_ms: u64,
pub lag_threshold_ms: u64,
pub min_co_occurrences: u32,
pub min_confidence: f64,
pub max_tracked_pairs: usize,
}
impl Default for CorrelationConfig {
fn default() -> Self {
Self {
window_ms: 600_000,
lag_threshold_ms: 5_000,
min_co_occurrences: 5,
min_confidence: 0.7,
max_tracked_pairs: 10_000,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, serde::Deserialize)]
pub struct CorrelationEndpoint {
pub finding_type: FindingType,
pub service: String,
pub template: String,
}
#[derive(Debug, Clone, Serialize, serde::Deserialize)]
pub struct CrossTraceCorrelation {
pub source: CorrelationEndpoint,
pub target: CorrelationEndpoint,
pub co_occurrence_count: u32,
pub source_total_occurrences: u32,
pub confidence: f64,
pub median_lag_ms: f64,
pub first_seen: String,
pub last_seen: String,
}
#[derive(Debug, Clone)]
struct PairKey {
source: std::sync::Arc<CorrelationEndpoint>,
target: std::sync::Arc<CorrelationEndpoint>,
}
impl PartialEq for PairKey {
fn eq(&self, other: &Self) -> bool {
(std::sync::Arc::ptr_eq(&self.source, &other.source) || self.source == other.source)
&& (std::sync::Arc::ptr_eq(&self.target, &other.target) || self.target == other.target)
}
}
impl Eq for PairKey {}
impl std::hash::Hash for PairKey {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.source.hash(state);
self.target.hash(state);
}
}
const MAX_LAG_SAMPLES: usize = 256;
struct PairState {
co_occurrence_count: u32,
lags_ms: Vec<f64>,
total_observations: u64,
rng_state: u64,
first_seen_ms: u64,
last_seen_ms: u64,
}
impl PairState {
fn record_lag(&mut self, lag_ms: f64) {
self.total_observations = self.total_observations.saturating_add(1);
if self.lags_ms.len() < MAX_LAG_SAMPLES {
self.lags_ms.push(lag_ms);
return;
}
let r = splitmix64(&mut self.rng_state) % self.total_observations;
if r < MAX_LAG_SAMPLES as u64 {
self.lags_ms[r as usize] = lag_ms;
}
}
}
fn splitmix64(state: &mut u64) -> u64 {
*state = state.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = *state;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
}
fn hash_endpoint(ep: &CorrelationEndpoint) -> u64 {
const FNV_OFFSET: u64 = 0xcbf2_9ce4_8422_2325;
const FNV_PRIME: u64 = 0x100_0000_01b3;
let mut h: u64 = FNV_OFFSET;
for b in ep.finding_type.as_str().bytes() {
h ^= u64::from(b);
h = h.wrapping_mul(FNV_PRIME);
}
h ^= 0xFF; for b in ep.service.bytes() {
h ^= u64::from(b);
h = h.wrapping_mul(FNV_PRIME);
}
h ^= 0xFE; for b in ep.template.bytes() {
h ^= u64::from(b);
h = h.wrapping_mul(FNV_PRIME);
}
h
}
struct FindingOccurrence {
endpoint: std::sync::Arc<CorrelationEndpoint>,
timestamp_ms: u64,
}
pub struct CrossTraceCorrelator {
occurrences: VecDeque<FindingOccurrence>,
pair_counts: HashMap<PairKey, PairState>,
source_totals: HashMap<CorrelationEndpoint, u32>,
config: CorrelationConfig,
}
impl CrossTraceCorrelator {
#[must_use]
pub fn new(config: CorrelationConfig) -> Self {
Self {
occurrences: VecDeque::new(),
pair_counts: HashMap::new(),
source_totals: HashMap::new(),
config,
}
}
fn decrement_source_total(
source_totals: &mut HashMap<CorrelationEndpoint, u32>,
endpoint: &CorrelationEndpoint,
) {
if let Some(count) = source_totals.get_mut(endpoint) {
*count = count.saturating_sub(1);
if *count == 0 {
source_totals.remove(endpoint);
}
}
}
pub fn ingest(&mut self, findings: &[Finding], now_ms: u64) {
let cutoff = now_ms.saturating_sub(self.config.window_ms);
self.evict_stale(cutoff);
self.pair_counts
.retain(|_, state| state.last_seen_ms >= cutoff);
for finding in findings {
let endpoint = std::sync::Arc::new(CorrelationEndpoint {
finding_type: finding.finding_type.clone(),
service: finding.service.clone(),
template: finding.pattern.template.clone(),
});
self.record_co_occurrences(&endpoint, now_ms);
*self.source_totals.entry((*endpoint).clone()).or_insert(0) += 1;
self.occurrences.push_back(FindingOccurrence {
endpoint,
timestamp_ms: now_ms,
});
}
self.enforce_pair_cap();
}
fn evict_stale(&mut self, cutoff: u64) {
loop {
match self.occurrences.front() {
Some(front) if front.timestamp_ms < cutoff => {
if let Some(expired) = self.occurrences.pop_front() {
Self::decrement_source_total(&mut self.source_totals, &expired.endpoint);
}
}
_ => break,
}
}
}
fn record_co_occurrences(
&mut self,
endpoint: &std::sync::Arc<CorrelationEndpoint>,
now_ms: u64,
) {
for occ in self.occurrences.iter().rev() {
let age = now_ms.saturating_sub(occ.timestamp_ms);
if age > self.config.lag_threshold_ms {
break;
}
if occ.endpoint.service == endpoint.service {
continue;
}
let key = PairKey {
source: occ.endpoint.clone(), target: endpoint.clone(), };
#[allow(clippy::cast_precision_loss)]
let lag = age as f64;
let state = self.pair_counts.entry(key).or_insert_with(|| PairState {
co_occurrence_count: 0,
lags_ms: Vec::new(),
total_observations: 0,
rng_state: now_ms ^ (hash_endpoint(&occ.endpoint) << 17) ^ hash_endpoint(endpoint),
first_seen_ms: now_ms,
last_seen_ms: now_ms,
});
state.co_occurrence_count = state.co_occurrence_count.saturating_add(1);
state.record_lag(lag);
state.last_seen_ms = now_ms;
}
}
fn enforce_pair_cap(&mut self) {
if self.pair_counts.len() <= self.config.max_tracked_pairs {
return;
}
let to_remove = self.pair_counts.len() - self.config.max_tracked_pairs;
let mut keys: Vec<(PairKey, u32)> = self
.pair_counts
.iter()
.map(|(k, v)| (k.clone(), v.co_occurrence_count))
.collect();
let (lowest, _, _) = keys.select_nth_unstable_by_key(to_remove, |(_, c)| *c);
for (key, _) in lowest {
self.pair_counts.remove(key);
}
}
#[must_use]
pub fn active_correlations(&self) -> Vec<CrossTraceCorrelation> {
self.pair_counts
.iter()
.filter_map(|(key, state)| {
if state.co_occurrence_count < self.config.min_co_occurrences {
return None;
}
let source_total = self
.source_totals
.get(key.source.as_ref())
.copied()
.unwrap_or(1);
let confidence =
f64::from(state.co_occurrence_count) / f64::from(source_total.max(1));
if confidence < self.config.min_confidence {
return None;
}
let median_lag = median(&state.lags_ms);
Some(CrossTraceCorrelation {
source: (*key.source).clone(),
target: (*key.target).clone(),
co_occurrence_count: state.co_occurrence_count,
source_total_occurrences: source_total,
confidence,
median_lag_ms: median_lag,
first_seen: crate::time::millis_to_iso8601(state.first_seen_ms),
last_seen: crate::time::millis_to_iso8601(state.last_seen_ms),
})
})
.collect()
}
}
fn median(values: &[f64]) -> f64 {
if values.is_empty() {
return 0.0;
}
let mut sorted = values.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let mid = sorted.len() / 2;
if sorted.len().is_multiple_of(2) {
f64::midpoint(sorted[mid - 1], sorted[mid])
} else {
sorted[mid]
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_finding(service: &str, finding_type: FindingType, template: &str) -> Finding {
Finding {
finding_type,
severity: crate::detect::Severity::Warning,
trace_id: format!("trace-{service}"),
service: service.to_string(),
source_endpoint: "POST /api/test".to_string(),
pattern: crate::detect::Pattern {
template: template.to_string(),
occurrences: 5,
window_ms: 200,
distinct_params: 5,
},
suggestion: "batch".to_string(),
first_timestamp: "2025-07-10T14:32:01.000Z".to_string(),
last_timestamp: "2025-07-10T14:32:01.200Z".to_string(),
green_impact: None,
confidence: crate::detect::Confidence::default(),
code_location: None,
}
}
#[test]
fn detects_simple_a_then_b_pattern() {
let mut correlator = CrossTraceCorrelator::new(CorrelationConfig {
min_co_occurrences: 2,
min_confidence: 0.5,
lag_threshold_ms: 5_000,
..Default::default()
});
for i in 0..5 {
let t = 1_000_000 + i * 10_000;
let fa = make_finding("order-svc", FindingType::NPlusOneSql, "SELECT * FROM t");
correlator.ingest(&[fa], t);
let fb = make_finding("payment-svc", FindingType::PoolSaturation, "payment-svc");
correlator.ingest(&[fb], t + 2_000);
}
let correlations = correlator.active_correlations();
assert!(
!correlations.is_empty(),
"expected at least one correlation"
);
let c = &correlations[0];
assert_eq!(c.source.service, "order-svc");
assert_eq!(c.target.service, "payment-svc");
assert!(c.co_occurrence_count >= 2);
assert!(c.confidence > 0.0);
}
#[test]
fn same_service_not_correlated() {
let mut correlator = CrossTraceCorrelator::new(CorrelationConfig {
min_co_occurrences: 2,
min_confidence: 0.1,
..Default::default()
});
for i in 0..5 {
let t = 1_000_000 + i * 10_000;
let fa = make_finding("order-svc", FindingType::NPlusOneSql, "SELECT * FROM t");
let fb = make_finding("order-svc", FindingType::RedundantSql, "SELECT * FROM t");
correlator.ingest(&[fa, fb], t);
}
let correlations = correlator.active_correlations();
assert!(
correlations.is_empty(),
"same-service findings should not be correlated"
);
}
#[test]
fn eviction_removes_stale_entries() {
let mut correlator = CrossTraceCorrelator::new(CorrelationConfig {
window_ms: 10_000,
min_co_occurrences: 1,
min_confidence: 0.1,
..Default::default()
});
let fa = make_finding("order-svc", FindingType::NPlusOneSql, "SELECT 1");
correlator.ingest(&[fa], 1_000);
let fb = make_finding("payment-svc", FindingType::PoolSaturation, "payment-svc");
correlator.ingest(&[fb], 2_000);
let fa2 = make_finding("other-svc", FindingType::SlowSql, "SELECT 2");
correlator.ingest(&[fa2], 100_000);
assert!(
correlator.occurrences.len() <= 2,
"stale entries should be evicted"
);
}
#[test]
fn max_tracked_pairs_enforced() {
let mut correlator = CrossTraceCorrelator::new(CorrelationConfig {
max_tracked_pairs: 5,
lag_threshold_ms: 100_000,
min_co_occurrences: 1,
min_confidence: 0.0,
..Default::default()
});
for i in 0..20 {
let fa = make_finding(
&format!("svc-a-{i}"),
FindingType::NPlusOneSql,
&format!("tpl-{i}"),
);
correlator.ingest(&[fa], 1000);
let fb = make_finding(
&format!("svc-b-{i}"),
FindingType::RedundantSql,
&format!("tpl-{i}"),
);
correlator.ingest(&[fb], 1001);
}
assert!(
correlator.pair_counts.len() <= 5,
"pair count should be capped at max_tracked_pairs"
);
}
#[test]
fn low_confidence_filtered_out() {
let mut correlator = CrossTraceCorrelator::new(CorrelationConfig {
min_co_occurrences: 1,
min_confidence: 0.9,
lag_threshold_ms: 5_000,
..Default::default()
});
for i in 0..10 {
let t = 1_000_000 + i * 10_000;
let fa = make_finding("order-svc", FindingType::NPlusOneSql, "SELECT * FROM t");
correlator.ingest(&[fa], t);
if i < 2 {
let fb = make_finding("payment-svc", FindingType::PoolSaturation, "payment-svc");
correlator.ingest(&[fb], t + 1_000);
}
}
let correlations = correlator.active_correlations();
assert!(
correlations.is_empty(),
"low confidence pairs should be filtered"
);
}
#[test]
fn delay_exceeding_lag_threshold_not_counted() {
let mut correlator = CrossTraceCorrelator::new(CorrelationConfig {
lag_threshold_ms: 1_000,
min_co_occurrences: 1,
min_confidence: 0.1,
..Default::default()
});
let fa = make_finding("order-svc", FindingType::NPlusOneSql, "SELECT 1");
correlator.ingest(&[fa], 1_000);
let fb = make_finding("payment-svc", FindingType::PoolSaturation, "payment-svc");
correlator.ingest(&[fb], 10_000);
let correlations = correlator.active_correlations();
assert!(
correlations.is_empty(),
"findings outside lag threshold should not be correlated"
);
}
#[test]
fn lags_ms_bounded_by_reservoir_cap() {
let mut correlator = CrossTraceCorrelator::new(CorrelationConfig {
min_co_occurrences: 1,
min_confidence: 0.1,
lag_threshold_ms: 10_000,
window_ms: 10_000_000,
..Default::default()
});
let total = MAX_LAG_SAMPLES * 10;
for i in 0..total {
let t = 1_000_000 + i as u64 * 10;
let fa = make_finding("order-svc", FindingType::NPlusOneSql, "SELECT 1");
correlator.ingest(&[fa], t);
let fb = make_finding("payment-svc", FindingType::PoolSaturation, "payment-svc");
correlator.ingest(&[fb], t + 1);
}
assert!(
!correlator.pair_counts.is_empty(),
"expected at least one tracked pair"
);
for state in correlator.pair_counts.values() {
assert!(
state.lags_ms.len() <= MAX_LAG_SAMPLES,
"lags_ms must be bounded: got {}",
state.lags_ms.len()
);
assert!(
state.total_observations > MAX_LAG_SAMPLES as u64,
"total_observations should track every hit, got {}",
state.total_observations
);
}
}
#[test]
fn reservoir_continues_to_sample_after_many_observations() {
let mut state = PairState {
co_occurrence_count: 0,
lags_ms: Vec::new(),
total_observations: 0,
rng_state: 0x1234_5678_9ABC_DEF0,
first_seen_ms: 0,
last_seen_ms: 0,
};
let n = MAX_LAG_SAMPLES * 20;
for i in 0..n {
state.record_lag(i as f64);
}
let mean: f64 = state.lags_ms.iter().sum::<f64>() / state.lags_ms.len() as f64;
let expected_mean = (n - 1) as f64 / 2.0;
let tolerance = expected_mean * 0.10;
assert!(
(mean - expected_mean).abs() < tolerance,
"reservoir mean {mean} should be within {tolerance} of {expected_mean} \
(a frozen/biased reservoir would produce a much lower mean)"
);
let variance: f64 = state
.lags_ms
.iter()
.map(|&x| (x - mean).powi(2))
.sum::<f64>()
/ state.lags_ms.len() as f64;
let pop_variance = (n as f64).powi(2) / 12.0;
assert!(
variance > pop_variance * 0.25,
"reservoir variance {variance} should be at least 25% of population \
variance {pop_variance}; a frozen reservoir would be orders of \
magnitude below this"
);
}
#[test]
fn source_totals_rebuilt_from_window_on_each_ingest() {
let mut correlator = CrossTraceCorrelator::new(CorrelationConfig {
window_ms: 1_000,
min_co_occurrences: 1,
min_confidence: 0.1,
..Default::default()
});
let fa = make_finding("order-svc", FindingType::NPlusOneSql, "SELECT 1");
correlator.ingest(&[fa], 1_000);
assert_eq!(correlator.source_totals.len(), 1);
let fb = make_finding("other-svc", FindingType::NPlusOneSql, "SELECT 2");
correlator.ingest(&[fb], 10_000);
assert!(
correlator.source_totals.len() <= 1,
"source_totals should not retain stale entries"
);
}
#[test]
fn correlation_serde_roundtrip() {
let c = CrossTraceCorrelation {
source: CorrelationEndpoint {
finding_type: FindingType::NPlusOneSql,
service: "order-svc".to_string(),
template: "SELECT * FROM t".to_string(),
},
target: CorrelationEndpoint {
finding_type: FindingType::PoolSaturation,
service: "payment-svc".to_string(),
template: "payment-svc".to_string(),
},
co_occurrence_count: 12,
source_total_occurrences: 15,
confidence: 0.8,
median_lag_ms: 1200.0,
first_seen: "2025-07-10T14:32:00.000Z".to_string(),
last_seen: "2025-07-10T14:42:00.000Z".to_string(),
};
let json = serde_json::to_string(&c).unwrap();
let back: CrossTraceCorrelation = serde_json::from_str(&json).unwrap();
assert_eq!(back.co_occurrence_count, 12);
assert_eq!(back.source.service, "order-svc");
assert_eq!(back.target.service, "payment-svc");
assert!((back.confidence - 0.8).abs() < f64::EPSILON);
}
}