use crate::data::SectionRole;
use crate::splits::{SplitLabel, SplitRatios};
use std::borrow::Cow;
use std::sync::Arc;
#[derive(Clone, Debug)]
pub struct DenoiserConfig {
pub enabled: bool,
pub max_digit_ratio: f32,
pub strip_markdown: bool,
}
impl Default for DenoiserConfig {
fn default() -> Self {
Self {
enabled: false,
max_digit_ratio: 0.35,
strip_markdown: true,
}
}
}
pub struct ChunkingStrategy {
pub max_window_tokens: usize,
pub overlap_tokens: Vec<usize>,
pub summary_fallback_weight: f32,
pub summary_fallback_tokens: usize,
pub chunk_weight_floor: f32,
pub(crate) preprocessors: Vec<Arc<dyn crate::preprocessor::TextPreprocessor>>,
}
impl Default for ChunkingStrategy {
fn default() -> Self {
Self {
max_window_tokens: 1024,
overlap_tokens: vec![64],
summary_fallback_weight: 0.35,
summary_fallback_tokens: 512,
chunk_weight_floor: 0.1,
preprocessors: Vec::new(),
}
}
}
impl Clone for ChunkingStrategy {
fn clone(&self) -> Self {
Self {
max_window_tokens: self.max_window_tokens,
overlap_tokens: self.overlap_tokens.clone(),
summary_fallback_weight: self.summary_fallback_weight,
summary_fallback_tokens: self.summary_fallback_tokens,
chunk_weight_floor: self.chunk_weight_floor,
preprocessors: self.preprocessors.clone(),
}
}
}
impl std::fmt::Debug for ChunkingStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ChunkingStrategy")
.field("max_window_tokens", &self.max_window_tokens)
.field("overlap_tokens", &self.overlap_tokens)
.field("summary_fallback_weight", &self.summary_fallback_weight)
.field("summary_fallback_tokens", &self.summary_fallback_tokens)
.field("chunk_weight_floor", &self.chunk_weight_floor)
.field(
"preprocessors",
&format_args!("{} registered", self.preprocessors.len()),
)
.finish()
}
}
impl ChunkingStrategy {
pub fn register_preprocessor(
&mut self,
p: impl crate::preprocessor::TextPreprocessor + 'static,
) -> &mut Self {
self.preprocessors.push(Arc::new(p));
self
}
pub fn preprocessors(&self) -> &[Arc<dyn crate::preprocessor::TextPreprocessor>] {
&self.preprocessors
}
}
#[derive(Clone, Debug)]
pub struct TripletRecipe {
pub name: Cow<'static, str>,
pub anchor: Selector,
pub positive_selector: Selector,
pub negative_selector: Selector,
pub negative_strategy: NegativeStrategy,
pub weight: f32,
pub instruction: Option<Cow<'static, str>>,
pub allow_same_anchor_positive: bool,
}
impl Default for TripletRecipe {
fn default() -> Self {
Self {
name: "".into(),
anchor: Selector::Random,
positive_selector: Selector::Random,
negative_selector: Selector::Random,
negative_strategy: NegativeStrategy::WrongArticle,
weight: 1.0,
instruction: None,
allow_same_anchor_positive: false,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Selector {
Role(SectionRole),
Paragraph(usize),
TemporalOffset(i32),
Random,
}
#[derive(Clone, Debug)]
pub struct TextRecipe {
pub name: Cow<'static, str>,
pub selector: Selector,
pub weight: f32,
pub instruction: Option<Cow<'static, str>>,
}
#[derive(Clone, Debug)]
pub enum NegativeStrategy {
WrongPublicationDate,
WrongArticle,
QuestionAnswerMismatch,
}
#[derive(Clone, Debug)]
pub struct SamplerConfig {
pub seed: u64,
pub batch_size: usize,
pub ingestion_max_records: usize,
pub chunking: ChunkingStrategy,
pub recipes: Vec<TripletRecipe>,
pub text_recipes: Vec<TextRecipe>,
pub split: SplitRatios,
pub allowed_splits: Vec<SplitLabel>,
}
impl Default for SamplerConfig {
fn default() -> Self {
Self {
seed: 42,
batch_size: 128,
ingestion_max_records: 2048,
chunking: ChunkingStrategy::default(),
recipes: Vec::new(),
text_recipes: Vec::new(),
split: SplitRatios::default(),
allowed_splits: vec![SplitLabel::Train],
}
}
}
impl SamplerConfig {
pub fn with_denoiser(mut self, config: DenoiserConfig) -> Self {
use crate::preprocessor::backends::denoiser_preprocessor::DenoiserPreprocessor;
self.chunking
.preprocessors
.push(Arc::new(DenoiserPreprocessor::new(config)));
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn chunking_strategy_defaults_are_stable() {
let cfg = ChunkingStrategy::default();
assert_eq!(cfg.max_window_tokens, 1024);
assert_eq!(cfg.overlap_tokens, vec![64]);
assert_eq!(cfg.summary_fallback_weight, 0.35);
assert_eq!(cfg.summary_fallback_tokens, 512);
assert_eq!(cfg.chunk_weight_floor, 0.1);
}
#[test]
fn sampler_config_defaults_are_expected() {
let cfg = SamplerConfig::default();
assert_eq!(cfg.seed, 42);
assert_eq!(cfg.batch_size, 128);
assert_eq!(cfg.ingestion_max_records, 2048);
assert!(cfg.recipes.is_empty());
assert!(cfg.text_recipes.is_empty());
assert_eq!(cfg.allowed_splits, vec![SplitLabel::Train]);
assert_eq!(cfg.chunking.max_window_tokens, 1024);
}
#[test]
fn selector_variants_can_be_constructed() {
let role = Selector::Role(SectionRole::Anchor);
let paragraph = Selector::Paragraph(3);
let temporal = Selector::TemporalOffset(-2);
let random = Selector::Random;
assert!(matches!(role, Selector::Role(SectionRole::Anchor)));
assert!(matches!(paragraph, Selector::Paragraph(3)));
assert!(matches!(temporal, Selector::TemporalOffset(-2)));
assert!(matches!(random, Selector::Random));
}
#[test]
fn triplet_recipe_default_is_expected() {
let recipe = TripletRecipe::default();
assert_eq!(recipe.name.as_ref(), "");
assert!(matches!(recipe.anchor, Selector::Random));
assert!(matches!(recipe.positive_selector, Selector::Random));
assert!(matches!(recipe.negative_selector, Selector::Random));
assert!(matches!(
recipe.negative_strategy,
NegativeStrategy::WrongArticle
));
assert_eq!(recipe.weight, 1.0);
assert!(recipe.instruction.is_none());
assert!(!recipe.allow_same_anchor_positive);
}
}