use crossbeam::channel;
use std::thread;
pub struct PrefetchIterator<T> {
receiver: channel::Receiver<Option<T>>,
_handle: thread::JoinHandle<()>,
}
impl<T> PrefetchIterator<T>
where
T: Send + 'static,
{
pub fn new<I>(inner: I, buffer_size: usize) -> Self
where
I: Iterator<Item = T> + Send + 'static,
{
let (sender, receiver) = channel::bounded(buffer_size);
let handle = thread::spawn(move || {
for item in inner {
if sender.send(Some(item)).is_err() {
break;
}
}
let _ = sender.send(None);
});
Self {
receiver,
_handle: handle,
}
}
pub fn new_unbounded<I>(inner: I) -> Self
where
I: Iterator<Item = T> + Send + 'static,
{
let (sender, receiver) = channel::unbounded();
let handle = thread::spawn(move || {
for item in inner {
if sender.send(Some(item)).is_err() {
break;
}
}
let _ = sender.send(None);
});
Self {
receiver,
_handle: handle,
}
}
pub fn buffer_len(&self) -> usize {
self.receiver.len()
}
pub fn buffer_is_empty(&self) -> bool {
self.receiver.is_empty()
}
pub fn try_next(&mut self) -> Option<T> {
match self.receiver.try_recv() {
Ok(Some(item)) => Some(item),
Ok(None) | Err(_) => None,
}
}
}
impl<T> Iterator for PrefetchIterator<T>
where
T: Send + 'static,
{
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
match self.receiver.recv() {
Ok(Some(item)) => Some(item),
Ok(None) | Err(_) => None,
}
}
}
pub trait PrefetchExt<T>: Iterator<Item = T> + Sized + Send + 'static
where
T: Send + 'static,
{
fn prefetch(self, buffer_size: usize) -> PrefetchIterator<T> {
PrefetchIterator::new(self, buffer_size)
}
fn prefetch_unbounded(self) -> PrefetchIterator<T> {
PrefetchIterator::new_unbounded(self)
}
}
impl<I, T> PrefetchExt<T> for I
where
I: Iterator<Item = T> + Send + 'static,
T: Send + 'static,
{
}
#[derive(Debug, Clone)]
pub struct PrefetchConfig {
pub buffer_size: usize,
pub unbounded: bool,
}
impl Default for PrefetchConfig {
fn default() -> Self {
Self {
buffer_size: 2,
unbounded: false,
}
}
}
impl PrefetchConfig {
pub fn new() -> Self {
Self::default()
}
pub fn buffer_size(mut self, size: usize) -> Self {
self.buffer_size = size;
self.unbounded = false;
self
}
pub fn unbounded(mut self) -> Self {
self.unbounded = true;
self
}
pub fn apply<I, T>(self, iter: I) -> PrefetchIterator<T>
where
I: Iterator<Item = T> + Send + 'static,
T: Send + 'static,
{
if self.unbounded {
PrefetchIterator::new_unbounded(iter)
} else {
PrefetchIterator::new(iter, self.buffer_size)
}
}
}
pub mod utils {
use super::*;
pub fn optimal_prefetch<I, T>(
iter: I,
expected_item_processing_time: u64,
) -> PrefetchIterator<T>
where
I: Iterator<Item = T> + Send + 'static,
T: Send + 'static,
{
let buffer_size = if expected_item_processing_time > 100 {
2 } else if expected_item_processing_time > 10 {
4 } else {
8 };
PrefetchIterator::new(iter, buffer_size)
}
pub fn cpu_bound_prefetch<I, T>(iter: I) -> PrefetchIterator<T>
where
I: Iterator<Item = T> + Send + 'static,
T: Send + 'static,
{
PrefetchIterator::new(iter, 2)
}
pub fn io_bound_prefetch<I, T>(iter: I) -> PrefetchIterator<T>
where
I: Iterator<Item = T> + Send + 'static,
T: Send + 'static,
{
PrefetchIterator::new(iter, 8)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::{Duration, Instant};
#[test]
fn test_prefetch_iterator_basic() {
let data = vec![1, 2, 3, 4, 5];
let iter = data.into_iter();
let mut prefetch_iter = PrefetchIterator::new(iter, 2);
assert_eq!(prefetch_iter.next(), Some(1));
assert_eq!(prefetch_iter.next(), Some(2));
assert_eq!(prefetch_iter.next(), Some(3));
assert_eq!(prefetch_iter.next(), Some(4));
assert_eq!(prefetch_iter.next(), Some(5));
assert_eq!(prefetch_iter.next(), None);
}
#[test]
fn test_prefetch_ext_trait() {
let data = vec![1, 2, 3, 4, 5];
let mut prefetch_iter = data.into_iter().prefetch(2);
assert_eq!(prefetch_iter.next(), Some(1));
assert_eq!(prefetch_iter.next(), Some(2));
assert_eq!(prefetch_iter.next(), Some(3));
assert_eq!(prefetch_iter.next(), Some(4));
assert_eq!(prefetch_iter.next(), Some(5));
assert_eq!(prefetch_iter.next(), None);
}
#[test]
fn test_prefetch_unbounded() {
let data = vec![1, 2, 3, 4, 5];
let mut prefetch_iter = data.into_iter().prefetch_unbounded();
assert_eq!(prefetch_iter.next(), Some(1));
assert_eq!(prefetch_iter.next(), Some(2));
assert_eq!(prefetch_iter.next(), Some(3));
assert_eq!(prefetch_iter.next(), Some(4));
assert_eq!(prefetch_iter.next(), Some(5));
assert_eq!(prefetch_iter.next(), None);
}
#[test]
fn test_prefetch_config() {
let config = PrefetchConfig::new().buffer_size(4);
assert_eq!(config.buffer_size, 4);
assert!(!config.unbounded);
let config = PrefetchConfig::new().unbounded();
assert!(config.unbounded);
}
#[test]
fn test_prefetch_config_apply() {
let data = vec![1, 2, 3, 4, 5];
let config = PrefetchConfig::new().buffer_size(3);
let mut prefetch_iter = config.apply(data.into_iter());
assert_eq!(prefetch_iter.next(), Some(1));
assert_eq!(prefetch_iter.next(), Some(2));
assert_eq!(prefetch_iter.next(), Some(3));
assert_eq!(prefetch_iter.next(), Some(4));
assert_eq!(prefetch_iter.next(), Some(5));
assert_eq!(prefetch_iter.next(), None);
}
#[test]
fn test_try_next() {
let data = vec![1, 2, 3];
let mut prefetch_iter = PrefetchIterator::new(data.into_iter(), 2);
std::thread::sleep(Duration::from_millis(10));
assert!(prefetch_iter.try_next().is_some());
}
#[test]
fn test_buffer_status() {
let data = vec![1, 2, 3, 4, 5];
let prefetch_iter = PrefetchIterator::new(data.into_iter(), 3);
std::thread::sleep(Duration::from_millis(10));
assert!(!prefetch_iter.buffer_is_empty());
assert!(prefetch_iter.buffer_len() > 0);
}
#[test]
fn test_utils_optimal_prefetch() {
let data = vec![1, 2, 3, 4, 5];
let mut prefetch_iter = utils::optimal_prefetch(data.into_iter(), 50);
assert_eq!(prefetch_iter.next(), Some(1));
assert_eq!(prefetch_iter.next(), Some(2));
assert_eq!(prefetch_iter.next(), Some(3));
assert_eq!(prefetch_iter.next(), Some(4));
assert_eq!(prefetch_iter.next(), Some(5));
assert_eq!(prefetch_iter.next(), None);
}
#[test]
fn test_utils_cpu_bound_prefetch() {
let data = vec![1, 2, 3, 4, 5];
let mut prefetch_iter = utils::cpu_bound_prefetch(data.into_iter());
assert_eq!(prefetch_iter.next(), Some(1));
assert_eq!(prefetch_iter.next(), Some(2));
assert_eq!(prefetch_iter.next(), Some(3));
assert_eq!(prefetch_iter.next(), Some(4));
assert_eq!(prefetch_iter.next(), Some(5));
assert_eq!(prefetch_iter.next(), None);
}
#[test]
fn test_utils_io_bound_prefetch() {
let data = vec![1, 2, 3, 4, 5];
let mut prefetch_iter = utils::io_bound_prefetch(data.into_iter());
assert_eq!(prefetch_iter.next(), Some(1));
assert_eq!(prefetch_iter.next(), Some(2));
assert_eq!(prefetch_iter.next(), Some(3));
assert_eq!(prefetch_iter.next(), Some(4));
assert_eq!(prefetch_iter.next(), Some(5));
assert_eq!(prefetch_iter.next(), None);
}
#[test]
fn test_empty_iterator() {
let data: Vec<i32> = vec![];
let mut prefetch_iter = data.into_iter().prefetch(2);
assert_eq!(prefetch_iter.next(), None);
}
#[test]
fn test_prefetch_performance() {
let slow_iter = (0..10).map(|x| {
std::thread::sleep(Duration::from_millis(10));
x
});
let start = Instant::now();
let mut prefetch_iter = slow_iter.prefetch(3);
assert_eq!(prefetch_iter.next(), Some(0));
assert_eq!(prefetch_iter.next(), Some(1));
assert_eq!(prefetch_iter.next(), Some(2));
let elapsed = start.elapsed();
assert!(elapsed < Duration::from_millis(100));
}
}