use std::collections::HashMap;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CorrelationFingerprint {
pub matrices: HashMap<String, CorrelationMatrix>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub cross_table_correlations: Vec<CrossTableCorrelation>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub copulas: Vec<GaussianCopula>,
}
impl CorrelationFingerprint {
pub fn new() -> Self {
Self {
matrices: HashMap::new(),
cross_table_correlations: Vec::new(),
copulas: Vec::new(),
}
}
pub fn add_matrix(&mut self, table: impl Into<String>, matrix: CorrelationMatrix) {
self.matrices.insert(table.into(), matrix);
}
pub fn add_copula(&mut self, copula: GaussianCopula) {
self.copulas.push(copula);
}
}
impl Default for CorrelationFingerprint {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CorrelationMatrix {
pub columns: Vec<String>,
pub correlations: Vec<f64>,
pub correlation_type: CorrelationType,
pub sample_size: u64,
}
impl CorrelationMatrix {
pub fn new(columns: Vec<String>, correlation_type: CorrelationType) -> Self {
let n = columns.len();
let size = n * (n - 1) / 2; Self {
columns,
correlations: vec![0.0; size],
correlation_type,
sample_size: 0,
}
}
pub fn get(&self, i: usize, j: usize) -> Option<f64> {
if i == j {
return Some(1.0); }
let (low, high) = if i < j { (i, j) } else { (j, i) };
let n = self.columns.len();
if high >= n {
return None;
}
let idx = (0..low).map(|k| n - k - 1).sum::<usize>() + (high - low - 1);
self.correlations.get(idx).copied()
}
pub fn set(&mut self, i: usize, j: usize, value: f64) {
if i == j {
return; }
let (low, high) = if i < j { (i, j) } else { (j, i) };
let n = self.columns.len();
if high >= n {
return;
}
let idx = (0..low).map(|k| n - k - 1).sum::<usize>() + (high - low - 1);
if idx < self.correlations.len() {
self.correlations[idx] = value;
}
}
pub fn get_by_name(&self, col1: &str, col2: &str) -> Option<f64> {
let i = self.columns.iter().position(|c| c == col1)?;
let j = self.columns.iter().position(|c| c == col2)?;
self.get(i, j)
}
pub fn to_full_matrix(&self) -> Vec<Vec<f64>> {
let n = self.columns.len();
let mut matrix = vec![vec![0.0; n]; n];
for i in 0..n {
for j in 0..n {
matrix[i][j] = self.get(i, j).unwrap_or(0.0);
}
}
matrix
}
pub fn from_full_matrix(
columns: Vec<String>,
matrix: &[Vec<f64>],
correlation_type: CorrelationType,
) -> Self {
let n = columns.len();
let size = n * (n - 1) / 2;
let mut correlations = Vec::with_capacity(size);
for i in 0..n {
for j in (i + 1)..n {
correlations.push(matrix[i][j]);
}
}
Self {
columns,
correlations,
correlation_type,
sample_size: 0,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CorrelationType {
Pearson,
Spearman,
Kendall,
CramersV,
Eta,
PointBiserial,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CrossTableCorrelation {
pub table1: String,
pub column1: String,
pub table2: String,
pub column2: String,
pub correlation: f64,
pub correlation_type: CorrelationType,
pub sample_size: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub join_key: Option<JoinKey>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JoinKey {
pub column1: String,
pub column2: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GaussianCopula {
pub name: String,
pub table: String,
pub columns: Vec<String>,
pub correlation_matrix: Vec<f64>,
pub marginal_cdfs: Vec<EmpiricalCdf>,
}
impl GaussianCopula {
pub fn new(name: impl Into<String>, table: impl Into<String>, columns: Vec<String>) -> Self {
let n = columns.len();
Self {
name: name.into(),
table: table.into(),
columns,
correlation_matrix: vec![1.0; n * n], marginal_cdfs: Vec::new(),
}
}
pub fn dimensions(&self) -> usize {
self.columns.len()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmpiricalCdf {
pub column: String,
pub values: Vec<f64>,
pub probabilities: Vec<f64>,
}
impl EmpiricalCdf {
pub fn from_sorted_values(column: impl Into<String>, values: Vec<f64>) -> Self {
let n = values.len();
let probabilities: Vec<f64> = (1..=n).map(|i| i as f64 / n as f64).collect();
Self {
column: column.into(),
values,
probabilities,
}
}
pub fn cdf(&self, x: f64) -> f64 {
match self.values.binary_search_by(|v| v.total_cmp(&x)) {
Ok(i) => self.probabilities[i],
Err(i) => {
if i == 0 {
0.0
} else if i >= self.values.len() {
1.0
} else {
let (x0, x1) = (self.values[i - 1], self.values[i]);
let (p0, p1) = (self.probabilities[i - 1], self.probabilities[i]);
p0 + (p1 - p0) * (x - x0) / (x1 - x0)
}
}
}
}
pub fn quantile(&self, p: f64) -> f64 {
if p <= 0.0 {
return *self.values.first().unwrap_or(&0.0);
}
if p >= 1.0 {
return *self.values.last().unwrap_or(&0.0);
}
match self.probabilities.binary_search_by(|v| v.total_cmp(&p)) {
Ok(i) => self.values[i],
Err(i) => {
if i == 0 {
self.values[0]
} else if i >= self.probabilities.len() {
*self.values.last().unwrap_or(&0.0)
} else {
let (p0, p1) = (self.probabilities[i - 1], self.probabilities[i]);
let (x0, x1) = (self.values[i - 1], self.values[i]);
x0 + (x1 - x0) * (p - p0) / (p1 - p0)
}
}
}
}
}