Skip to main content

datasynth_generators/
priors_loader.rs

1//! Loads an SP2 industry-priors bundle and builds the SP3 sampler state.
2//!
3//! This module provides [`LoadedPriors`], which wraps the fully-built runtime
4//! samplers consumed by `je_generator` and friends.
5//!
6//! # Architecture note
7//!
8//! `datasynth-generators` cannot depend on `datasynth-fingerprint` directly
9//! (that crate depends on `datasynth-eval` which in turn depends on
10//! `datasynth-generators`, creating a package cycle).  Therefore:
11//!
12//! - `BehavioralPriors` and its sub-types live in
13//!   `datasynth_core::distributions::behavioral_priors` and are re-exported
14//!   from `datasynth-fingerprint`.
15//! - File-loading (`load_bundled` / `load_from_path`) is exposed here via a
16//!   thin `zip` + `serde_yaml` reader that is independent of the fingerprint
17//!   crate.
18//! - Callers that already hold a deserialized [`BehavioralPriors`] (e.g. from
19//!   the fingerprint SDK) can use [`LoadedPriors::from_priors`] directly.
20
21use std::collections::HashMap;
22use std::io::{Read, Seek};
23use std::path::{Path, PathBuf};
24
25use rand::Rng;
26use thiserror::Error;
27
28use datasynth_core::distributions::behavioral_priors::{
29    CoaSemanticPrior, LinesPerJePrior, PerSourceAmountPrior, PerSourceAttributePrior,
30    PerSourceRolePrior, PostingLagPrior, ReferenceFormatPrior, SourceMixPrior, TbAnchorPrior,
31    UserPersonaPrior,
32};
33use datasynth_core::distributions::text_taxonomy::TextTaxonomyPrior;
34use datasynth_core::distributions::{
35    behavioral_priors::BehavioralPriors, BipartiteFanoutSampler, ConditionalIETSampler,
36    CrossEntityMotifSampler, MultiSegmentActiveWindow, SourceActiveWindow, SourceIetState,
37};
38
39/// Key used inside `.dsf` ZIP archives for the behavioral section.
40const BEHAVIORAL_YAML_KEY: &str = "behavioral.yaml";
41
42#[derive(Debug, Error)]
43pub enum PriorsLoadError {
44    #[error("priors bundle not found at {0}")]
45    NotFound(PathBuf),
46    #[error("bundle has no behavioral section")]
47    MissingBehavioral,
48    #[error("bundle industry mismatch: bundle={bundle}, requested={requested}")]
49    IndustryMismatch { bundle: String, requested: String },
50    #[error("fingerprint read error: {0}")]
51    Read(String),
52}
53
54/// Conventional resource directory for committed industry-priors bundles.
55pub fn bundled_priors_dir() -> PathBuf {
56    PathBuf::from(env!("CARGO_MANIFEST_DIR"))
57        .join("resources")
58        .join("priors")
59}
60
61/// Resolve the bundled `.dsf` path for an industry slug.
62pub fn bundled_priors_path(industry: &str) -> PathBuf {
63    bundled_priors_dir().join(format!("industry_priors_{industry}.dsf"))
64}
65
66/// Fully-built runtime priors consumed by `je_generator` and friends.
67#[derive(Clone)]
68pub struct LoadedPriors {
69    pub industry: String,
70    pub bundle_path: PathBuf,
71    pub source_mix: SourceMixPrior,
72    pub iet_sampler: ConditionalIETSampler,
73    pub lines_per_je: LinesPerJePrior,
74    pub active_window: SourceActiveWindow,
75    /// SP3.2 — when `Some`, supersedes `active_window` for `is_active` checks.
76    pub multi_segment_window: Option<MultiSegmentActiveWindow>,
77    pub fanout_samplers: HashMap<String, BipartiteFanoutSampler>,
78    pub posting_lag: Option<PostingLagPrior>,
79    /// SP3.3 — cross-entity motif sampler. None when bundle has no entity_clusters.
80    pub cross_entity_motifs: Option<CrossEntityMotifSampler>,
81    /// SP3.7 — per-source conditional attribute distributions.  When `Some`,
82    /// downstream attribute sampling (GL account, cost center, profit center)
83    /// is constrained to the values characteristic of the just-drawn source
84    /// code rather than the marginal distribution over all sources.
85    pub per_source_attribute: Option<PerSourceAttributePrior>,
86    /// SP3.12 — TP motif sampler. Built from `tp_entity_clusters`; biases the
87    /// TP draw toward cluster-mates of recently-emitted TPs on the same source
88    /// to build triangle structure in the TP co-occurrence graph.
89    pub tp_motif_sampler: Option<CrossEntityMotifSampler>,
90    /// SP4.7 — Per-source reference-string format templates.  When `Some`, the
91    /// JE generator calls `sample_reference` to produce a reference string that
92    /// matches the corpus format pattern for the current source code.
93    pub reference_formats: Option<ReferenceFormatPrior>,
94    /// SP4.2 — CoA semantic content extracted from corpus CoA parquet
95    /// files.  When `Some`, the CoA generator overwrites generic account
96    /// descriptions with corpus names and ISO 21378 hierarchy values.
97    pub coa_semantic: Option<CoaSemanticPrior>,
98    /// SP4.5 — Per-user behavioral patterns (source mix, hourly density,
99    /// weekday density, volume share).  When `Some` and `has_data()` is true,
100    /// `je_generator` biases `created_by` and `created_at` toward the
101    /// characteristic patterns of each user.
102    ///
103    /// `None` or empty: generator falls back to the internal user pool.
104    pub user_personas: Option<UserPersonaPrior>,
105    /// SP4.3 — Per-(source, gl_prefix) log-normal amount parameters.  When
106    /// `Some`, `je_generator` draws JE total-amounts from the source-conditional
107    /// distribution rather than the global `AmountSampler` marginal.  Fraud
108    /// entries bypass this path to preserve fraud-pattern semantics.
109    ///
110    /// `None` means the bundle was built before SP4.3 or had too few rows.
111    pub source_amount_conditionals: Option<PerSourceAmountPrior>,
112    /// SP4.6 — Per-(source, line_role) GL account conditional.
113    ///
114    /// When `Some`, callers use `sample_gl_for_source_role(source, "DR"|"CR")` to
115    /// draw a GL account that respects the debit/credit line role for the given SAP
116    /// document type.  Falls back to `sample_attribute_for_source` → default when
117    /// the pair is missing.
118    ///
119    /// `None` means the bundle was built before SP4.6 or had too few rows.
120    pub source_role_gl: Option<PerSourceRolePrior>,
121    /// SP4.1 — Trial-balance anchor prior.
122    ///
123    /// When `Some` and `has_data()` is true, the `RunningBalanceTracker` can
124    /// use these industry-median per-account targets to generate periodic
125    /// drift-correction entries that keep the synthetic balance sheet shaped
126    /// like a corpus balance sheet.
127    ///
128    /// `None` means the bundle was built before SP4.1 (current committed
129    /// bundles) or had no TB data — the balance tracker operates in its
130    /// existing free-drift mode in that case.
131    pub tb_anchor: Option<TbAnchorPrior>,
132    /// SP6 — corpus text taxonomy. When `Some`, provides
133    /// `(source, account-class)` line pools, `source` header pools, and
134    /// per-account CoA description templates.
135    pub text_taxonomy: Option<TextTaxonomyPrior>,
136}
137
138impl LoadedPriors {
139    /// Load the bundled `.dsf` for `industry` from the crate's
140    /// `resources/priors/` directory.
141    pub fn load_bundled<R: Rng>(
142        industry: &str,
143        rng: &mut R,
144        period_days: i64,
145    ) -> Result<Self, PriorsLoadError> {
146        Self::load_from_path(
147            &bundled_priors_path(industry),
148            rng,
149            period_days,
150            Some(industry),
151        )
152    }
153
154    /// Load a `.dsf` bundle from an arbitrary path.
155    pub fn load_from_path<R: Rng>(
156        path: &Path,
157        rng: &mut R,
158        period_days: i64,
159        expected_industry: Option<&str>,
160    ) -> Result<Self, PriorsLoadError> {
161        if !path.exists() {
162            return Err(PriorsLoadError::NotFound(path.to_path_buf()));
163        }
164        let file = std::fs::File::open(path).map_err(|e| PriorsLoadError::Read(e.to_string()))?;
165        let bp = read_behavioral_from_dsf(file)
166            .map_err(PriorsLoadError::Read)?
167            .ok_or(PriorsLoadError::MissingBehavioral)?;
168        if let Some(want) = expected_industry {
169            if bp.industry != want {
170                return Err(PriorsLoadError::IndustryMismatch {
171                    bundle: bp.industry,
172                    requested: want.to_string(),
173                });
174            }
175        }
176        Self::from_priors(bp, path.to_path_buf(), rng, period_days)
177    }
178
179    /// Build `LoadedPriors` from an already-deserialised [`BehavioralPriors`].
180    ///
181    /// Use this when you already have a `BehavioralPriors` from the
182    /// `datasynth-fingerprint` SDK and don't need file I/O here.
183    pub fn from_priors<R: Rng>(
184        bp: BehavioralPriors,
185        bundle_path: PathBuf,
186        rng: &mut R,
187        period_days: i64,
188    ) -> Result<Self, PriorsLoadError> {
189        let mut per_source_states: HashMap<String, SourceIetState> = HashMap::new();
190        for (src, summ) in &bp.per_source_iet.by_source {
191            per_source_states.insert(
192                src.clone(),
193                SourceIetState {
194                    cdf_values: summ.empirical_cdf_days.values.clone(),
195                    cdf_probabilities: summ.empirical_cdf_days.probabilities.clone(),
196                    lag1_autocorr: summ.lag1_autocorr,
197                    last_iet_days: None,
198                },
199            );
200        }
201        let fallback = SourceIetState {
202            cdf_values: vec![1.0],
203            cdf_probabilities: vec![1.0],
204            lag1_autocorr: 0.0,
205            last_iet_days: None,
206        };
207        let iet_sampler = ConditionalIETSampler::from_state_map(per_source_states, fallback);
208
209        let lifetime_hist = bp.active_lifetime.overall.clone();
210        let sources: Vec<String> = bp.source_mix.probabilities.keys().cloned().collect();
211        let active_window = SourceActiveWindow::build(
212            &sources,
213            period_days,
214            |r| lifetime_hist.sample_bucket(r) as i64,
215            rng,
216        );
217
218        let multi_segment_window = bp.active_segments.as_ref().map(|prior| {
219            let lifetime_hist = bp.active_lifetime.overall.clone();
220            MultiSegmentActiveWindow::build_from_prior(
221                &sources,
222                period_days,
223                prior,
224                |r| lifetime_hist.sample_bucket(r) as i64,
225                rng,
226            )
227        });
228
229        let mut fanout_samplers: HashMap<String, BipartiteFanoutSampler> = HashMap::new();
230        for (attr, hist) in &bp.fanout.by_attribute {
231            const N_POOL: usize = 256;
232            let targets: Vec<u32> = (0..N_POOL).map(|_| hist.sample_bucket(rng)).collect();
233            let attr_prefix = attr.clone();
234            let sampler = BipartiteFanoutSampler::new_with_targets(targets, move |i| {
235                format!("{attr_prefix}-{i:04}")
236            });
237            fanout_samplers.insert(attr.clone(), sampler);
238        }
239
240        let cross_entity_motifs = bp
241            .entity_clusters
242            .as_ref()
243            .map(CrossEntityMotifSampler::from_prior);
244
245        let tp_motif_sampler = bp
246            .tp_entity_clusters
247            .as_ref()
248            .map(CrossEntityMotifSampler::from_prior);
249
250        let per_source_attribute = bp.per_source_attribute.clone();
251        let reference_formats = bp.reference_formats.clone();
252        let coa_semantic = bp.coa_semantic.clone();
253        // SP4.5 — carry user_personas through from the bundle.  When the bundle
254        // was built without user-column data (the current corpus), this is either
255        // None or Some(empty-stub).  The generator guards with `has_data()`.
256        let user_personas = bp.user_personas.clone();
257        // SP4.3 — carry source_amount_conditionals through from the bundle.
258        // Old bundles (pre-SP4.3) will have None here; generators fall back to
259        // the existing AmountSampler.
260        let source_amount_conditionals = bp.source_amount_conditionals.clone();
261        // SP4.6 — carry source_role_gl_conditionals through from the bundle.
262        // Old bundles (pre-SP4.6) will have None here; generators fall back to
263        // sample_attribute_for_source then to default GL accounts.
264        let source_role_gl = bp.source_role_gl_conditionals.clone();
265        // SP4.1 — carry tb_anchor through from the bundle.  Old bundles (pre-SP4.1)
266        // will have None here; the balance tracker operates in free-drift mode.
267        let tb_anchor = bp.tb_anchor.clone();
268        // SP6 — carry text_taxonomy through. Old bundles (pre-SP6) have None;
269        // generators fall back to the DescriptionGenerator.
270        let text_taxonomy = bp.text_taxonomy.clone();
271
272        Ok(LoadedPriors {
273            industry: bp.industry.clone(),
274            bundle_path,
275            source_mix: bp.source_mix,
276            iet_sampler,
277            lines_per_je: bp.lines_per_je,
278            active_window,
279            multi_segment_window,
280            fanout_samplers,
281            posting_lag: bp.posting_lag,
282            cross_entity_motifs,
283            per_source_attribute,
284            tp_motif_sampler,
285            reference_formats,
286            coa_semantic,
287            user_personas,
288            source_amount_conditionals,
289            source_role_gl,
290            tb_anchor,
291            text_taxonomy,
292        })
293    }
294}
295
296impl LoadedPriors {
297    /// SP3.7 — Try to sample `attribute` value conditional on the given
298    /// `source` code.  Returns `None` when either the prior is absent,
299    /// the source isn't represented in the prior, the attribute isn't
300    /// present for that source, or the conditional distribution is empty.
301    ///
302    /// Caller falls back to the marginal sampler in that case.
303    pub fn sample_attribute_for_source<R: rand::Rng>(
304        &self,
305        source: &str,
306        attribute: &str,
307        rng: &mut R,
308    ) -> Option<String> {
309        self.per_source_attribute
310            .as_ref()?
311            .conditional(source, attribute)?
312            .sample(rng)
313    }
314
315    /// SP4.5 — Sample a user ID likely to post the given `source` code from the
316    /// user-persona prior.
317    ///
318    /// Returns `None` when:
319    /// - No `user_personas` prior was loaded, OR
320    /// - The prior is empty (no user-column data; typical for current corpus), OR
321    /// - No user has a non-zero weight for `source`.
322    ///
323    /// Callers fall back to the internal user pool (`select_user`) in all of these
324    /// cases — the prior is purely additive.
325    pub fn sample_user_for_source<R: rand::Rng>(
326        &self,
327        source: &str,
328        rng: &mut R,
329    ) -> Option<String> {
330        self.user_personas
331            .as_ref()
332            .filter(|up| up.has_data())
333            .and_then(|up| up.sample_user_for_source(source, rng))
334    }
335
336    /// SP4.5 — Given a `user_id` from the prior, sample an `(hour, weekday)` pair
337    /// from the user's characteristic density.
338    ///
339    /// Returns `None` when the prior is absent, empty, or the user ID is unknown.
340    /// `hour` ∈ 0..24, `weekday` ∈ 0..7 (Monday = 0).
341    pub fn sample_timestamp_for_user<R: rand::Rng>(
342        &self,
343        user_id: &str,
344        rng: &mut R,
345    ) -> Option<(u32, u32)> {
346        self.user_personas
347            .as_ref()
348            .filter(|up| up.has_data())
349            .and_then(|up| up.sample_timestamp_for_user(user_id, rng))
350    }
351
352    /// SP4.3 — Sample a JE total-amount magnitude for `source` (and optionally
353    /// `gl_prefix`) from the per-(source, gl_prefix) log-normal prior.
354    ///
355    /// Lookup strategy:
356    /// 1. Try `(source, gl_prefix)` when `gl_prefix` is non-empty.
357    /// 2. Fall back to the source-marginal.
358    /// 3. Return `None` when the prior is absent or the source isn't represented.
359    ///
360    /// Callers fall back to the existing `AmountSampler` when `None` is returned.
361    /// Fraud entries should bypass this helper entirely — the caller is responsible
362    /// for that guard.
363    pub fn sample_amount_for_source<R: rand::Rng>(
364        &self,
365        source: &str,
366        gl_prefix: &str,
367        rng: &mut R,
368    ) -> Option<f64> {
369        let p = self.source_amount_conditionals.as_ref()?;
370        // Try (source, gl_prefix) first.
371        if !gl_prefix.is_empty() {
372            if let Some(per_class) = p.by_source_and_class.get(source) {
373                if let Some(params) = per_class.get(gl_prefix) {
374                    return Some(params.sample(rng));
375                }
376            }
377        }
378        // Fall back to source-marginal.
379        p.by_source.get(source).map(|params| params.sample(rng))
380    }
381
382    /// SP4.6 — Sample a GL account conditioned on `(source, role)` where `role`
383    /// is `"DR"` or `"CR"`.
384    ///
385    /// Returns `None` when:
386    /// - No `source_role_gl` prior was loaded (old bundle), OR
387    /// - The `(source, role)` pair isn't represented (sparse corpus), OR
388    /// - The distribution is empty.
389    ///
390    /// Callers must fall back to `sample_attribute_for_source(source, "gl_account", ...)`
391    /// and ultimately to a hard-coded default GL when `None` is returned.
392    pub fn sample_gl_for_source_role<R: rand::Rng>(
393        &self,
394        source: &str,
395        role: &str,
396        rng: &mut R,
397    ) -> Option<String> {
398        self.source_role_gl
399            .as_ref()?
400            .conditional(source, role)?
401            .sample(rng)
402    }
403
404    /// SP4.7 — Sample a reference string for `source` from the reference-format
405    /// prior.  Returns `None` when the prior is absent or the source has no
406    /// templates (caller falls back to the existing `format!(...)` template).
407    pub fn sample_reference<R: rand::Rng>(&self, source: &str, rng: &mut R) -> Option<String> {
408        let rf = self.reference_formats.as_ref()?;
409        let templates = rf.by_source.get(source)?;
410        if templates.is_empty() {
411            return None;
412        }
413        // Weighted sample by probability mass.
414        let total: f64 = templates.iter().map(|t| t.probability).sum();
415        if total <= 0.0 {
416            return None;
417        }
418        use rand::RngExt;
419        let r: f64 = rng.random_range(0.0..total);
420        let mut cum = 0.0;
421        for t in templates {
422            cum += t.probability;
423            if r <= cum {
424                return Some(fill_reference_template(&t.template, rng));
425            }
426        }
427        // Floating-point rounding: return last template.
428        templates
429            .last()
430            .map(|t| fill_reference_template(&t.template, rng))
431    }
432
433    /// SP6 — Sample a line-text string for `(source, account_class)` from the
434    /// text-taxonomy prior, filling placeholders via `resolver`.
435    ///
436    /// Lookup cascade:
437    /// 1. `line_pools["SOURCE|CLASS"]`
438    /// 2. `line_pools["SOURCE|_unknown_"]`
439    /// 3. `header_pools["SOURCE"]` (last resort — source-level vocabulary)
440    ///
441    /// Returns `None` only when the prior is absent or the source has no pools
442    /// at any cascade tier — the caller then falls back to the `DescriptionGenerator`.
443    pub fn sample_line_template<R: rand::Rng>(
444        &self,
445        source: &str,
446        account_class: &str,
447        resolver: &mut dyn datasynth_core::distributions::text_taxonomy::PlaceholderResolver,
448        rng: &mut R,
449    ) -> Option<String> {
450        let tx = self.text_taxonomy.as_ref()?;
451        let class_key = TextTaxonomyPrior::line_key(source, account_class);
452        let unknown_key = TextTaxonomyPrior::line_key(source, TextTaxonomyPrior::UNKNOWN_CLASS);
453        let pool = tx
454            .line_pools
455            .get(&class_key)
456            .or_else(|| tx.line_pools.get(&unknown_key))
457            .or_else(|| tx.header_pools.get(source))?;
458        sample_pool_filled(pool, resolver, rng)
459    }
460
461    /// SP6 — Sample a header-text string for `source` from the text-taxonomy
462    /// prior. Returns `None` when absent / no pool for the source.
463    pub fn sample_header_template<R: rand::Rng>(
464        &self,
465        source: &str,
466        resolver: &mut dyn datasynth_core::distributions::text_taxonomy::PlaceholderResolver,
467        rng: &mut R,
468    ) -> Option<String> {
469        let tx = self.text_taxonomy.as_ref()?;
470        let pool = tx.header_pools.get(source)?;
471        sample_pool_filled(pool, resolver, rng)
472    }
473
474    /// SP6 — Fill the CoA description template for `account_no`. Returns `None`
475    /// when the prior is absent or the account has no template.
476    pub fn sample_coa_description<R: rand::Rng>(
477        &self,
478        account_no: &str,
479        resolver: &mut dyn datasynth_core::distributions::text_taxonomy::PlaceholderResolver,
480        rng: &mut R,
481    ) -> Option<String> {
482        let tx = self.text_taxonomy.as_ref()?;
483        let entry = tx.coa_pools.get(account_no)?;
484        Some(
485            datasynth_core::distributions::text_taxonomy::PlaceholderGrammar::fill(
486                &entry.template,
487                resolver,
488                rng,
489            ),
490        )
491    }
492}
493
494/// SP6 — Weighted-pick a `TemplateEntry` from a `TemplatePool` and fill it.
495fn sample_pool_filled<R: rand::Rng>(
496    pool: &datasynth_core::distributions::text_taxonomy::TemplatePool,
497    resolver: &mut dyn datasynth_core::distributions::text_taxonomy::PlaceholderResolver,
498    rng: &mut R,
499) -> Option<String> {
500    use datasynth_core::distributions::text_taxonomy::PlaceholderGrammar;
501    use rand::RngExt;
502    if pool.templates.is_empty() {
503        return None;
504    }
505    let total: f64 = pool.templates.iter().map(|t| t.probability).sum();
506    if total <= 0.0 {
507        return None;
508    }
509    let r: f64 = rng.random_range(0.0..total);
510    let mut cum = 0.0;
511    for t in &pool.templates {
512        cum += t.probability;
513        if r <= cum {
514            return Some(PlaceholderGrammar::fill(&t.template, resolver, rng));
515        }
516    }
517    pool.templates
518        .last()
519        .map(|t| PlaceholderGrammar::fill(&t.template, resolver, rng))
520}
521
522/// Fill a reference format template by replacing `{N digits}` and `{N alpha}`
523/// placeholders with random strings of the indicated length and character class.
524/// Fixed characters in the template are reproduced verbatim.
525fn fill_reference_template<R: rand::Rng>(template: &str, rng: &mut R) -> String {
526    use rand::RngExt;
527    if template.is_empty() {
528        return String::new();
529    }
530    let mut result = String::with_capacity(template.len() * 2);
531    let mut chars = template.char_indices().peekable();
532
533    while let Some((i, ch)) = chars.next() {
534        if ch == '{' {
535            // Find the closing '}'
536            let rest = &template[i..];
537            if let Some(close_offset) = rest.find('}') {
538                let inner = &rest[1..close_offset];
539                // Advance the iterator past the closing '}'
540                let end_byte = i + close_offset + 1;
541                // Consume chars up to end_byte
542                while chars.peek().map(|(j, _)| *j < end_byte).unwrap_or(false) {
543                    chars.next();
544                }
545                if let Some((count, kind)) = parse_ref_placeholder(inner) {
546                    match kind {
547                        RefPlaceholderKind::Digits => {
548                            for _ in 0..count {
549                                result.push(char::from(b'0' + rng.random_range(0u8..10)));
550                            }
551                        }
552                        RefPlaceholderKind::Alpha => {
553                            for _ in 0..count {
554                                result.push(char::from(b'A' + rng.random_range(0u8..26)));
555                            }
556                        }
557                    }
558                } else {
559                    result.push('{');
560                    result.push_str(inner);
561                    result.push('}');
562                }
563            } else {
564                result.push(ch);
565            }
566        } else {
567            result.push(ch);
568        }
569    }
570    result
571}
572
573enum RefPlaceholderKind {
574    Digits,
575    Alpha,
576}
577
578fn parse_ref_placeholder(inner: &str) -> Option<(usize, RefPlaceholderKind)> {
579    let inner = inner.trim();
580    if let Some(rest) = inner.strip_suffix("digits") {
581        let n: usize = rest.trim().parse().ok()?;
582        Some((n, RefPlaceholderKind::Digits))
583    } else if let Some(rest) = inner.strip_suffix("alpha") {
584        let n: usize = rest.trim().parse().ok()?;
585        Some((n, RefPlaceholderKind::Alpha))
586    } else {
587        None
588    }
589}
590
591// ---------------------------------------------------------------------------
592// Internal: minimal .dsf reader (ZIP + YAML)
593// ---------------------------------------------------------------------------
594
595/// Read and deserialize the `behavioral.yaml` component from a `.dsf` archive.
596/// Returns `None` if the archive does not contain a behavioral section.
597fn read_behavioral_from_dsf<R: Read + Seek>(reader: R) -> Result<Option<BehavioralPriors>, String> {
598    let mut archive = zip::ZipArchive::new(reader).map_err(|e| format!("zip open: {e}"))?;
599    for i in 0..archive.len() {
600        let mut entry = archive
601            .by_index(i)
602            .map_err(|e| format!("zip entry {i}: {e}"))?;
603        if entry.name() == BEHAVIORAL_YAML_KEY {
604            let mut buf = String::new();
605            entry
606                .read_to_string(&mut buf)
607                .map_err(|e| format!("read {BEHAVIORAL_YAML_KEY}: {e}"))?;
608            let bp: BehavioralPriors = serde_yaml::from_str(&buf)
609                .map_err(|e| format!("deserialize behavioral.yaml: {e}"))?;
610            return Ok(Some(bp));
611        }
612    }
613    Ok(None)
614}
615
616#[cfg(test)]
617mod tests {
618    use super::*;
619    use rand::SeedableRng;
620    use rand_chacha::ChaCha8Rng;
621
622    #[test]
623    fn bundled_priors_path_known() {
624        let p = bundled_priors_path("health");
625        assert!(p.ends_with("industry_priors_health.dsf"));
626    }
627
628    #[test]
629    fn load_bundled_health_actually_works() {
630        let p = bundled_priors_path("health");
631        if !p.exists() {
632            eprintln!("skipping: {} not present", p.display());
633            return;
634        }
635        let mut rng = ChaCha8Rng::seed_from_u64(42);
636        let priors =
637            LoadedPriors::load_bundled("health", &mut rng, 365).expect("load_bundled health");
638        assert_eq!(priors.industry, "health");
639        assert!(!priors.source_mix.probabilities.is_empty());
640        assert!(priors.fanout_samplers.contains_key("GLAccount"));
641    }
642
643    #[test]
644    fn load_from_path_not_found() {
645        let mut rng = ChaCha8Rng::seed_from_u64(42);
646        let result =
647            LoadedPriors::load_from_path(Path::new("/nonexistent.dsf"), &mut rng, 365, None);
648        assert!(result.is_err());
649        assert!(matches!(
650            result.err().expect("expected err"),
651            PriorsLoadError::NotFound(_)
652        ));
653    }
654
655    // ---- SP4.6 tests -------------------------------------------------------
656
657    use datasynth_core::distributions::behavioral_priors::{
658        BehavioralPriors, CategoricalDistribution, PerSourceRolePrior,
659    };
660    use std::collections::BTreeMap;
661
662    fn make_kr_role_prior() -> PerSourceRolePrior {
663        let mut dr_counts = BTreeMap::new();
664        dr_counts.insert("6000".to_string(), 100usize);
665        dr_counts.insert("6100".to_string(), 50usize);
666        let mut cr_counts = BTreeMap::new();
667        cr_counts.insert("2000".to_string(), 150usize);
668
669        let mut role_map = BTreeMap::new();
670        role_map.insert(
671            "DR".to_string(),
672            CategoricalDistribution::from_counts(dr_counts),
673        );
674        role_map.insert(
675            "CR".to_string(),
676            CategoricalDistribution::from_counts(cr_counts),
677        );
678
679        let mut by_source_and_role = BTreeMap::new();
680        by_source_and_role.insert("KR".to_string(), role_map);
681        PerSourceRolePrior { by_source_and_role }
682    }
683
684    fn minimal_bp_with_role_prior(role_prior: PerSourceRolePrior) -> BehavioralPriors {
685        use datasynth_core::distributions::behavioral_priors::*;
686        BehavioralPriors {
687            schema_version: BehavioralPriors::SCHEMA_VERSION,
688            generator_version: "test".to_string(),
689            industry: "test".to_string(),
690            n_client_inputs: 1,
691            n_rows_aggregated: 1000,
692            source_mix: SourceMixPrior::default(),
693            per_source_iet: PerSourceIetPrior::default(),
694            lines_per_je: LinesPerJePrior::default(),
695            active_lifetime: ActiveLifetimePrior::default(),
696            fanout: FanoutPrior::default(),
697            posting_lag: None,
698            active_segments: None,
699            entity_clusters: None,
700            per_source_attribute: None,
701            tp_entity_clusters: None,
702            coa_semantic: None,
703            reference_formats: None,
704            text_taxonomy: None,
705            user_personas: None,
706            source_amount_conditionals: None,
707            source_role_gl_conditionals: Some(role_prior),
708            tb_anchor: None,
709        }
710    }
711
712    /// `sample_gl_for_source_role` returns only DR-class accounts when role="DR".
713    #[test]
714    fn sp4_6_sample_gl_for_source_role_dr_returns_expense_accounts() {
715        let bp = minimal_bp_with_role_prior(make_kr_role_prior());
716        let mut rng = ChaCha8Rng::seed_from_u64(42);
717        let priors = LoadedPriors::from_priors(bp, std::path::PathBuf::from("test"), &mut rng, 365)
718            .expect("from_priors");
719
720        let mut rng2 = ChaCha8Rng::seed_from_u64(77);
721        for _ in 0..100 {
722            let v = priors.sample_gl_for_source_role("KR", "DR", &mut rng2);
723            assert!(v.is_some(), "should return Some for KR/DR");
724            let v = v.unwrap();
725            assert!(
726                v == "6000" || v == "6100",
727                "DR draw must be expense account, got {v}"
728            );
729        }
730    }
731
732    /// `sample_gl_for_source_role` returns `None` when the prior is absent.
733    #[test]
734    fn sp4_6_sample_gl_for_source_role_returns_none_when_prior_absent() {
735        use datasynth_core::distributions::behavioral_priors::*;
736        let bp = BehavioralPriors {
737            schema_version: BehavioralPriors::SCHEMA_VERSION,
738            generator_version: "test".to_string(),
739            industry: "test".to_string(),
740            n_client_inputs: 0,
741            n_rows_aggregated: 0,
742            source_mix: SourceMixPrior::default(),
743            per_source_iet: PerSourceIetPrior::default(),
744            lines_per_je: LinesPerJePrior::default(),
745            active_lifetime: ActiveLifetimePrior::default(),
746            fanout: FanoutPrior::default(),
747            posting_lag: None,
748            active_segments: None,
749            entity_clusters: None,
750            per_source_attribute: None,
751            tp_entity_clusters: None,
752            coa_semantic: None,
753            reference_formats: None,
754            text_taxonomy: None,
755            user_personas: None,
756            source_amount_conditionals: None,
757            source_role_gl_conditionals: None, // no role prior
758            tb_anchor: None,
759        };
760        let mut rng = ChaCha8Rng::seed_from_u64(42);
761        let priors = LoadedPriors::from_priors(bp, std::path::PathBuf::from("test"), &mut rng, 365)
762            .expect("from_priors");
763
764        let result = priors.sample_gl_for_source_role("KR", "DR", &mut rng);
765        assert!(
766            result.is_none(),
767            "must return None when source_role_gl is absent"
768        );
769    }
770
771    // ---- SP6 tests -----------------------------------------------------------
772
773    use datasynth_core::distributions::text_taxonomy::{
774        TemplateEntry, TemplatePool, TextTaxonomyPrior,
775    };
776
777    fn bp_with_text_taxonomy() -> BehavioralPriors {
778        let mut tx = TextTaxonomyPrior::default();
779        tx.line_pools.insert(
780            "KR|A.B".to_string(),
781            TemplatePool {
782                templates: vec![TemplateEntry {
783                    template: "Rechnung Eingang".to_string(),
784                    probability: 1.0,
785                    synthetic_example: "Rechnung Eingang".to_string(),
786                }],
787                n: 50,
788            },
789        );
790        tx.line_pools.insert(
791            "KR|_unknown_".to_string(),
792            TemplatePool {
793                templates: vec![TemplateEntry {
794                    template: "Diverse".to_string(),
795                    probability: 1.0,
796                    synthetic_example: "Diverse".to_string(),
797                }],
798                n: 20,
799            },
800        );
801        tx.header_pools.insert(
802            "KR".to_string(),
803            TemplatePool {
804                templates: vec![TemplateEntry {
805                    template: "Monatsabschluss".to_string(),
806                    probability: 1.0,
807                    synthetic_example: "Monatsabschluss".to_string(),
808                }],
809                n: 30,
810            },
811        );
812        tx.coa_pools.insert(
813            "0000204000".to_string(),
814            TemplateEntry {
815                template: "Kreditoren".to_string(),
816                probability: 1.0,
817                synthetic_example: "Kreditoren".to_string(),
818            },
819        );
820        let mut bp = minimal_bp_with_role_prior(make_kr_role_prior());
821        bp.text_taxonomy = Some(tx);
822        bp
823    }
824
825    #[test]
826    fn sample_line_template_keyed_on_source_and_class() {
827        let mut rng = ChaCha8Rng::seed_from_u64(1);
828        let priors =
829            LoadedPriors::from_priors(bp_with_text_taxonomy(), PathBuf::from("t"), &mut rng, 365)
830                .expect("from_priors");
831        let mut resolver = datasynth_core::distributions::text_taxonomy::SyntheticExampleResolver;
832        let mut r2 = ChaCha8Rng::seed_from_u64(2);
833        // exact (source,class) hit
834        let v = priors.sample_line_template("KR", "A.B", &mut resolver, &mut r2);
835        assert_eq!(v, Some("Rechnung Eingang".to_string()));
836        // unknown class -> cascade to KR|_unknown_
837        let v = priors.sample_line_template("KR", "Z.Z", &mut resolver, &mut r2);
838        assert_eq!(v, Some("Diverse".to_string()));
839        // unknown source -> None (caller falls back)
840        let v = priors.sample_line_template("ZZ", "A.B", &mut resolver, &mut r2);
841        assert_eq!(v, None);
842    }
843
844    /// Cascade tier 3: when a source has NO `_unknown_` line pool but DOES
845    /// have a header pool, `sample_line_template` must fall through to the
846    /// header vocabulary. Uses a fixture without `RV|_unknown_` to force
847    /// tier 2 to miss.
848    #[test]
849    fn sample_line_template_falls_through_to_header_pool_at_tier_3() {
850        use datasynth_core::distributions::text_taxonomy::{
851            TemplateEntry, TemplatePool, TextTaxonomyPrior,
852        };
853        let mut tx = TextTaxonomyPrior::default();
854        // RV has only a header pool — no line pools at all.
855        tx.header_pools.insert(
856            "RV".to_string(),
857            TemplatePool {
858                templates: vec![TemplateEntry {
859                    template: "Header-only fallback".to_string(),
860                    probability: 1.0,
861                    synthetic_example: "Header-only fallback".to_string(),
862                }],
863                n: 5,
864            },
865        );
866        let mut bp = minimal_bp_with_role_prior(make_kr_role_prior());
867        bp.text_taxonomy = Some(tx);
868
869        let mut rng = ChaCha8Rng::seed_from_u64(1);
870        let priors =
871            LoadedPriors::from_priors(bp, PathBuf::from("t"), &mut rng, 365).expect("from_priors");
872        let mut resolver = datasynth_core::distributions::text_taxonomy::SyntheticExampleResolver;
873        let mut r2 = ChaCha8Rng::seed_from_u64(2);
874        // tier 1 miss (no RV|A.B), tier 2 miss (no RV|_unknown_),
875        // tier 3 hits (header_pools["RV"])
876        let v = priors.sample_line_template("RV", "A.B", &mut resolver, &mut r2);
877        assert_eq!(v, Some("Header-only fallback".to_string()));
878    }
879
880    #[test]
881    fn sample_coa_description_hits_account() {
882        let mut rng = ChaCha8Rng::seed_from_u64(1);
883        let priors =
884            LoadedPriors::from_priors(bp_with_text_taxonomy(), PathBuf::from("t"), &mut rng, 365)
885                .expect("from_priors");
886        let mut resolver = datasynth_core::distributions::text_taxonomy::SyntheticExampleResolver;
887        let mut r2 = ChaCha8Rng::seed_from_u64(2);
888        assert_eq!(
889            priors.sample_coa_description("0000204000", &mut resolver, &mut r2),
890            Some("Kreditoren".to_string())
891        );
892        assert_eq!(
893            priors.sample_coa_description("9999999999", &mut resolver, &mut r2),
894            None
895        );
896    }
897}