use chrono::{DateTime, Utc};
use std::hash::Hash;
use std::sync::Arc;
use crate::config::TripletRecipe;
use crate::data::DataRecord;
use crate::errors::SamplerError;
use crate::hash::stable_hash_with;
use crate::types::SourceId;
pub mod date_helpers;
pub mod file_corpus;
pub(crate) mod grouping;
#[derive(Clone, Debug)]
pub struct SourceCursor {
pub last_seen: DateTime<Utc>,
pub revision: u64,
}
#[derive(Clone, Debug)]
pub struct SourceSnapshot {
pub records: Vec<DataRecord>,
pub cursor: SourceCursor,
}
pub trait DataSource: Send + Sync {
fn id(&self) -> &str;
fn refresh(
&self,
cursor: Option<&SourceCursor>,
limit: Option<usize>,
) -> Result<SourceSnapshot, SamplerError>;
fn reported_record_count(&self) -> Option<u128> {
None
}
fn default_triplet_recipes(&self) -> Vec<TripletRecipe> {
Vec::new()
}
}
pub trait IndexableSource: Send + Sync {
fn id(&self) -> &str;
fn len_hint(&self) -> Option<usize>;
fn record_at(&self, idx: usize) -> Result<Option<DataRecord>, SamplerError>;
}
pub struct IndexablePager {
source_id: SourceId,
}
impl IndexablePager {
pub fn new(source_id: impl Into<SourceId>) -> Self {
Self {
source_id: source_id.into(),
}
}
pub fn refresh(
&self,
source: &dyn IndexableSource,
cursor: Option<&SourceCursor>,
limit: Option<usize>,
) -> Result<SourceSnapshot, SamplerError> {
let total = source
.len_hint()
.ok_or_else(|| SamplerError::SourceInconsistent {
source_id: source.id().to_string(),
details: "indexable source did not provide len_hint".into(),
})?;
self.refresh_with(total, cursor, limit, |idx| source.record_at(idx))
}
pub fn refresh_with(
&self,
total: usize,
cursor: Option<&SourceCursor>,
limit: Option<usize>,
mut fetch: impl FnMut(usize) -> Result<Option<DataRecord>, SamplerError>,
) -> Result<SourceSnapshot, SamplerError> {
if total == 0 {
return Ok(SourceSnapshot {
records: Vec::new(),
cursor: SourceCursor {
last_seen: Utc::now(),
revision: 0,
},
});
}
let mut start = cursor.map(|cursor| cursor.revision as usize).unwrap_or(0);
if start >= total {
start = 0;
}
let max = limit.unwrap_or(total);
let mut records = Vec::new();
let seed = Self::seed_for(&self.source_id, total);
let mut permutation = IndexPermutation::new(total, seed, start as u64);
for _ in 0..total {
if records.len() >= max {
break;
}
let idx = permutation.next();
if let Some(record) = fetch(idx)? {
records.push(record);
}
}
let last_seen = records
.iter()
.map(|record| record.updated_at)
.max()
.unwrap_or_else(Utc::now);
let next_start = permutation.cursor();
Ok(SourceSnapshot {
records,
cursor: SourceCursor {
last_seen,
revision: next_start as u64,
},
})
}
pub(crate) fn seed_for(source_id: &SourceId, total: usize) -> u64 {
Self::stable_index_shuffle_key(source_id, 0)
^ Self::stable_index_shuffle_key(source_id, total)
}
fn stable_index_shuffle_key(source_id: &SourceId, idx: usize) -> u64 {
stable_hash_with(|hasher| {
source_id.hash(hasher);
idx.hash(hasher);
})
}
}
pub struct IndexableAdapter<T: IndexableSource> {
inner: T,
}
impl<T: IndexableSource> IndexableAdapter<T> {
pub fn new(inner: T) -> Self {
Self { inner }
}
}
impl<T: IndexableSource> DataSource for IndexableAdapter<T> {
fn id(&self) -> &str {
self.inner.id()
}
fn refresh(
&self,
cursor: Option<&SourceCursor>,
limit: Option<usize>,
) -> Result<SourceSnapshot, SamplerError> {
let pager = IndexablePager::new(self.inner.id());
pager.refresh(&self.inner, cursor, limit)
}
}
pub(crate) struct IndexPermutation {
total: u64,
domain_bits: u32,
domain_size: u64,
seed: u64,
counter: u64,
}
impl IndexPermutation {
fn new(total: usize, seed: u64, counter: u64) -> Self {
let total_u64 = total as u64;
let domain_bits = (64 - (total_u64 - 1).leading_zeros()).max(1);
let domain_size = 1u64 << domain_bits;
Self {
total: total_u64,
domain_bits,
domain_size,
seed,
counter,
}
}
fn next(&mut self) -> usize {
loop {
let v =
Self::permute_bits(self.counter % self.domain_size, self.domain_bits, self.seed);
self.counter = self.counter.wrapping_add(1);
if v < self.total {
return v as usize;
}
}
}
fn cursor(&self) -> usize {
(self.counter as usize) % (self.total as usize)
}
fn permute_bits(value: u64, bits: u32, seed: u64) -> u64 {
if bits == 0 {
return 0;
}
let mask = if bits == 64 {
u64::MAX
} else {
(1u64 << bits) - 1
};
let mut a = (seed | 1) & mask;
if a == 0 {
a = 1;
}
let b = (seed >> 1) & mask;
a.wrapping_mul(value).wrapping_add(b) & mask
}
}
pub struct InMemorySource {
id: SourceId,
records: Arc<Vec<DataRecord>>,
}
impl InMemorySource {
pub fn new(id: impl Into<SourceId>, records: Vec<DataRecord>) -> Self {
Self {
id: id.into(),
records: Arc::new(records),
}
}
}
impl DataSource for InMemorySource {
fn id(&self) -> &str {
&self.id
}
fn refresh(
&self,
cursor: Option<&SourceCursor>,
limit: Option<usize>,
) -> Result<SourceSnapshot, SamplerError> {
let records = &*self.records;
let total = records.len();
let mut start = cursor.map(|cursor| cursor.revision as usize).unwrap_or(0);
if total > 0 && start >= total {
start = 0;
}
let max = limit.unwrap_or(total);
let mut filtered = Vec::new();
for idx in 0..total {
if filtered.len() >= max {
break;
}
let pos = (start + idx) % total;
filtered.push(records[pos].clone());
}
let last_seen = filtered
.iter()
.map(|record| record.updated_at)
.max()
.unwrap_or_else(Utc::now);
let next_start = if total == 0 {
0
} else {
(start + filtered.len()) % total
};
Ok(SourceSnapshot {
records: filtered,
cursor: SourceCursor {
last_seen,
revision: next_start as u64,
},
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::data::{QualityScore, RecordSection, SectionRole};
use crate::types::RecordId;
struct IndexableStub {
id: SourceId,
count: usize,
}
impl IndexableStub {
fn new(id: &str, count: usize) -> Self {
Self {
id: id.to_string(),
count,
}
}
}
impl IndexableSource for IndexableStub {
fn id(&self) -> &str {
&self.id
}
fn len_hint(&self) -> Option<usize> {
Some(self.count)
}
fn record_at(&self, idx: usize) -> Result<Option<DataRecord>, SamplerError> {
if idx >= self.count {
return Ok(None);
}
let now = Utc::now();
Ok(Some(DataRecord {
id: format!("record_{idx}"),
source: self.id.clone(),
created_at: now,
updated_at: now,
quality: QualityScore { trust: 1.0 },
taxonomy: Vec::new(),
sections: vec![RecordSection {
role: SectionRole::Anchor,
heading: None,
text: "stub".into(),
sentences: vec!["stub".into()],
}],
meta_prefix: None,
}))
}
}
#[test]
fn indexable_adapter_pages_in_stable_order() {
let adapter = IndexableAdapter::new(IndexableStub::new("stub", 6));
let full = adapter.refresh(None, None).unwrap();
let full_ids: Vec<RecordId> = full.records.into_iter().map(|r| r.id).collect();
let mut cursor = None;
let mut paged = Vec::new();
for _ in 0..3 {
let snapshot = adapter.refresh(cursor.as_ref(), Some(2)).unwrap();
cursor = Some(snapshot.cursor);
paged.extend(snapshot.records.into_iter().map(|r| r.id));
}
assert_eq!(paged, full_ids);
}
#[test]
fn indexable_paging_spans_multiple_regimes() {
let total = 256usize;
let mask = (1u64 << (64 - (total as u64 - 1).leading_zeros())) - 1;
let source_id = (0..512)
.map(|idx| format!("regime_test_{idx}"))
.find(|id| {
let seed = IndexablePager::seed_for(id, total);
let a = (seed | 1) & mask;
a != 1 && a != mask
})
.unwrap();
let adapter = IndexableAdapter::new(IndexableStub::new(&source_id, total));
let snapshot = adapter.refresh(None, Some(64)).unwrap();
let indices: Vec<usize> = snapshot
.records
.into_iter()
.map(|r| {
r.id.strip_prefix("record_")
.unwrap()
.parse::<usize>()
.unwrap()
})
.collect();
let min_idx = *indices.iter().min().unwrap();
let max_idx = *indices.iter().max().unwrap();
assert!(
max_idx - min_idx >= total / 2,
"expected spread across the index space, got min={min_idx} max={max_idx}"
);
}
}