datasynth_core/distributions/
source_conditional_pair.rs1use std::collections::HashMap;
18
19use rand::distr::weighted::WeightedIndex;
20use rand::prelude::*;
21use rand_chacha::ChaCha8Rng;
22use rand_distr::{Distribution, Gamma, LogNormal};
23
24#[derive(Debug, Clone)]
26pub struct SourcePool {
27 pub accounts: Vec<String>,
29 cumulative: Vec<f64>,
32}
33
34impl SourcePool {
35 pub fn new(
38 pool_size: usize,
39 all_accounts: &[String],
40 account_weights: &[f64],
41 alpha: f64,
42 rng: &mut ChaCha8Rng,
43 ) -> Self {
44 assert_eq!(
45 all_accounts.len(),
46 account_weights.len(),
47 "all_accounts and account_weights must align"
48 );
49 assert!(alpha > 0.0, "alpha must be > 0");
50
51 let widx = WeightedIndex::new(account_weights).expect("non-negative weights");
54 let mut chosen: Vec<String> = Vec::with_capacity(pool_size);
55 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
56 let cap = (16 * pool_size.max(1)).max(64);
57 for _ in 0..cap {
58 if chosen.len() >= pool_size {
59 break;
60 }
61 let i = widx.sample(rng);
62 if seen.insert(all_accounts[i].clone()) {
63 chosen.push(all_accounts[i].clone());
64 }
65 }
66
67 let gamma = Gamma::new(alpha, 1.0).expect("alpha > 0");
69 let raw: Vec<f64> = (0..chosen.len())
70 .map(|_| gamma.sample(rng).max(1e-300))
71 .collect();
72 let total: f64 = raw.iter().sum();
73 let mut cumulative = Vec::with_capacity(raw.len());
74 let mut running = 0.0;
75 for r in raw {
76 running += r / total;
77 cumulative.push(running);
78 }
79 if let Some(last) = cumulative.last_mut() {
81 *last = 1.0;
82 }
83 Self {
84 accounts: chosen,
85 cumulative,
86 }
87 }
88
89 pub fn n(&self) -> usize {
91 self.accounts.len()
92 }
93
94 pub fn sample_one(&self, rng: &mut ChaCha8Rng) -> &str {
96 if self.accounts.is_empty() {
97 return "";
98 }
99 let u: f64 = rng.random();
100 let idx = self
101 .cumulative
102 .partition_point(|&c| c < u)
103 .min(self.accounts.len() - 1);
104 &self.accounts[idx]
105 }
106
107 pub fn sample_pair(&self, rng: &mut ChaCha8Rng) -> Option<(String, String)> {
111 if self.accounts.len() < 2 {
112 return None;
113 }
114 let d = self.sample_one(rng).to_string();
115 for _ in 0..16 {
118 let c = self.sample_one(rng);
119 if c != d {
120 return Some((d, c.to_string()));
121 }
122 }
123 let other = self
125 .accounts
126 .iter()
127 .find(|a| **a != d)
128 .expect("len() >= 2 was checked above");
129 Some((d, other.clone()))
130 }
131
132 pub fn normalised_entropy(&self) -> f64 {
135 if self.accounts.len() <= 1 {
136 return 0.0;
137 }
138 let n = self.accounts.len() as f64;
139 let mut prev = 0.0;
140 let mut h = 0.0;
141 for &c in &self.cumulative {
142 let p = c - prev;
143 prev = c;
144 if p > 0.0 {
145 h -= p * p.ln();
146 }
147 }
148 h / n.ln()
149 }
150}
151
152#[derive(Debug, Clone, Default)]
154pub struct SourceConditionalPairSampler {
155 pools: HashMap<String, SourcePool>,
156}
157
158impl SourceConditionalPairSampler {
159 pub fn new(
164 sources: &[String],
165 all_accounts: &[String],
166 account_weights: &[f64],
167 accts_per_source_target: usize,
168 alpha: f64,
169 rng: &mut ChaCha8Rng,
170 ) -> Self {
171 assert_eq!(all_accounts.len(), account_weights.len());
172 let jitter = LogNormal::new(0.0, 0.3).expect("sigma > 0");
173 let mut pools = HashMap::with_capacity(sources.len());
174 for s in sources {
175 let mult = jitter.sample(rng);
176 let n_s = ((accts_per_source_target as f64 * mult).round() as usize)
177 .max(2)
178 .min(all_accounts.len());
179 pools.insert(
180 s.clone(),
181 SourcePool::new(n_s, all_accounts, account_weights, alpha, rng),
182 );
183 }
184 Self { pools }
185 }
186
187 pub fn pool(&self, source: &str) -> Option<&SourcePool> {
189 self.pools.get(source)
190 }
191
192 pub fn ensure_pool(
198 &mut self,
199 source: &str,
200 all_accounts: &[String],
201 account_weights: &[f64],
202 accts_per_source_target: usize,
203 alpha: f64,
204 rng: &mut ChaCha8Rng,
205 ) -> bool {
206 if self.pools.contains_key(source) {
207 return false;
208 }
209 let jitter = LogNormal::new(0.0, 0.3).expect("sigma > 0");
210 let mult = jitter.sample(rng);
211 let n_s = ((accts_per_source_target as f64 * mult).round() as usize)
212 .max(2)
213 .min(all_accounts.len());
214 self.pools.insert(
215 source.to_string(),
216 SourcePool::new(n_s, all_accounts, account_weights, alpha, rng),
217 );
218 true
219 }
220
221 pub fn sample_pair(&self, source: &str, rng: &mut ChaCha8Rng) -> Option<(String, String)> {
225 self.pools.get(source).and_then(|p| p.sample_pair(rng))
226 }
227
228 pub fn is_empty(&self) -> bool {
229 self.pools.is_empty()
230 }
231 pub fn n_sources(&self) -> usize {
232 self.pools.len()
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use rand::SeedableRng;
240
241 fn synthetic_accounts(n: usize) -> (Vec<String>, Vec<f64>) {
242 let accounts: Vec<String> = (0..n).map(|i| format!("ACC{i:04}")).collect();
244 let weights: Vec<f64> = (0..n).map(|i| 1.0 / ((i + 1) as f64).powf(1.2)).collect();
245 (accounts, weights)
246 }
247
248 #[test]
249 fn small_alpha_yields_concentrated_pmf() {
250 let (acc, wts) = synthetic_accounts(200);
251 let mut rng = ChaCha8Rng::seed_from_u64(42);
252 let pool = SourcePool::new(25, &acc, &wts, 0.5, &mut rng);
253 let h = pool.normalised_entropy();
254 assert_eq!(pool.n(), 25);
255 assert!(
257 (0.45..=0.85).contains(&h),
258 "expected concentrated entropy in [0.45, 0.85], got {h}"
259 );
260 }
261
262 #[test]
263 fn large_alpha_yields_diffuse_pmf() {
264 let (acc, wts) = synthetic_accounts(200);
265 let mut rng = ChaCha8Rng::seed_from_u64(7);
266 let pool = SourcePool::new(25, &acc, &wts, 10.0, &mut rng);
267 let h = pool.normalised_entropy();
268 assert!(h > 0.9, "expected diffuse entropy > 0.9, got {h}");
270 }
271
272 #[test]
273 fn same_seed_same_pool() {
274 let (acc, wts) = synthetic_accounts(100);
275 let a = SourcePool::new(20, &acc, &wts, 0.5, &mut ChaCha8Rng::seed_from_u64(1));
276 let b = SourcePool::new(20, &acc, &wts, 0.5, &mut ChaCha8Rng::seed_from_u64(1));
277 assert_eq!(a.accounts, b.accounts);
278 for (x, y) in a.cumulative.iter().zip(&b.cumulative) {
279 assert!((x - y).abs() < 1e-12, "PMF mismatch: {x} vs {y}");
280 }
281 }
282
283 #[test]
284 fn sample_pair_returns_distinct_accounts() {
285 let (acc, wts) = synthetic_accounts(50);
286 let mut rng = ChaCha8Rng::seed_from_u64(3);
287 let pool = SourcePool::new(10, &acc, &wts, 0.7, &mut rng);
288 for _ in 0..200 {
289 let (d, c) = pool.sample_pair(&mut rng).expect("pool has 2+ accounts");
290 assert_ne!(d, c);
291 assert!(pool.accounts.contains(&d));
292 assert!(pool.accounts.contains(&c));
293 }
294 }
295
296 #[test]
297 fn full_sampler_per_source_diversity() {
298 let (acc, wts) = synthetic_accounts(200);
299 let sources: Vec<String> = (0..5).map(|i| format!("S{i}")).collect();
300 let sampler = SourceConditionalPairSampler::new(
301 &sources,
302 &acc,
303 &wts,
304 25,
305 0.5,
306 &mut ChaCha8Rng::seed_from_u64(99),
307 );
308 assert_eq!(sampler.n_sources(), 5);
309 let p0: std::collections::HashSet<_> =
312 sampler.pool("S0").unwrap().accounts.iter().collect();
313 let p1: std::collections::HashSet<_> =
314 sampler.pool("S1").unwrap().accounts.iter().collect();
315 let overlap = p0.intersection(&p1).count() as f64 / p0.len() as f64;
316 assert!(
317 overlap < 0.85,
318 "pools too similar across sources: overlap={overlap:.2}"
319 );
320 }
321}