use crate::tpch_cli::progress::TableProgress;
use futures::StreamExt;
use log::debug;
use std::collections::VecDeque;
use std::io;
use std::sync::{Arc, Mutex};
use tokio::task::JoinSet;
pub trait Source: Send {
fn create(self, buffer: Vec<u8>) -> Vec<u8>;
fn header(&self, buffer: Vec<u8>) -> Vec<u8>;
}
pub trait Sink: Send {
fn sink(&mut self, buffer: &[u8]) -> Result<(), io::Error>;
fn flush(self) -> Result<(), io::Error>;
}
pub async fn generate_in_chunks<G, I, S>(
sink: S,
sources: I,
num_threads: usize,
) -> Result<(), io::Error>
where
G: Source + 'static,
I: Iterator<Item = G>,
S: Sink + 'static,
{
generate_in_chunks_with_progress(sink, sources, num_threads, TableProgress::default()).await
}
pub(crate) async fn generate_in_chunks_with_progress<G, I, S>(
mut sink: S,
sources: I,
num_threads: usize,
progress: TableProgress,
) -> Result<(), io::Error>
where
G: Source + 'static,
I: Iterator<Item = G>,
S: Sink + 'static,
{
let recycler = BufferRecycler::new();
let mut sources = sources.peekable();
debug!("Using {num_threads} threads");
let Some(first) = sources.peek() else {
return Ok(()); };
let header = first.header(Vec::new());
let sources_and_recyclers = sources.map(|generator| (generator, recycler.clone()));
let (tx, mut rx) = tokio::sync::mpsc::channel(num_threads);
let mut stream = futures::stream::iter(sources_and_recyclers)
.map(async |(source, recycler)| {
let buffer = recycler.new_buffer(1024 * 1024 * 8);
let mut join_set = JoinSet::new();
join_set.spawn(async move { source.create(buffer) });
join_set
.join_next()
.await
.expect("had one item")
.expect("join_next join is infallible unless task panics")
})
.buffered(num_threads)
.map(async |buffer| {
if let Err(e) = tx.send(buffer).await {
debug!("Error sending buffer to writer: {e}");
}
});
let captured_recycler = recycler.clone();
let writer_task = tokio::task::spawn_blocking(move || {
sink.sink(&header)?;
while let Some(buffer) = rx.blocking_recv() {
sink.sink(&buffer)?;
captured_recycler.return_buffer(buffer);
progress.increment_output_unit();
}
sink.flush()
});
while let Some(write_task) = stream.next().await {
if writer_task.is_finished() {
debug!("writer task is done early, stopping writer");
break;
}
write_task.await; }
drop(stream); drop(tx);
debug!("waiting for writer task to complete");
writer_task.await.expect("writer task panicked")
}
#[derive(Debug, Clone)]
struct BufferRecycler {
buffers: Arc<Mutex<VecDeque<Vec<u8>>>>,
}
impl BufferRecycler {
fn new() -> Self {
Self {
buffers: Arc::new(Mutex::new(VecDeque::new())),
}
}
fn new_buffer(&self, size: usize) -> Vec<u8> {
let mut buffers = self.buffers.lock().unwrap();
if let Some(mut buffer) = buffers.pop_front() {
buffer.clear();
if size > buffer.capacity() {
buffer.reserve(size - buffer.capacity());
}
buffer
} else {
Vec::with_capacity(size)
}
}
fn return_buffer(&self, buffer: Vec<u8>) {
let mut buffers = self.buffers.lock().unwrap();
buffers.push_back(buffer);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tpch_cli::progress::{ProgressTracker, RunProgress};
use crate::tpch_cli::Table;
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Debug)]
struct CountingProgress {
increments: AtomicU64,
}
impl ProgressTracker for CountingProgress {
fn increment(&self, _table: Table, units: u64) {
self.increments.fetch_add(units, Ordering::Relaxed);
}
}
struct TestSource {
header: &'static [u8],
data: &'static [u8],
}
impl Source for TestSource {
fn header(&self, mut buffer: Vec<u8>) -> Vec<u8> {
buffer.extend_from_slice(self.header);
buffer
}
fn create(self, mut buffer: Vec<u8>) -> Vec<u8> {
buffer.extend_from_slice(self.data);
buffer
}
}
struct CapturingSink {
writes: Arc<Mutex<Vec<Vec<u8>>>>,
}
impl Sink for CapturingSink {
fn sink(&mut self, buffer: &[u8]) -> Result<(), io::Error> {
self.writes.lock().unwrap().push(buffer.to_vec());
Ok(())
}
fn flush(self) -> Result<(), io::Error> {
Ok(())
}
}
#[tokio::test]
async fn progress_counts_generated_chunks_not_header() {
let writes = Arc::new(Mutex::new(Vec::new()));
let tracker = Arc::new(CountingProgress {
increments: AtomicU64::new(0),
});
let progress: Arc<dyn ProgressTracker> = tracker.clone();
let progress = RunProgress::with_tracker(progress).for_table(Table::Region);
let sources = vec![
TestSource {
header: b"header\n",
data: b"row-1\n",
},
TestSource {
header: b"header\n",
data: b"row-2\n",
},
];
generate_in_chunks_with_progress(
CapturingSink {
writes: Arc::clone(&writes),
},
sources.into_iter(),
1,
progress,
)
.await
.unwrap();
assert_eq!(tracker.increments.load(Ordering::Relaxed), 2);
assert_eq!(
*writes.lock().unwrap(),
vec![
b"header\n".to_vec(),
b"row-1\n".to_vec(),
b"row-2\n".to_vec()
]
);
}
}