use crate::data::SectionRole;
use crate::splits::{SplitLabel, SplitRatios};
use std::borrow::Cow;
#[derive(Clone, Debug)]
pub struct ChunkingStrategy {
pub max_window_tokens: usize,
pub overlap_tokens: Vec<usize>,
pub summary_fallback_weight: f32,
pub summary_fallback_tokens: usize,
pub chunk_weight_floor: f32,
}
impl Default for ChunkingStrategy {
fn default() -> Self {
Self {
max_window_tokens: 1024,
overlap_tokens: vec![64],
summary_fallback_weight: 0.35,
summary_fallback_tokens: 512,
chunk_weight_floor: 0.1,
}
}
}
#[derive(Clone, Debug)]
pub struct TripletRecipe {
pub name: Cow<'static, str>,
pub anchor: Selector,
pub positive_selector: Selector,
pub negative_selector: Selector,
pub negative_strategy: NegativeStrategy,
pub weight: f32,
pub instruction: Option<Cow<'static, str>>,
pub allow_same_anchor_positive: bool,
}
impl Default for TripletRecipe {
fn default() -> Self {
Self {
name: "".into(),
anchor: Selector::Random,
positive_selector: Selector::Random,
negative_selector: Selector::Random,
negative_strategy: NegativeStrategy::WrongArticle,
weight: 1.0,
instruction: None,
allow_same_anchor_positive: false,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Selector {
Role(SectionRole),
Paragraph(usize),
TemporalOffset(i32),
Random,
}
#[derive(Clone, Debug)]
pub struct TextRecipe {
pub name: Cow<'static, str>,
pub selector: Selector,
pub weight: f32,
pub instruction: Option<Cow<'static, str>>,
}
#[derive(Clone, Debug)]
pub enum NegativeStrategy {
WrongPublicationDate,
WrongArticle,
QuestionAnswerMismatch,
}
#[derive(Clone, Debug)]
pub struct SamplerConfig {
pub seed: u64,
pub batch_size: usize,
pub ingestion_max_records: usize,
pub chunking: ChunkingStrategy,
pub recipes: Vec<TripletRecipe>,
pub text_recipes: Vec<TextRecipe>,
pub split: SplitRatios,
pub allowed_splits: Vec<SplitLabel>,
}
impl Default for SamplerConfig {
fn default() -> Self {
Self {
seed: 42,
batch_size: 128,
ingestion_max_records: 2048,
chunking: ChunkingStrategy::default(),
recipes: Vec::new(),
text_recipes: Vec::new(),
split: SplitRatios::default(),
allowed_splits: vec![SplitLabel::Train],
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn chunking_strategy_defaults_are_stable() {
let cfg = ChunkingStrategy::default();
assert_eq!(cfg.max_window_tokens, 1024);
assert_eq!(cfg.overlap_tokens, vec![64]);
assert_eq!(cfg.summary_fallback_weight, 0.35);
assert_eq!(cfg.summary_fallback_tokens, 512);
assert_eq!(cfg.chunk_weight_floor, 0.1);
}
#[test]
fn sampler_config_defaults_are_expected() {
let cfg = SamplerConfig::default();
assert_eq!(cfg.seed, 42);
assert_eq!(cfg.batch_size, 128);
assert_eq!(cfg.ingestion_max_records, 2048);
assert!(cfg.recipes.is_empty());
assert!(cfg.text_recipes.is_empty());
assert_eq!(cfg.allowed_splits, vec![SplitLabel::Train]);
assert_eq!(cfg.chunking.max_window_tokens, 1024);
}
#[test]
fn selector_variants_can_be_constructed() {
let role = Selector::Role(SectionRole::Anchor);
let paragraph = Selector::Paragraph(3);
let temporal = Selector::TemporalOffset(-2);
let random = Selector::Random;
assert!(matches!(role, Selector::Role(SectionRole::Anchor)));
assert!(matches!(paragraph, Selector::Paragraph(3)));
assert!(matches!(temporal, Selector::TemporalOffset(-2)));
assert!(matches!(random, Selector::Random));
}
#[test]
fn triplet_recipe_default_is_expected() {
let recipe = TripletRecipe::default();
assert_eq!(recipe.name.as_ref(), "");
assert!(matches!(recipe.anchor, Selector::Random));
assert!(matches!(recipe.positive_selector, Selector::Random));
assert!(matches!(recipe.negative_selector, Selector::Random));
assert!(matches!(
recipe.negative_strategy,
NegativeStrategy::WrongArticle
));
assert_eq!(recipe.weight, 1.0);
assert!(recipe.instruction.is_none());
assert!(!recipe.allow_same_anchor_positive);
}
}