use std::cmp;
use std::collections::VecDeque;
use std::io::{Cursor, Read, Write};
use std::mem::swap;
use std::sync::mpsc::{Receiver, Sender, channel};
use std::sync::{Arc, Mutex};
use byteorder::WriteBytesExt;
use super::simple_threadpool::LeptonThreadPool;
use crate::lepton_error::{AddContext, ExitCode, Result};
use crate::{helpers::*, lepton_error::err_exit_code, structs::partial_buffer::PartialBuffer};
enum Message {
Eof(usize),
WriteBlock(usize, Vec<u8>),
}
pub struct MultiplexWriter {
thread_id: usize,
sender: Sender<Message>,
buffer: Vec<u8>,
}
const WRITE_BUFFER_SIZE: usize = 65536;
impl Write for MultiplexWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let mut copy_start = 0;
while copy_start < buf.len() {
let amount_to_copy = cmp::min(
WRITE_BUFFER_SIZE - self.buffer.len(),
buf.len() - copy_start,
);
self.buffer
.extend_from_slice(&buf[copy_start..copy_start + amount_to_copy]);
if self.buffer.len() == WRITE_BUFFER_SIZE {
self.flush()?;
}
copy_start += amount_to_copy;
}
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
if self.buffer.len() > 0 {
let mut new_buffer = Vec::with_capacity(WRITE_BUFFER_SIZE);
swap(&mut new_buffer, &mut self.buffer);
self.sender
.send(Message::WriteBlock(self.thread_id, new_buffer))
.unwrap();
}
Ok(())
}
}
struct ThreadResults<RESULT> {
results: Vec<Receiver<Result<RESULT>>>,
}
impl<RESULT> ThreadResults<RESULT> {
fn new() -> Self {
ThreadResults {
results: Vec::new(),
}
}
fn send_results<T: FnOnce() -> Result<RESULT> + Send + 'static>(
&mut self,
f: T,
) -> impl FnOnce() + use<RESULT, T> {
let (tx, rx) = channel();
self.results.push(rx);
move || {
let r = catch_unwind_result(f);
let _ = tx.send(r);
}
}
fn receive_results(&mut self) -> Result<Vec<RESULT>> {
let mut final_results = Vec::new();
let mut error_found = None;
for r in self.results.drain(..) {
match r.recv() {
Ok(Ok(r)) => final_results.push(r),
Ok(Err(e)) => {
error_found = Some(e);
}
Err(e) => {
if error_found.is_none() {
error_found = Some(e.into());
}
}
}
}
if let Some(error) = error_found {
Err(error)
} else {
Ok(final_results)
}
}
}
pub fn multiplex_write<WRITE, FN, RESULT>(
writer: &mut WRITE,
num_threads: usize,
thread_pool: &dyn LeptonThreadPool,
processor: FN,
) -> Result<Vec<RESULT>>
where
WRITE: Write,
FN: Fn(&mut MultiplexWriter, usize) -> Result<RESULT> + Send + Sync + 'static,
RESULT: Send + 'static,
{
let mut thread_results = ThreadResults::new();
let mut packet_receivers = Vec::new();
let arc_processor = Arc::new(Box::new(processor));
for thread_id in 0..num_threads {
let (tx, rx) = channel();
let mut thread_writer = MultiplexWriter {
thread_id: thread_id,
sender: tx,
buffer: Vec::with_capacity(WRITE_BUFFER_SIZE),
};
let processor_clone = arc_processor.clone();
let f = thread_results.send_results(move || {
let r = processor_clone(&mut thread_writer, thread_id)?;
thread_writer.flush().context()?;
thread_writer
.sender
.send(Message::Eof(thread_id))
.context()?;
Ok(r)
});
thread_pool.run(Box::new(f));
packet_receivers.push(rx);
}
let mut current_thread_writer = 0;
loop {
match packet_receivers[current_thread_writer].recv() {
Ok(Message::WriteBlock(thread_id, b)) => {
let tid = thread_id as u8;
let l = b.len() - 1;
if l == 4095 || l == 16383 || l == 65535 {
writer.write_u8(tid | ((l.ilog2() as u8 >> 1) - 4) << 4)?;
} else {
writer.write_u8(tid)?;
writer.write_u8((l & 0xff) as u8)?;
writer.write_u8(((l >> 8) & 0xff) as u8)?;
}
writer.write_all(&b[..])?;
current_thread_writer = (current_thread_writer + 1) % packet_receivers.len();
}
Ok(Message::Eof(_)) | Err(_) => {
packet_receivers.remove(current_thread_writer);
if packet_receivers.len() == 0 {
break;
}
current_thread_writer = current_thread_writer % packet_receivers.len();
}
}
}
thread_results.receive_results()
}
pub struct MultiplexReader {
thread_id: usize,
receiver: Receiver<Message>,
current_buffer: Cursor<Vec<u8>>,
end_of_file: bool,
}
impl Read for MultiplexReader {
#[inline(always)]
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let amount_read = self.current_buffer.read(buf)?;
if amount_read > 0 {
return Ok(amount_read);
}
self.read_slow(buf)
}
}
impl MultiplexReader {
#[cold]
#[inline(never)]
fn read_slow(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
while !self.end_of_file {
let amount_read = self.current_buffer.read(buf)?;
if amount_read > 0 {
return Ok(amount_read);
}
match self.receiver.recv() {
Ok(r) => match r {
Message::Eof(_tid) => {
self.end_of_file = true;
}
Message::WriteBlock(tid, block) => {
debug_assert_eq!(
tid, self.thread_id,
"incoming thread must be equal to processing thread"
);
self.current_buffer = Cursor::new(block);
}
},
Err(e) => {
return std::io::Result::Err(std::io::Error::new(std::io::ErrorKind::Other, e));
}
}
}
return Ok(0);
}
}
pub struct MultiplexReaderState<RESULT> {
sender_channels: Vec<Sender<Message>>,
result_receiver: ThreadResults<RESULT>,
retention_bytes: usize,
current_state: State,
}
enum State {
StartBlock,
U16Length(u8),
Block(u8, usize),
}
impl<RESULT> MultiplexReaderState<RESULT> {
pub fn new<FN>(
num_threads: usize,
thread_pool: &dyn LeptonThreadPool,
retention_bytes: usize,
max_processor_threads: usize,
processor: FN,
) -> MultiplexReaderState<RESULT>
where
FN: Fn(usize, &mut MultiplexReader) -> Result<RESULT> + Send + Sync + 'static,
RESULT: Send + 'static,
{
let arc_processor = Arc::new(Box::new(processor));
let mut channel_to_sender = Vec::new();
let mut work = VecDeque::new();
let mut result_receiver = ThreadResults::new();
for thread_id in 0..num_threads {
let (tx, rx) = channel::<Message>();
channel_to_sender.push(tx);
let cloned_processor = arc_processor.clone();
let f = result_receiver.send_results(move || {
let mut proc_reader = MultiplexReader {
thread_id: thread_id,
current_buffer: Cursor::new(Vec::new()),
receiver: rx,
end_of_file: false,
};
cloned_processor(thread_id, &mut proc_reader)
});
work.push_back(f);
}
let shared_queue = Arc::new(Mutex::new(work));
for _i in 0..num_threads.min(max_processor_threads) {
let q = shared_queue.clone();
thread_pool.run(Box::new(move || {
loop {
let w = q.lock().unwrap().pop_front();
if let Some(f) = w {
f();
} else {
break;
}
}
}));
}
MultiplexReaderState {
sender_channels: channel_to_sender,
result_receiver: result_receiver,
current_state: State::StartBlock,
retention_bytes,
}
}
pub fn process_buffer(&mut self, source: &mut PartialBuffer<'_>) -> Result<()> {
while source.continue_processing() {
match self.current_state {
State::StartBlock => {
if let Some(a) = source.take_n::<1>(self.retention_bytes) {
let thread_marker = a[0];
let thread_id = thread_marker & 0xf;
if usize::from(thread_id) >= self.sender_channels.len() {
return err_exit_code(
ExitCode::BadLeptonFile,
format!("invalid thread_id {0}", thread_id),
);
}
if thread_marker < 16 {
self.current_state = State::U16Length(thread_id);
} else {
let flags = (thread_marker >> 4) & 3;
self.current_state = State::Block(thread_id, 1024 << (2 * flags));
}
} else {
break;
}
}
State::U16Length(thread_marker) => {
if let Some(a) = source.take_n::<2>(self.retention_bytes) {
let b0 = usize::from(a[0]);
let b1 = usize::from(a[1]);
self.current_state = State::Block(thread_marker, (b1 << 8) + b0 + 1);
} else {
break;
}
}
State::Block(thread_id, data_length) => {
if let Some(a) = source.take(data_length, self.retention_bytes) {
let tid = usize::from(thread_id);
let _ = self.sender_channels[tid].send(Message::WriteBlock(tid, a));
self.current_state = State::StartBlock;
} else {
break;
}
}
}
}
Ok(())
}
pub fn complete(&mut self) -> Result<Vec<RESULT>> {
for thread_id in 0..self.sender_channels.len() {
let _ = self.sender_channels[thread_id].send(Message::Eof(thread_id));
}
self.result_receiver.receive_results()
}
}
#[cfg(test)]
use crate::DEFAULT_THREAD_POOL;
#[test]
fn test_multiplex_end_to_end() {
use byteorder::ReadBytesExt;
let mut output = Vec::new();
let w = multiplex_write(
&mut output,
10,
&DEFAULT_THREAD_POOL,
|writer, thread_id| -> Result<usize> {
for i in thread_id as u32..10000 {
writer.write_u32::<byteorder::LittleEndian>(i)?;
}
Ok(thread_id)
},
)
.unwrap();
assert_eq!(w[..], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
let mut extra = Vec::new();
let mut multiplex_state = MultiplexReaderState::new(
10,
&DEFAULT_THREAD_POOL,
0,
8,
|thread_id, reader| -> Result<usize> {
for i in thread_id as u32..10000 {
let read_thread_id = reader.read_u32::<byteorder::LittleEndian>()?;
assert_eq!(read_thread_id, i);
}
Ok(thread_id)
},
);
for i in 0..output.len() {
let mut i = PartialBuffer::new(&output[i..=i], &mut extra);
multiplex_state.process_buffer(&mut i).unwrap();
}
let r = multiplex_state.complete().unwrap();
assert_eq!(r[..], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
}
#[cfg(test)]
use crate::lepton_error::LeptonError;
#[test]
fn test_multiplex_read_error() {
let mut multiplex_state =
MultiplexReaderState::new(10, &DEFAULT_THREAD_POOL, 0, 8, |_, _| -> Result<usize> {
Err(LeptonError::new(ExitCode::FileNotFound, "test error"))?
});
let e: LeptonError = multiplex_state.complete().unwrap_err().into();
assert_eq!(e.exit_code(), ExitCode::FileNotFound);
assert!(e.message().starts_with("test error"));
}
#[test]
fn test_multiplex_read_panic() {
let mut multiplex_state =
MultiplexReaderState::new(10, &DEFAULT_THREAD_POOL, 0, 8, |_, _| -> Result<usize> {
panic!();
});
let e: LeptonError = multiplex_state.complete().unwrap_err().into();
assert_eq!(e.exit_code(), ExitCode::AssertionFailure);
}
#[test]
fn test_multiplex_write_error() {
let mut output = Vec::new();
let e: LeptonError = multiplex_write(
&mut output,
10,
&DEFAULT_THREAD_POOL,
|_, thread_id| -> Result<usize> {
if thread_id == 3 {
Err(LeptonError::new(ExitCode::FileNotFound, "test error"))?
} else {
Ok(0)
}
},
)
.unwrap_err()
.into();
assert_eq!(e.exit_code(), ExitCode::FileNotFound);
assert!(e.message().starts_with("test error"));
}
#[test]
fn test_multiplex_write_panic() {
let mut output = Vec::new();
let e: LeptonError = multiplex_write(
&mut output,
10,
&DEFAULT_THREAD_POOL,
|_, thread_id| -> Result<usize> {
if thread_id == 5 {
panic!();
}
Ok(0)
},
)
.unwrap_err()
.into();
assert_eq!(e.exit_code(), ExitCode::AssertionFailure);
}