use std::collections::{HashMap, HashSet};
use std::sync::{Arc, RwLock};
use bm25::{Document, Language, SearchEngine, SearchEngineBuilder};
use indexmap::IndexMap;
use crate::constants::sampler::BM25_HARD_NEGATIVE_ROTATION_TOP_K;
use crate::constants::sampler::BM25_QUERY_TOKEN_LIMIT;
use crate::constants::sampler::BM25_SEARCH_TOP_K;
use crate::data::DataRecord;
use crate::splits::SplitLabel;
use crate::tokenizer::{Tokenizer, WhitespaceTokenizer};
use crate::types::{RecordId, SourceId};
use crate::utils::platform_newline;
use super::NegativeBackend;
struct Bm25RecordMeta {
record_id: RecordId,
split: Option<SplitLabel>,
}
struct PerSourceBm25Index {
meta: Vec<Bm25RecordMeta>,
search_engine: SearchEngine<usize>,
}
pub struct Bm25Backend {
hard_negatives: RwLock<HashMap<RecordId, Vec<RecordId>>>,
source_indexes: HashMap<SourceId, PerSourceBm25Index>,
negative_cursors: RwLock<HashMap<(RecordId, SplitLabel), usize>>,
max_window_tokens: usize,
#[cfg(feature = "extended-metrics")]
bm25_selection_count: std::sync::atomic::AtomicU64,
#[cfg(feature = "extended-metrics")]
bm25_fallback_count: std::sync::atomic::AtomicU64,
}
impl Bm25Backend {
pub fn new() -> Self {
Self {
hard_negatives: RwLock::new(HashMap::new()),
source_indexes: HashMap::new(),
negative_cursors: RwLock::new(HashMap::new()),
max_window_tokens: 0,
#[cfg(feature = "extended-metrics")]
bm25_selection_count: std::sync::atomic::AtomicU64::new(0),
#[cfg(feature = "extended-metrics")]
bm25_fallback_count: std::sync::atomic::AtomicU64::new(0),
}
}
fn select_hard_negative(
&self,
anchor: &DataRecord,
anchor_split: SplitLabel,
pool: &[Arc<DataRecord>],
fallback_used: bool,
anchor_query_text: Option<&str>,
rng: &mut dyn rand::RngCore,
) -> Option<(Arc<DataRecord>, bool)> {
if pool.is_empty() {
return None;
}
#[cfg(feature = "extended-metrics")]
{
self.bm25_selection_count
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
let pool_by_id: HashMap<&str, &Arc<DataRecord>> =
pool.iter().map(|r| (r.id.as_str(), r)).collect();
let candidate_ids = self.ranked_candidates(anchor, anchor_split, anchor_query_text);
let ranked_pool: Vec<Arc<DataRecord>> = candidate_ids
.iter()
.filter_map(|id| pool_by_id.get(id.as_str()).copied().cloned())
.collect();
if !ranked_pool.is_empty() {
let top_k = ranked_pool
.len()
.min(BM25_HARD_NEGATIVE_ROTATION_TOP_K.max(1));
let cursor_key = (anchor.id.clone(), anchor_split);
let mut cursors = self.negative_cursors.write().unwrap();
let cursor = cursors.entry(cursor_key).or_insert(0);
if *cursor >= top_k {
*cursor = 0;
}
let selected = ranked_pool.get(*cursor).cloned();
*cursor = (*cursor + 1) % top_k;
drop(cursors);
return selected.map(|record| (record, fallback_used));
}
#[cfg(feature = "extended-metrics")]
{
self.bm25_fallback_count
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
let mut fallback = pool.to_vec();
fallback.sort_by(|a, b| a.id.cmp(&b.id));
if fallback.is_empty() {
return None;
}
let idx = {
use rand::Rng as _;
rng.random_range(0..fallback.len())
};
fallback.get(idx).cloned().map(|r| (r, fallback_used))
}
fn ranked_candidates(
&self,
anchor: &DataRecord,
anchor_split: SplitLabel,
anchor_query_text: Option<&str>,
) -> Vec<RecordId> {
if anchor_query_text.is_none()
&& let Some(cached) = self.hard_negatives.read().unwrap().get(&anchor.id).cloned()
{
return cached;
}
let Some(index) = self.source_indexes.get(anchor.source.as_str()) else {
if anchor_query_text.is_none() {
self.hard_negatives
.write()
.unwrap()
.insert(anchor.id.clone(), Vec::new());
}
return Vec::new();
};
let owned_text: String;
let query_limited: String;
let bm25_query_text: &str = if let Some(text) = anchor_query_text {
let tokens: Vec<&str> = WhitespaceTokenizer.tokenize(text);
if tokens.len() <= BM25_QUERY_TOKEN_LIMIT {
text
} else {
query_limited = tokens[..BM25_QUERY_TOKEN_LIMIT].join(" ");
&query_limited
}
} else {
owned_text = record_bm25_text(anchor, self.max_window_tokens);
&owned_text
};
let results = index
.search_engine
.search(bm25_query_text, BM25_SEARCH_TOP_K);
let mut all_scored: Vec<(f32, RecordId)> = results
.into_iter()
.filter_map(|result| {
let m = index.meta.get(result.document.id)?;
if m.record_id == anchor.id {
return None;
}
if m.split != Some(anchor_split) {
return None;
}
Some((result.score, m.record_id.clone()))
})
.collect();
all_scored.sort_by(|a, b| {
b.0.partial_cmp(&a.0)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.1.cmp(&b.1))
});
let ranked: Vec<RecordId> = all_scored.into_iter().map(|(_, id)| id).collect();
if anchor_query_text.is_none() {
self.hard_negatives
.write()
.unwrap()
.insert(anchor.id.clone(), ranked.clone());
}
ranked
}
fn rebuild_source_index(
&mut self,
source_id: &SourceId,
source_records: &[&DataRecord],
split_fn: &dyn Fn(&RecordId) -> Option<SplitLabel>,
) {
if source_records.len() < 2 {
self.source_indexes.remove(source_id);
return;
}
let mut meta: Vec<Bm25RecordMeta> = Vec::with_capacity(source_records.len());
let mut docs: Vec<Document<usize>> = Vec::with_capacity(source_records.len());
for (idx, record) in source_records.iter().enumerate() {
let split = split_fn(&record.id);
meta.push(Bm25RecordMeta {
record_id: record.id.clone(),
split,
});
docs.push(Document {
id: idx,
contents: record_bm25_text(record, self.max_window_tokens),
});
}
let search_engine =
SearchEngineBuilder::<usize>::with_documents(Language::English, docs).build();
self.source_indexes.insert(
source_id.clone(),
PerSourceBm25Index {
meta,
search_engine,
},
);
}
}
impl NegativeBackend for Bm25Backend {
fn choose_negative(
&self,
anchor: &DataRecord,
anchor_split: SplitLabel,
pool: Vec<Arc<DataRecord>>,
fallback_used: bool,
anchor_query_text: Option<&str>,
rng: &mut dyn rand::RngCore,
) -> Option<(Arc<DataRecord>, bool)> {
self.select_hard_negative(
anchor,
anchor_split,
&pool,
fallback_used,
anchor_query_text,
rng,
)
}
fn on_sync_start(&mut self) {
self.negative_cursors.write().unwrap().clear();
}
fn on_records_refreshed(
&mut self,
records: &IndexMap<RecordId, Arc<DataRecord>>,
max_window_tokens: usize,
split_fn: &dyn Fn(&RecordId) -> Option<SplitLabel>,
refreshed_source_ids: &[SourceId],
) {
if refreshed_source_ids.is_empty() {
return;
}
self.max_window_tokens = max_window_tokens;
let refreshed_set: HashSet<&str> =
refreshed_source_ids.iter().map(|s| s.as_str()).collect();
self.hard_negatives.write().unwrap().retain(|anchor_id, _| {
records
.get(anchor_id)
.map(|r| !refreshed_set.contains(r.source.as_str()))
.unwrap_or(false)
});
let mut records_by_source: HashMap<&str, Vec<&DataRecord>> = HashMap::new();
for r in records.values() {
records_by_source
.entry(r.source.as_str())
.or_default()
.push(r.as_ref());
}
for source_id in refreshed_source_ids {
let source_records = records_by_source
.get(source_id.as_str())
.map(|v| v.as_slice())
.unwrap_or(&[]);
self.rebuild_source_index(source_id, source_records, split_fn);
}
let active_sources: HashSet<&str> = records.values().map(|r| r.source.as_str()).collect();
self.source_indexes
.retain(|source_id, _| active_sources.contains(source_id.as_str()));
}
fn prune_cursors(&mut self, valid_ids: &HashSet<RecordId>) {
self.negative_cursors
.write()
.unwrap()
.retain(|(record_id, _), _| valid_ids.contains(record_id));
self.hard_negatives
.write()
.unwrap()
.retain(|anchor_id, _| valid_ids.contains(anchor_id));
}
fn cursors_empty(&self) -> bool {
self.negative_cursors.read().unwrap().is_empty()
}
#[cfg(all(feature = "bm25-mining", feature = "extended-metrics"))]
fn bm25_fallback_stats(&self) -> (u64, u64) {
(
self.bm25_fallback_count
.load(std::sync::atomic::Ordering::Relaxed),
self.bm25_selection_count
.load(std::sync::atomic::Ordering::Relaxed),
)
}
#[cfg(test)]
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
impl Bm25Backend {
#[cfg(test)]
pub(in crate::sampler) fn ranked_candidates_pub(
&self,
anchor: &DataRecord,
anchor_split: SplitLabel,
) -> Vec<RecordId> {
self.ranked_candidates(anchor, anchor_split, None)
}
#[cfg(test)]
pub(in crate::sampler) fn hard_negatives_get(
&self,
anchor_id: &RecordId,
) -> Option<Vec<RecordId>> {
self.hard_negatives.read().unwrap().get(anchor_id).cloned()
}
#[cfg(test)]
pub(in crate::sampler) fn index_meta_record_ids(&self) -> Option<Vec<RecordId>> {
if self.source_indexes.is_empty() {
return None;
}
let mut source_keys: Vec<&SourceId> = self.source_indexes.keys().collect();
source_keys.sort();
let mut all_ids: Vec<RecordId> = Vec::new();
for source_id in source_keys {
if let Some(idx) = self.source_indexes.get(source_id) {
all_ids.extend(idx.meta.iter().map(|m| m.record_id.clone()));
}
}
Some(all_ids)
}
#[cfg(test)]
pub(in crate::sampler) fn negative_cursors_len(&self) -> usize {
self.negative_cursors.read().unwrap().len()
}
#[cfg(test)]
pub(in crate::sampler) fn negative_cursors_is_empty(&self) -> bool {
self.negative_cursors.read().unwrap().is_empty()
}
#[cfg(test)]
pub(in crate::sampler) fn negative_cursors_insert(
&self,
key: (RecordId, SplitLabel),
value: usize,
) {
self.negative_cursors.write().unwrap().insert(key, value);
}
}
pub(in crate::sampler) fn record_bm25_text(record: &DataRecord, max_tokens: usize) -> String {
let mut out = String::new();
for section in &record.sections {
if let Some(heading) = §ion.heading
&& !heading.trim().is_empty()
{
out.push_str(heading);
out.push_str(platform_newline());
}
if !section.text.trim().is_empty() {
out.push_str(§ion.text);
out.push_str(platform_newline());
}
}
if out.trim().is_empty() {
return record.id.clone();
}
if max_tokens == 0 {
return out;
}
let tokens: Vec<&str> = WhitespaceTokenizer.tokenize(&out);
if tokens.len() <= max_tokens {
return out;
}
tokens
.into_iter()
.take(max_tokens)
.collect::<Vec<_>>()
.join(" ")
}