use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, OnceLock};
use std::thread::{self, JoinHandle};
use std::vec::IntoIter;
use crossbeam_channel::{bounded, Receiver, Sender};
use parking_lot::Mutex;
use sysinfo::System;
use super::CorpusReader;
static SYSTEM: OnceLock<Mutex<System>> = OnceLock::new();
fn get_system() -> &'static Mutex<System> {
SYSTEM.get_or_init(|| Mutex::new(System::new()))
}
#[derive(Clone, Debug)]
pub struct PrefetchConfig {
pub batch_size: usize,
pub buffer_batches: usize,
pub auto_tune: bool,
pub ram_fraction: f64,
}
impl Default for PrefetchConfig {
fn default() -> Self {
Self {
batch_size: 10_000,
buffer_batches: 8,
auto_tune: true,
ram_fraction: 0.10,
}
}
}
impl PrefetchConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_batch_size(mut self, size: usize) -> Self {
self.batch_size = size;
self
}
pub fn with_buffer_batches(mut self, batches: usize) -> Self {
self.buffer_batches = batches;
self
}
pub fn with_auto_tune(mut self, enabled: bool) -> Self {
self.auto_tune = enabled;
self
}
pub fn with_ram_fraction(mut self, fraction: f64) -> Self {
self.ram_fraction = fraction.clamp(0.01, 0.50);
self
}
fn effective_buffer_batches(&self) -> usize {
if self.auto_tune {
compute_buffer_batches(self.batch_size, self.ram_fraction)
} else {
self.buffer_batches.clamp(2, 64)
}
}
}
const DEFAULT_FALLBACK_MEMORY: usize = 8 * 1024 * 1024 * 1024;
fn get_available_memory_bytes() -> usize {
let mut sys = get_system().lock();
sys.refresh_memory();
let total = sys.total_memory() as usize;
let used = sys.used_memory() as usize;
if total > 0 {
total.saturating_sub(used)
} else {
DEFAULT_FALLBACK_MEMORY
}
}
fn compute_buffer_batches(batch_size: usize, ram_fraction: f64) -> usize {
let available = get_available_memory_bytes();
let target_bytes = (available as f64 * ram_fraction) as usize;
let bytes_per_batch = batch_size * 100;
let batches = target_bytes / bytes_per_batch.max(1);
batches.clamp(2, 32)
}
enum PrefetchMessage {
Batch(Vec<String>),
Done,
Error(String),
}
pub struct PrefetchingReader {
rx: Option<Receiver<PrefetchMessage>>,
stop_signal: Arc<AtomicBool>,
producer_handle: Option<JoinHandle<()>>,
current_batch: Option<IntoIter<String>>,
exhausted: bool,
batches_received: usize,
sentences_yielded: usize,
}
impl PrefetchingReader {
pub fn new<R>(reader: R) -> Self
where
R: CorpusReader + 'static,
{
Self::with_config(reader, PrefetchConfig::default())
}
pub fn with_config<R>(reader: R, config: PrefetchConfig) -> Self
where
R: CorpusReader + 'static,
{
let buffer_batches = config.effective_buffer_batches();
let batch_size = config.batch_size;
log::debug!(
"PrefetchingReader: batch_size={}, buffer_batches={}, ram_fraction={:.1}%",
batch_size,
buffer_batches,
config.ram_fraction * 100.0
);
let (tx, rx) = bounded::<PrefetchMessage>(buffer_batches);
let stop_signal = Arc::new(AtomicBool::new(false));
let stop_clone = stop_signal.clone();
let producer_handle = thread::spawn(move || {
producer_loop(reader, tx, stop_clone, batch_size);
});
Self {
rx: Some(rx),
stop_signal,
producer_handle: Some(producer_handle),
current_batch: None,
exhausted: false,
batches_received: 0,
sentences_yielded: 0,
}
}
pub fn stop(&self) {
self.stop_signal.store(true, Ordering::Release);
}
pub fn is_stopped(&self) -> bool {
self.stop_signal.load(Ordering::Acquire)
}
pub fn batches_received(&self) -> usize {
self.batches_received
}
pub fn sentences_yielded(&self) -> usize {
self.sentences_yielded
}
pub fn batches(mut self) -> PrefetchBatchIterator {
let rx = self.rx.take().expect("PrefetchingReader already consumed");
let producer_handle = self.producer_handle.take();
let stop_signal = self.stop_signal.clone();
self.exhausted = true;
PrefetchBatchIterator {
rx,
stop_signal,
producer_handle,
exhausted: false,
batches_received: self.batches_received,
}
}
fn receive_batch(&mut self) -> Option<Vec<String>> {
if self.exhausted {
return None;
}
let rx = self.rx.as_ref()?;
match rx.recv() {
Ok(PrefetchMessage::Batch(batch)) => {
self.batches_received += 1;
Some(batch)
}
Ok(PrefetchMessage::Done) => {
self.exhausted = true;
None
}
Ok(PrefetchMessage::Error(e)) => {
log::error!("Prefetch producer error: {}", e);
self.exhausted = true;
None
}
Err(_) => {
self.exhausted = true;
None
}
}
}
}
impl Iterator for PrefetchingReader {
type Item = String;
fn next(&mut self) -> Option<Self::Item> {
loop {
if let Some(ref mut batch_iter) = self.current_batch {
if let Some(sentence) = batch_iter.next() {
self.sentences_yielded += 1;
return Some(sentence);
}
}
match self.receive_batch() {
Some(batch) => {
self.current_batch = Some(batch.into_iter());
}
None => {
return None;
}
}
}
}
}
impl Drop for PrefetchingReader {
fn drop(&mut self) {
if self.rx.is_none() {
return;
}
self.stop_signal.store(true, Ordering::Release);
if let Some(ref rx) = self.rx {
while rx.try_recv().is_ok() {}
}
if let Some(handle) = self.producer_handle.take() {
let _ = handle.join();
}
}
}
pub struct PrefetchBatchIterator {
rx: Receiver<PrefetchMessage>,
stop_signal: Arc<AtomicBool>,
producer_handle: Option<JoinHandle<()>>,
exhausted: bool,
batches_received: usize,
}
impl PrefetchBatchIterator {
pub fn stop(&self) {
self.stop_signal.store(true, Ordering::Release);
}
pub fn batches_received(&self) -> usize {
self.batches_received
}
}
impl Iterator for PrefetchBatchIterator {
type Item = Vec<String>;
fn next(&mut self) -> Option<Self::Item> {
if self.exhausted {
return None;
}
match self.rx.recv() {
Ok(PrefetchMessage::Batch(batch)) => {
self.batches_received += 1;
Some(batch)
}
Ok(PrefetchMessage::Done) => {
self.exhausted = true;
None
}
Ok(PrefetchMessage::Error(e)) => {
log::error!("Prefetch producer error: {}", e);
self.exhausted = true;
None
}
Err(_) => {
self.exhausted = true;
None
}
}
}
}
impl Drop for PrefetchBatchIterator {
fn drop(&mut self) {
self.stop_signal.store(true, Ordering::Release);
while self.rx.try_recv().is_ok() {}
if let Some(handle) = self.producer_handle.take() {
let _ = handle.join();
}
}
}
fn producer_loop<R: CorpusReader>(
reader: R,
tx: Sender<PrefetchMessage>,
stop_signal: Arc<AtomicBool>,
batch_size: usize,
) {
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut batch = Vec::with_capacity(batch_size);
for sentence in reader.sentences() {
if stop_signal.load(Ordering::Acquire) {
log::debug!("Prefetch producer: stop signal received");
break;
}
batch.push(sentence);
if batch.len() >= batch_size {
if tx.send(PrefetchMessage::Batch(batch)).is_err() {
log::debug!("Prefetch producer: receiver dropped");
return;
}
batch = Vec::with_capacity(batch_size);
}
}
if !batch.is_empty() && !stop_signal.load(Ordering::Acquire) {
let _ = tx.send(PrefetchMessage::Batch(batch));
}
let _ = tx.send(PrefetchMessage::Done);
}));
if let Err(panic) = result {
let msg = if let Some(s) = panic.downcast_ref::<&str>() {
s.to_string()
} else if let Some(s) = panic.downcast_ref::<String>() {
s.clone()
} else {
"Unknown panic in prefetch producer".to_string()
};
log::error!("Prefetch producer panicked: {}", msg);
let _ = tx.send(PrefetchMessage::Error(msg));
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
struct MockReader {
sentences: Arc<Mutex<Vec<String>>>,
}
impl MockReader {
fn new(sentences: Vec<String>) -> Self {
Self {
sentences: Arc::new(Mutex::new(sentences)),
}
}
}
impl CorpusReader for MockReader {
fn documents(&self) -> Box<dyn Iterator<Item = crate::corpus::Document> + Send + '_> {
Box::new(std::iter::empty())
}
fn sentences(&self) -> Box<dyn Iterator<Item = String> + Send + '_> {
let sentences = self.sentences.lock().unwrap().clone();
Box::new(sentences.into_iter())
}
}
#[test]
fn test_prefetch_config_defaults() {
let config = PrefetchConfig::default();
assert_eq!(config.batch_size, 10_000);
assert!(config.auto_tune);
assert!((config.ram_fraction - 0.10).abs() < f64::EPSILON);
}
#[test]
fn test_prefetch_config_builder() {
let config = PrefetchConfig::new()
.with_batch_size(5_000)
.with_auto_tune(false)
.with_buffer_batches(4)
.with_ram_fraction(0.20);
assert_eq!(config.batch_size, 5_000);
assert!(!config.auto_tune);
assert_eq!(config.buffer_batches, 4);
assert!((config.ram_fraction - 0.20).abs() < f64::EPSILON);
}
#[test]
fn test_ram_fraction_clamping() {
let config = PrefetchConfig::new().with_ram_fraction(0.90);
assert!((config.ram_fraction - 0.50).abs() < f64::EPSILON);
let config = PrefetchConfig::new().with_ram_fraction(0.001);
assert!((config.ram_fraction - 0.01).abs() < f64::EPSILON);
}
#[test]
fn test_prefetch_empty_corpus() {
let reader = MockReader::new(vec![]);
let prefetch = PrefetchingReader::new(reader);
let sentences: Vec<String> = prefetch.collect();
assert!(sentences.is_empty());
}
#[test]
fn test_prefetch_small_corpus() {
let input = vec![
"Hello world.".to_string(),
"This is a test.".to_string(),
"Rust is great.".to_string(),
];
let reader = MockReader::new(input.clone());
let config = PrefetchConfig::new()
.with_batch_size(2)
.with_auto_tune(false)
.with_buffer_batches(2);
let prefetch = PrefetchingReader::with_config(reader, config);
let output: Vec<String> = prefetch.collect();
assert_eq!(output, input);
}
#[test]
fn test_prefetch_batch_iterator() {
let input: Vec<String> = (0..100).map(|i| format!("Sentence {}", i)).collect();
let reader = MockReader::new(input.clone());
let config = PrefetchConfig::new()
.with_batch_size(25)
.with_auto_tune(false)
.with_buffer_batches(2);
let prefetch = PrefetchingReader::with_config(reader, config);
let batches: Vec<Vec<String>> = prefetch.batches().collect();
assert_eq!(batches.len(), 4);
for batch in &batches {
assert_eq!(batch.len(), 25);
}
let flattened: Vec<String> = batches.into_iter().flatten().collect();
assert_eq!(flattened, input);
}
#[test]
fn test_prefetch_partial_batch() {
let input: Vec<String> = (0..7).map(|i| format!("Sentence {}", i)).collect();
let reader = MockReader::new(input.clone());
let config = PrefetchConfig::new()
.with_batch_size(3)
.with_auto_tune(false)
.with_buffer_batches(2);
let prefetch = PrefetchingReader::with_config(reader, config);
let batches: Vec<Vec<String>> = prefetch.batches().collect();
assert_eq!(batches.len(), 3);
assert_eq!(batches[0].len(), 3);
assert_eq!(batches[1].len(), 3);
assert_eq!(batches[2].len(), 1);
}
#[test]
fn test_prefetch_early_stop() {
let input: Vec<String> = (0..1000).map(|i| format!("Sentence {}", i)).collect();
let reader = MockReader::new(input);
let config = PrefetchConfig::new()
.with_batch_size(100)
.with_auto_tune(false)
.with_buffer_batches(2);
let prefetch = PrefetchingReader::with_config(reader, config);
let output: Vec<String> = prefetch.take(50).collect();
assert_eq!(output.len(), 50);
}
#[test]
fn test_prefetch_drop_no_hang() {
let input: Vec<String> = (0..10_000).map(|i| format!("Sentence {}", i)).collect();
let reader = MockReader::new(input);
let config = PrefetchConfig::new()
.with_batch_size(100)
.with_auto_tune(false)
.with_buffer_batches(2);
let prefetch = PrefetchingReader::with_config(reader, config);
drop(prefetch);
}
#[test]
fn test_get_available_memory() {
let memory = get_available_memory_bytes();
assert!(memory >= 1024 * 1024 * 1024);
assert!(memory <= 1024 * 1024 * 1024 * 1024);
}
#[test]
fn test_compute_buffer_batches() {
let batches = compute_buffer_batches(10_000, 0.10);
assert!(batches >= 2);
assert!(batches <= 32);
}
}