Skip to main content

datasynth_banking/
orchestrator.rs

1//! Banking data generation orchestrator.
2
3use std::path::Path;
4
5use crate::config::BankingConfig;
6use crate::generators::{
7    AccountGenerator, CounterpartyGenerator, CustomerGenerator, KycGenerator, TransactionGenerator,
8};
9use crate::labels::{
10    AccountLabel, CustomerLabel, EntityLabelExtractor, ExportedNarrative, NarrativeGenerator,
11    RelationshipLabel, RelationshipLabelExtractor, TransactionLabel, TransactionLabelExtractor,
12};
13use crate::models::{AmlScenario, BankAccount, BankTransaction, BankingCustomer, CounterpartyPool};
14use crate::typologies::TypologyInjector;
15
16/// Banking data generation orchestrator.
17///
18/// Coordinates the generation of:
19/// - Customers with KYC profiles
20/// - Accounts for customers
21/// - Transactions based on personas
22/// - AML typology injection
23/// - Ground truth labels
24pub struct BankingOrchestrator {
25    config: BankingConfig,
26    seed: u64,
27    /// Optional country pack for locale-aware customer data generation.
28    country_pack: Option<datasynth_core::CountryPack>,
29}
30
31/// Generated banking data result.
32#[derive(Debug)]
33pub struct BankingData {
34    /// Generated customers
35    pub customers: Vec<BankingCustomer>,
36    /// Generated accounts
37    pub accounts: Vec<BankAccount>,
38    /// Generated transactions
39    pub transactions: Vec<BankTransaction>,
40    /// Counterparty pool
41    pub counterparties: CounterpartyPool,
42    /// AML scenarios
43    pub scenarios: Vec<AmlScenario>,
44    /// Transaction labels
45    pub transaction_labels: Vec<TransactionLabel>,
46    /// Customer labels
47    pub customer_labels: Vec<CustomerLabel>,
48    /// Account labels
49    pub account_labels: Vec<AccountLabel>,
50    /// Relationship labels
51    pub relationship_labels: Vec<RelationshipLabel>,
52    /// Case narratives
53    pub narratives: Vec<ExportedNarrative>,
54    /// Generation statistics
55    pub stats: GenerationStats,
56}
57
58/// Generation statistics.
59#[derive(Debug, Clone, Default)]
60pub struct GenerationStats {
61    /// Total customers generated
62    pub customer_count: usize,
63    /// Total accounts generated
64    pub account_count: usize,
65    /// Total transactions generated
66    pub transaction_count: usize,
67    /// Suspicious transaction count
68    pub suspicious_count: usize,
69    /// Suspicious rate
70    pub suspicious_rate: f64,
71    /// Spoofed transaction count
72    pub spoofed_count: usize,
73    /// Spoofed rate
74    pub spoofed_rate: f64,
75    /// AML scenario count
76    pub scenario_count: usize,
77    /// Generation duration in milliseconds
78    pub duration_ms: u64,
79}
80
81impl BankingOrchestrator {
82    /// Create a new banking orchestrator.
83    pub fn new(config: BankingConfig, seed: u64) -> Self {
84        Self {
85            config,
86            seed,
87            country_pack: None,
88        }
89    }
90
91    /// Set the country pack for locale-aware customer data generation.
92    pub fn set_country_pack(&mut self, pack: datasynth_core::CountryPack) {
93        self.country_pack = Some(pack);
94    }
95
96    /// Generate all banking data.
97    pub fn generate(&self) -> BankingData {
98        let start = std::time::Instant::now();
99
100        // Phase 1: Generate counterparty pool
101        let mut counterparty_gen = CounterpartyGenerator::new(self.seed);
102        let counterparties = counterparty_gen.generate_pool(&self.config);
103
104        // Phase 2: Generate customers with KYC profiles
105        let mut customer_gen = CustomerGenerator::new(self.config.clone(), self.seed);
106        if let Some(ref pack) = self.country_pack {
107            customer_gen.set_country_pack(pack.clone());
108        }
109        let mut customers = customer_gen.generate_all();
110
111        // Phase 3: Generate KYC profiles
112        let mut kyc_gen = KycGenerator::new(self.seed);
113        for customer in &mut customers {
114            let profile = kyc_gen.generate_profile(customer, &self.config);
115            customer.kyc_profile = profile;
116        }
117
118        // Phase 4: Generate accounts for customers
119        let mut account_gen = AccountGenerator::new(self.config.clone(), self.seed);
120        let mut accounts = account_gen.generate_for_customers(&mut customers);
121
122        // Phase 5: Generate transactions
123        let mut txn_gen = TransactionGenerator::new(self.config.clone(), self.seed);
124        let mut transactions = txn_gen.generate_all(&customers, &mut accounts);
125
126        // Phase 6: Inject AML typologies
127        let mut typology_injector = TypologyInjector::new(self.config.clone(), self.seed);
128        typology_injector.inject(&mut customers, &mut accounts, &mut transactions);
129        let scenarios: Vec<AmlScenario> = typology_injector.get_scenarios().to_vec();
130
131        // Phase 7: Generate narratives
132        let mut narrative_gen = NarrativeGenerator::new(self.seed);
133        let narratives: Vec<ExportedNarrative> = scenarios
134            .iter()
135            .map(|s| {
136                let narrative = narrative_gen.generate(s);
137                ExportedNarrative::from_scenario(s, &narrative)
138            })
139            .collect();
140
141        // Phase 8: Extract labels
142        let transaction_labels = TransactionLabelExtractor::extract_with_features(&transactions);
143        let customer_labels = EntityLabelExtractor::extract_customers(&customers);
144        let account_labels = EntityLabelExtractor::extract_accounts(&accounts);
145        let relationship_labels = RelationshipLabelExtractor::extract_from_customers(&customers);
146
147        // Compute statistics
148        let suspicious_count = transactions.iter().filter(|t| t.is_suspicious).count();
149        let spoofed_count = transactions.iter().filter(|t| t.is_spoofed).count();
150
151        let stats = GenerationStats {
152            customer_count: customers.len(),
153            account_count: accounts.len(),
154            transaction_count: transactions.len(),
155            suspicious_count,
156            suspicious_rate: suspicious_count as f64 / transactions.len().max(1) as f64,
157            spoofed_count,
158            spoofed_rate: spoofed_count as f64 / transactions.len().max(1) as f64,
159            scenario_count: scenarios.len(),
160            duration_ms: start.elapsed().as_millis() as u64,
161        };
162
163        BankingData {
164            customers,
165            accounts,
166            transactions,
167            counterparties,
168            scenarios,
169            transaction_labels,
170            customer_labels,
171            account_labels,
172            relationship_labels,
173            narratives,
174            stats,
175        }
176    }
177
178    /// Write generated data to output directory.
179    pub fn write_output(&self, data: &BankingData, output_dir: &Path) -> std::io::Result<()> {
180        std::fs::create_dir_all(output_dir)?;
181
182        // Write customers
183        self.write_csv(&data.customers, &output_dir.join("banking_customers.csv"))?;
184
185        // Write accounts
186        self.write_csv(&data.accounts, &output_dir.join("banking_accounts.csv"))?;
187
188        // Write transactions
189        self.write_csv(
190            &data.transactions,
191            &output_dir.join("banking_transactions.csv"),
192        )?;
193
194        // Write labels
195        self.write_csv(
196            &data.transaction_labels,
197            &output_dir.join("transaction_labels.csv"),
198        )?;
199        self.write_csv(
200            &data.customer_labels,
201            &output_dir.join("customer_labels.csv"),
202        )?;
203        self.write_csv(&data.account_labels, &output_dir.join("account_labels.csv"))?;
204        self.write_csv(
205            &data.relationship_labels,
206            &output_dir.join("relationship_labels.csv"),
207        )?;
208
209        // Write narratives as JSON
210        self.write_json(&data.narratives, &output_dir.join("case_narratives.json"))?;
211
212        // Write counterparties
213        self.write_csv(
214            &data.counterparties.merchants,
215            &output_dir.join("merchants.csv"),
216        )?;
217        self.write_csv(
218            &data.counterparties.employers,
219            &output_dir.join("employers.csv"),
220        )?;
221
222        Ok(())
223    }
224
225    /// Write data to CSV file.
226    fn write_csv<T: serde::Serialize>(&self, data: &[T], path: &Path) -> std::io::Result<()> {
227        let mut writer = csv::Writer::from_path(path)?;
228        for item in data {
229            writer.serialize(item)?;
230        }
231        writer.flush()?;
232        Ok(())
233    }
234
235    /// Write data to JSON file.
236    fn write_json<T: serde::Serialize>(&self, data: &T, path: &Path) -> std::io::Result<()> {
237        let file = std::fs::File::create(path)?;
238        serde_json::to_writer_pretty(file, data)?;
239        Ok(())
240    }
241}
242
243/// Builder for BankingOrchestrator.
244pub struct BankingOrchestratorBuilder {
245    config: Option<BankingConfig>,
246    seed: u64,
247    country_pack: Option<datasynth_core::CountryPack>,
248}
249
250impl Default for BankingOrchestratorBuilder {
251    fn default() -> Self {
252        Self {
253            config: None,
254            seed: 42,
255            country_pack: None,
256        }
257    }
258}
259
260impl BankingOrchestratorBuilder {
261    /// Create a new builder.
262    pub fn new() -> Self {
263        Self::default()
264    }
265
266    /// Set the configuration.
267    pub fn config(mut self, config: BankingConfig) -> Self {
268        self.config = Some(config);
269        self
270    }
271
272    /// Set the random seed.
273    pub fn seed(mut self, seed: u64) -> Self {
274        self.seed = seed;
275        self
276    }
277
278    /// Set the country pack for locale-aware data generation.
279    pub fn country_pack(mut self, pack: datasynth_core::CountryPack) -> Self {
280        self.country_pack = Some(pack);
281        self
282    }
283
284    /// Build the orchestrator.
285    pub fn build(self) -> BankingOrchestrator {
286        let mut orch = BankingOrchestrator::new(self.config.unwrap_or_default(), self.seed);
287        if let Some(pack) = self.country_pack {
288            orch.set_country_pack(pack);
289        }
290        orch
291    }
292}
293
294#[cfg(test)]
295#[allow(clippy::unwrap_used)]
296mod tests {
297    use super::*;
298
299    #[test]
300    fn test_orchestrator_generation() {
301        let config = BankingConfig::small();
302        let orchestrator = BankingOrchestrator::new(config, 12345);
303
304        let data = orchestrator.generate();
305
306        assert!(!data.customers.is_empty());
307        assert!(!data.accounts.is_empty());
308        assert!(!data.transactions.is_empty());
309        assert!(!data.transaction_labels.is_empty());
310        assert!(!data.customer_labels.is_empty());
311
312        // Stats should be populated
313        assert!(data.stats.customer_count > 0);
314        assert!(data.stats.transaction_count > 0);
315    }
316
317    #[test]
318    fn test_builder() {
319        let orchestrator = BankingOrchestratorBuilder::new()
320            .config(BankingConfig::small())
321            .seed(12345)
322            .build();
323
324        let data = orchestrator.generate();
325        assert!(!data.customers.is_empty());
326    }
327}