triplets 0.14.0-alpha

Composable data sampling primitives for deterministic multi-source ML/AI training-data orchestration.
Documentation
use crate::config::{SamplerConfig, TextRecipe, TripletRecipe};
pub use crate::constants::heuristics::{
    EFFECTIVE_NEGATIVES_PER_ANCHOR, EFFECTIVE_POSITIVES_PER_ANCHOR,
};
use crate::splits::{SplitLabel, SplitRatios};

#[derive(Debug, Default, Clone, Copy)]
/// Capacity estimate totals for one source/split scope.
pub struct CapacityTotals {
    /// Approximate raw triplet combinations.
    pub triplets: u128,
    /// Bounded practical triplet estimate using effective p/k assumptions.
    pub effective_triplets: u128,
    /// Approximate pair combinations derived from triplet assumptions.
    pub pairs: u128,
    /// Approximate text-sample capacity.
    pub text_samples: u128,
}

/// Estimate split-local capacity totals from record counts and recipe counts.
pub fn estimate_source_split_capacity_from_counts(
    source_records_in_split: u128,
    triplet_recipes: &[TripletRecipe],
    text_recipes: &[TextRecipe],
) -> CapacityTotals {
    let mut totals = CapacityTotals::default();

    for _recipe in triplet_recipes {
        let anchor_positive_pairs = source_records_in_split;
        let negative_count_per_anchor = source_records_in_split.saturating_sub(1);
        if anchor_positive_pairs == 0 || negative_count_per_anchor == 0 {
            continue;
        }
        totals.triplets += anchor_positive_pairs.saturating_mul(negative_count_per_anchor);
        totals.effective_triplets += source_records_in_split
            .saturating_mul(EFFECTIVE_POSITIVES_PER_ANCHOR)
            .saturating_mul(EFFECTIVE_NEGATIVES_PER_ANCHOR);
        totals.pairs += anchor_positive_pairs.saturating_mul(1 + negative_count_per_anchor);
    }

    for _recipe in text_recipes {
        totals.text_samples += source_records_in_split;
    }

    totals
}

/// Split `total` records into train/validation/test counts by ratio.
///
/// Train/validation are floored; remainder goes to test to preserve sum.
pub fn split_counts_for_total(total: u128, split: SplitRatios) -> [(SplitLabel, u128); 3] {
    let train = ((total as f64) * f64::from(split.train)).floor() as u128;
    let validation = ((total as f64) * f64::from(split.validation)).floor() as u128;
    let assigned = train.saturating_add(validation);
    let test = total.saturating_sub(assigned);
    [
        (SplitLabel::Train, train),
        (SplitLabel::Validation, validation),
        (SplitLabel::Test, test),
    ]
}

/// Format a large integer with comma separators.
pub fn format_u128_with_commas(value: u128) -> String {
    let raw = value.to_string();
    let mut grouped_reversed = String::with_capacity(raw.len() + (raw.len() / 3));
    for (idx, ch) in raw.chars().rev().enumerate() {
        if idx > 0 && idx % 3 == 0 {
            grouped_reversed.push(',');
        }
        grouped_reversed.push(ch);
    }
    grouped_reversed.chars().rev().collect()
}

/// Format how many times a shorter source may replay relative to the longest source.
pub fn format_replay_factor(longest_records: u128, source_records: u128) -> String {
    if longest_records == 0 || source_records == 0 {
        return "n/a".to_string();
    }
    let factor = longest_records as f64 / source_records as f64;
    format!("{factor:.2}x")
}

/// Resolve text recipes for a source, preferring explicit config over derived recipes.
pub fn resolve_text_recipes_for_source(
    config: &SamplerConfig,
    source_triplet_recipes: &[TripletRecipe],
) -> Vec<TextRecipe> {
    if !config.text_recipes.is_empty() {
        return config.text_recipes.clone();
    }
    if !config.recipes.is_empty() {
        return build_derived_text_recipes(&config.recipes);
    }
    build_derived_text_recipes(source_triplet_recipes)
}

