Skip to main content

datasynth_core/distributions/
source_conditional_pair.rs

1//! Source-conditional Dirichlet account-pair sampler (SOTA-8).
2//!
3//! Per source string, fits a Dirichlet-multinomial over a per-source account pool.
4//! Round 0 (`FINDINGS §14`) showed the synthetic engine's source-conditional structure
5//! is too uniform (entropy 0.97 vs corpus 0.68) and too narrow (5 vs 23.5 accounts per
6//! source). This sampler closes both gaps simultaneously: a configurable larger pool,
7//! drawn through a *concentrated* (low-α) Dirichlet.
8//!
9//! Math: symmetric Dirichlet(α, …, α) is realised by `pᵢ = Gᵢ / Σⱼ Gⱼ` with each
10//! `Gᵢ ~ Gamma(α, 1)`. Lower α ⇒ concentrated PMF. With α = 0.5 and `N_s = 25` the
11//! expected normalised entropy is ≈ 0.65 — matching the corpus median of 0.68.
12//!
13//! This module is wired in by `je_generator` only when the `transactions
14//! .source_conditional_account_pair.enabled` config flag is set (default off — opt-in
15//! so existing users' synthetic streams stay byte-identical).
16
17use std::collections::HashMap;
18
19use rand::distr::weighted::WeightedIndex;
20use rand::prelude::*;
21use rand_chacha::ChaCha8Rng;
22use rand_distr::{Distribution, Gamma, LogNormal};
23
24/// One source's account pool with a fitted Dirichlet PMF, ready to sample from.
25#[derive(Debug, Clone)]
26pub struct SourcePool {
27    /// Accounts in this source's pool (size = `n()`).
28    pub accounts: Vec<String>,
29    /// Cumulative PMF over `accounts`; used for O(log N) inverse-CDF sampling.
30    /// Always normalised so `cumulative.last() == 1.0`.
31    cumulative: Vec<f64>,
32}
33
34impl SourcePool {
35    /// Build a pool of `pool_size` accounts drawn from `all_accounts` weighted by
36    /// `account_weights` (deduplicated), with a symmetric Dirichlet(α) PMF over them.
37    pub fn new(
38        pool_size: usize,
39        all_accounts: &[String],
40        account_weights: &[f64],
41        alpha: f64,
42        rng: &mut ChaCha8Rng,
43    ) -> Self {
44        assert_eq!(
45            all_accounts.len(),
46            account_weights.len(),
47            "all_accounts and account_weights must align"
48        );
49        assert!(alpha > 0.0, "alpha must be > 0");
50
51        // Weighted sampling-with-replacement + dedup until we have `pool_size`
52        // distinct accounts (or run out of attempts on pathological weights).
53        let widx = WeightedIndex::new(account_weights).expect("non-negative weights");
54        let mut chosen: Vec<String> = Vec::with_capacity(pool_size);
55        let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
56        let cap = (16 * pool_size.max(1)).max(64);
57        for _ in 0..cap {
58            if chosen.len() >= pool_size {
59                break;
60            }
61            let i = widx.sample(rng);
62            if seen.insert(all_accounts[i].clone()) {
63                chosen.push(all_accounts[i].clone());
64            }
65        }
66
67        // Symmetric Dirichlet via Gamma normalisation.
68        let gamma = Gamma::new(alpha, 1.0).expect("alpha > 0");
69        let raw: Vec<f64> = (0..chosen.len())
70            .map(|_| gamma.sample(rng).max(1e-300))
71            .collect();
72        let total: f64 = raw.iter().sum();
73        let mut cumulative = Vec::with_capacity(raw.len());
74        let mut running = 0.0;
75        for r in raw {
76            running += r / total;
77            cumulative.push(running);
78        }
79        // Guard against fp drift on the upper bound.
80        if let Some(last) = cumulative.last_mut() {
81            *last = 1.0;
82        }
83        Self {
84            accounts: chosen,
85            cumulative,
86        }
87    }
88
89    /// Number of accounts in the pool.
90    pub fn n(&self) -> usize {
91        self.accounts.len()
92    }
93
94    /// Draw a single account from the PMF.
95    pub fn sample_one(&self, rng: &mut ChaCha8Rng) -> &str {
96        if self.accounts.is_empty() {
97            return "";
98        }
99        let u: f64 = rng.random();
100        let idx = self
101            .cumulative
102            .partition_point(|&c| c < u)
103            .min(self.accounts.len() - 1);
104        &self.accounts[idx]
105    }
106
107    /// Draw a `(debit_account, credit_account)` pair from the per-source PMF, with the
108    /// distinct-accounts constraint. Returns `None` if the pool has fewer than 2
109    /// accounts (the caller should fall back to the global picker).
110    pub fn sample_pair(&self, rng: &mut ChaCha8Rng) -> Option<(String, String)> {
111        if self.accounts.len() < 2 {
112            return None;
113        }
114        let d = self.sample_one(rng).to_string();
115        // Re-sample until the credit account differs. Typically 1 attempt; bounded to
116        // avoid infinite loops on pathological PMFs with one near-mass-1 component.
117        for _ in 0..16 {
118            let c = self.sample_one(rng);
119            if c != d {
120                return Some((d, c.to_string()));
121            }
122        }
123        // Deterministic fallback: pick any other account in the pool.
124        let other = self
125            .accounts
126            .iter()
127            .find(|a| **a != d)
128            .expect("len() >= 2 was checked above");
129        Some((d, other.clone()))
130    }
131
132    /// Normalised Shannon entropy of the PMF in `[0, 1]`. Useful for tests +
133    /// observability (e.g. comparing to the corpus's source-conditional entropy band).
134    pub fn normalised_entropy(&self) -> f64 {
135        if self.accounts.len() <= 1 {
136            return 0.0;
137        }
138        let n = self.accounts.len() as f64;
139        let mut prev = 0.0;
140        let mut h = 0.0;
141        for &c in &self.cumulative {
142            let p = c - prev;
143            prev = c;
144            if p > 0.0 {
145                h -= p * p.ln();
146            }
147        }
148        h / n.ln()
149    }
150}
151
152/// Top-level sampler — one `SourcePool` per source string.
153#[derive(Debug, Clone, Default)]
154pub struct SourceConditionalPairSampler {
155    pools: HashMap<String, SourcePool>,
156}
157
158impl SourceConditionalPairSampler {
159    /// Build a sampler for every source in `sources`. Each gets a pool of
160    /// approximately `accts_per_source_target` accounts (multiplied by a LogNormal(0,
161    /// 0.3) jitter so the per-source pool size has corpus-like variance), drawn from
162    /// `all_accounts` weighted by `account_weights`, with PMF ∼ Dir(α).
163    pub fn new(
164        sources: &[String],
165        all_accounts: &[String],
166        account_weights: &[f64],
167        accts_per_source_target: usize,
168        alpha: f64,
169        rng: &mut ChaCha8Rng,
170    ) -> Self {
171        assert_eq!(all_accounts.len(), account_weights.len());
172        let jitter = LogNormal::new(0.0, 0.3).expect("sigma > 0");
173        let mut pools = HashMap::with_capacity(sources.len());
174        for s in sources {
175            let mult = jitter.sample(rng);
176            let n_s = ((accts_per_source_target as f64 * mult).round() as usize)
177                .max(2)
178                .min(all_accounts.len());
179            pools.insert(
180                s.clone(),
181                SourcePool::new(n_s, all_accounts, account_weights, alpha, rng),
182            );
183        }
184        Self { pools }
185    }
186
187    /// Get the per-source pool (for diagnostics / tests).
188    pub fn pool(&self, source: &str) -> Option<&SourcePool> {
189        self.pools.get(source)
190    }
191
192    /// Lazy-add a per-source pool if one isn't already present. Returns `true` iff a
193    /// new pool was inserted; `false` if `source` was already pooled (no-op). Uses the
194    /// same LogNormal(0, 0.3) jitter on the pool size as `new`, so a sampler built up
195    /// one source at a time has the same distribution as one built with all sources
196    /// at once.
197    pub fn ensure_pool(
198        &mut self,
199        source: &str,
200        all_accounts: &[String],
201        account_weights: &[f64],
202        accts_per_source_target: usize,
203        alpha: f64,
204        rng: &mut ChaCha8Rng,
205    ) -> bool {
206        if self.pools.contains_key(source) {
207            return false;
208        }
209        let jitter = LogNormal::new(0.0, 0.3).expect("sigma > 0");
210        let mult = jitter.sample(rng);
211        let n_s = ((accts_per_source_target as f64 * mult).round() as usize)
212            .max(2)
213            .min(all_accounts.len());
214        self.pools.insert(
215            source.to_string(),
216            SourcePool::new(n_s, all_accounts, account_weights, alpha, rng),
217        );
218        true
219    }
220
221    /// Sample a `(debit_account, credit_account)` pair conditioned on `source`.
222    /// Returns `None` if the source isn't in the sampler — the caller should fall back
223    /// to the existing global account picker.
224    pub fn sample_pair(&self, source: &str, rng: &mut ChaCha8Rng) -> Option<(String, String)> {
225        self.pools.get(source).and_then(|p| p.sample_pair(rng))
226    }
227
228    pub fn is_empty(&self) -> bool {
229        self.pools.is_empty()
230    }
231    pub fn n_sources(&self) -> usize {
232        self.pools.len()
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239    use rand::SeedableRng;
240
241    fn synthetic_accounts(n: usize) -> (Vec<String>, Vec<f64>) {
242        // Lognormal-ish weights (a stand-in for the existing account-Pareto in tests).
243        let accounts: Vec<String> = (0..n).map(|i| format!("ACC{i:04}")).collect();
244        let weights: Vec<f64> = (0..n).map(|i| 1.0 / ((i + 1) as f64).powf(1.2)).collect();
245        (accounts, weights)
246    }
247
248    #[test]
249    fn small_alpha_yields_concentrated_pmf() {
250        let (acc, wts) = synthetic_accounts(200);
251        let mut rng = ChaCha8Rng::seed_from_u64(42);
252        let pool = SourcePool::new(25, &acc, &wts, 0.5, &mut rng);
253        let h = pool.normalised_entropy();
254        assert_eq!(pool.n(), 25);
255        // α = 0.5, N = 25 ⇒ expected entropy ≈ 0.6–0.75; allow a wider single-draw band.
256        assert!(
257            (0.45..=0.85).contains(&h),
258            "expected concentrated entropy in [0.45, 0.85], got {h}"
259        );
260    }
261
262    #[test]
263    fn large_alpha_yields_diffuse_pmf() {
264        let (acc, wts) = synthetic_accounts(200);
265        let mut rng = ChaCha8Rng::seed_from_u64(7);
266        let pool = SourcePool::new(25, &acc, &wts, 10.0, &mut rng);
267        let h = pool.normalised_entropy();
268        // α = 10, N = 25 ⇒ entropy near-uniform.
269        assert!(h > 0.9, "expected diffuse entropy > 0.9, got {h}");
270    }
271
272    #[test]
273    fn same_seed_same_pool() {
274        let (acc, wts) = synthetic_accounts(100);
275        let a = SourcePool::new(20, &acc, &wts, 0.5, &mut ChaCha8Rng::seed_from_u64(1));
276        let b = SourcePool::new(20, &acc, &wts, 0.5, &mut ChaCha8Rng::seed_from_u64(1));
277        assert_eq!(a.accounts, b.accounts);
278        for (x, y) in a.cumulative.iter().zip(&b.cumulative) {
279            assert!((x - y).abs() < 1e-12, "PMF mismatch: {x} vs {y}");
280        }
281    }
282
283    #[test]
284    fn sample_pair_returns_distinct_accounts() {
285        let (acc, wts) = synthetic_accounts(50);
286        let mut rng = ChaCha8Rng::seed_from_u64(3);
287        let pool = SourcePool::new(10, &acc, &wts, 0.7, &mut rng);
288        for _ in 0..200 {
289            let (d, c) = pool.sample_pair(&mut rng).expect("pool has 2+ accounts");
290            assert_ne!(d, c);
291            assert!(pool.accounts.contains(&d));
292            assert!(pool.accounts.contains(&c));
293        }
294    }
295
296    #[test]
297    fn full_sampler_per_source_diversity() {
298        let (acc, wts) = synthetic_accounts(200);
299        let sources: Vec<String> = (0..5).map(|i| format!("S{i}")).collect();
300        let sampler = SourceConditionalPairSampler::new(
301            &sources,
302            &acc,
303            &wts,
304            25,
305            0.5,
306            &mut ChaCha8Rng::seed_from_u64(99),
307        );
308        assert_eq!(sampler.n_sources(), 5);
309        // Pools across sources should not be near-identical: with pool size 25 drawn
310        // (weighted) from 200 accounts the typical overlap is well below total.
311        let p0: std::collections::HashSet<_> =
312            sampler.pool("S0").unwrap().accounts.iter().collect();
313        let p1: std::collections::HashSet<_> =
314            sampler.pool("S1").unwrap().accounts.iter().collect();
315        let overlap = p0.intersection(&p1).count() as f64 / p0.len() as f64;
316        assert!(
317            overlap < 0.85,
318            "pools too similar across sources: overlap={overlap:.2}"
319        );
320    }
321}