/// Build anchor/positive/negative text recipes from triplet recipes.
pub fn build_derived_text_recipes(recipes: &[TripletRecipe]) -> Vec<TextRecipe> {
    let mut derived = Vec::new();
    for recipe in recipes {
        let base = recipe.name.as_ref();
        derived.push(TextRecipe {
            name: format!("{base}_anchor").into(),
            selector: recipe.anchor.clone(),
            weight: recipe.weight.max(0.0001),
            instruction: None,
        });
        derived.push(TextRecipe {
            name: format!("{base}_positive").into(),
            selector: recipe.positive_selector.clone(),
            weight: recipe.weight.max(0.0001),
            instruction: None,
        });
        derived.push(TextRecipe {
            name: format!("{base}_negative").into(),
            selector: recipe.negative_selector.clone(),
            weight: recipe.weight.max(0.0001),
            instruction: None,
        });
    }
    derived
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::config::{NegativeStrategy, Selector};
    use crate::data::SectionRole;

    fn sample_triplet_recipe(name: &str, weight: f32) -> TripletRecipe {
        TripletRecipe {
            name: name.to_string().into(),
            anchor: Selector::Role(SectionRole::Anchor),
            positive_selector: Selector::Role(SectionRole::Context),
            negative_selector: Selector::Role(SectionRole::Context),
            negative_strategy: NegativeStrategy::WrongArticle,
            weight,
            instruction: None,
            allow_same_anchor_positive: false,
        }
    }

    #[test]
    fn split_counts_sum_to_total() {
        let split = SplitRatios {
            train: 0.8,
            validation: 0.1,
            test: 0.1,
        };
        let counts = split_counts_for_total(101, split);
        let sum: u128 = counts.iter().map(|(_, n)| *n).sum();
        assert_eq!(sum, 101);
    }

    #[test]
    fn formatting_helpers_are_stable() {
        assert_eq!(format_u128_with_commas(0), "0");
        assert_eq!(format_u128_with_commas(1_234_567), "1,234,567");
        assert_eq!(format_replay_factor(0, 10), "n/a");
        assert_eq!(format_replay_factor(10, 0), "n/a");
        assert_eq!(format_replay_factor(9, 4), "2.25x");
    }

    #[test]
    fn estimate_capacity_handles_zero_and_nonzero_inputs() {
        let recipes = vec![sample_triplet_recipe("r1", 1.0)];
        let text_recipes = build_derived_text_recipes(&recipes);

        let zero = estimate_source_split_capacity_from_counts(0, &recipes, &text_recipes);
        assert_eq!(zero.triplets, 0);
        assert_eq!(zero.pairs, 0);
        assert_eq!(zero.text_samples, 0);

        let nonzero = estimate_source_split_capacity_from_counts(4, &recipes, &text_recipes);
        assert!(nonzero.triplets > 0);
        assert!(nonzero.pairs > 0);
        assert_eq!(nonzero.text_samples, 12);
    }

    #[test]
    fn resolve_text_recipes_prefers_config_then_derived() {
        let source_recipes = vec![sample_triplet_recipe("source", 1.0)];
        let config_override = TextRecipe {
            name: "override".into(),
            selector: Selector::Role(SectionRole::Context),
            weight: 1.0,
            instruction: None,
        };
        let config = SamplerConfig {
            text_recipes: vec![config_override.clone()],
            ..SamplerConfig::default()
        };
        let resolved = resolve_text_recipes_for_source(&config, &source_recipes);
        assert_eq!(resolved.len(), 1);
        assert_eq!(resolved[0].name, config_override.name);

        let config_with_triplets = SamplerConfig {
            recipes: vec![sample_triplet_recipe("cfg", 0.5)],
            ..SamplerConfig::default()
        };
        let resolved_cfg = resolve_text_recipes_for_source(&config_with_triplets, &source_recipes);
        assert_eq!(resolved_cfg.len(), 3);
        assert!(resolved_cfg.iter().all(|r| r.name.starts_with("cfg_")));
    }
}