use std::collections::VecDeque;
use std::sync::mpsc;
use std::thread;
use super::{Worklist, WorklistChannel};
pub struct NaiveWorklist<T> {
query_tx: mpsc::Sender<WorklistQuery<T>>,
management_rx: mpsc::Receiver<NaiveWorklistChannel<T>>,
initial_len: usize,
}
struct WorklistThreadData<T> {
storage: VecDeque<T>,
waiting_pop_queries: VecDeque<usize>,
query_rx: mpsc::Receiver<WorklistQuery<T>>,
query_tx: mpsc::Sender<WorklistQuery<T>>,
response_channels: Vec<mpsc::Sender<WorklistResponse<T>>>,
}
impl<T> WorklistThreadData<T> {
fn create_channel(&mut self) -> NaiveWorklistChannel<T> {
let channel_id = self.response_channels.len();
let (tx, rx) = mpsc::channel();
self.response_channels.push(tx);
NaiveWorklistChannel {
channel_id,
tx: self.query_tx.clone(),
rx,
}
}
}
impl<T> NaiveWorklist<T>
where
T: Send + Sync + 'static,
{
pub fn empty() -> Self {
Self::new(Default::default())
}
pub fn new(initial_content: VecDeque<T>) -> Self {
let (query_tx, query_rx) = mpsc::channel();
let (management_tx, management_rx) = mpsc::channel();
let initial_len = initial_content.len();
let mut thread_data = WorklistThreadData {
storage: initial_content,
query_rx,
query_tx: query_tx.clone(),
response_channels: Default::default(),
waiting_pop_queries: Default::default(),
};
thread::spawn(move || {
while let Ok(msg) = thread_data.query_rx.recv() {
match msg {
WorklistQuery::Push(_channel_id, item) => {
if let Some(channel_id) = thread_data.waiting_pop_queries.pop_front() {
thread_data.response_channels[channel_id]
.send(WorklistResponse::Item(item))
.expect("failed to send response");
} else {
thread_data.storage.push_back(item);
}
}
WorklistQuery::Pop(channel_id) => {
if let Some(item) = thread_data.storage.pop_front() {
thread_data.response_channels[channel_id]
.send(WorklistResponse::Item(item))
.expect("failed to send response");
} else {
thread_data.waiting_pop_queries.push_back(channel_id);
}
}
WorklistQuery::WorklistLen(channel_id) => {
thread_data.response_channels[channel_id]
.send(WorklistResponse::Size(thread_data.storage.len()))
.expect("failed to send response");
}
WorklistQuery::NewChannel => management_tx
.send(thread_data.create_channel())
.expect("Failed to send new channel."),
WorklistQuery::TerminateServerThread => {
thread_data.response_channels.iter().for_each(|tx| {
let _ignore_err = tx.send(WorklistResponse::Terminated);
});
break;
}
}
}
});
Self {
query_tx,
management_rx,
initial_len,
}
}
}
impl<T> NaiveWorklist<T> {
fn stop_server(&self) {
let _ignore_error = self.query_tx.send(WorklistQuery::TerminateServerThread);
}
}
impl<T> Worklist<T> for NaiveWorklist<T> {
type Channel = NaiveWorklistChannel<T>;
fn create_channel(&mut self) -> Self::Channel {
self.query_tx
.send(WorklistQuery::NewChannel)
.expect("Failed to request new channel.");
self.management_rx
.recv()
.expect("Failed to receive new channel.")
}
fn stop(&mut self) {
self.stop_server();
}
fn initial_len(&self) -> usize {
self.initial_len
}
}
impl<T> Drop for NaiveWorklist<T> {
fn drop(&mut self) {
self.stop_server();
}
}
enum WorklistQuery<T> {
Push(usize, T),
Pop(usize),
NewChannel,
WorklistLen(usize),
TerminateServerThread,
}
#[derive(Debug, Copy, Clone)]
enum WorklistResponse<T> {
Item(T),
Size(usize),
Terminated,
}
pub struct NaiveWorklistChannel<T> {
channel_id: usize,
tx: mpsc::Sender<WorklistQuery<T>>,
rx: mpsc::Receiver<WorklistResponse<T>>,
}
impl<T> WorklistChannel<T> for NaiveWorklistChannel<T> {
fn push(&self, item: T) {
self.tx
.send(WorklistQuery::Push(self.channel_id, item))
.expect("Failed to send query to worklist.");
}
fn pop(&self) -> Option<T> {
self.tx
.send(WorklistQuery::Pop(self.channel_id))
.expect("Failed to send query to worklist.");
match self
.rx
.recv()
.expect("Failed to receive result to 'pop()' query.")
{
WorklistResponse::Item(i) => Some(i),
WorklistResponse::Size(_) => unreachable!(),
WorklistResponse::Terminated => None,
}
}
fn local_len(&self) -> usize {
0
}
fn global_len(&self) -> usize {
self.tx
.send(WorklistQuery::WorklistLen(self.channel_id))
.expect("Failed to send query to worklist.");
match self
.rx
.recv()
.expect("Failed to receive result to 'global_len()' query.")
{
WorklistResponse::Item(_) => unreachable!(),
WorklistResponse::Size(s) => s,
WorklistResponse::Terminated => unreachable!(),
}
}
fn close(self) {
todo!()
}
}
#[test]
fn test_naive_worklist() {
let num_repetitions = 10000;
for _ in 0..num_repetitions {
let mut wl = NaiveWorklist::empty();
let channel1 = wl.create_channel();
let channel2 = wl.create_channel();
channel1.push(1);
channel2.push(2);
channel1.push(3);
assert_eq!(channel2.pop(), Some(1));
assert_eq!(channel1.pop(), Some(2));
assert_eq!(channel1.pop(), Some(3));
wl.stop();
}
}
#[test]
fn test_naive_worklist_from_other_thread() {
let num_repetitions = 10000;
for _ in 0..num_repetitions {
let mut wl = NaiveWorklist::empty();
let channel1 = wl.create_channel();
let channel2 = wl.create_channel();
let t1 = thread::spawn(move || {
channel1.push(1);
channel1.push(2);
});
let t2 = thread::spawn(move || {
channel2.push(3);
});
t1.join().unwrap();
t2.join().unwrap();
let channel3 = wl.create_channel();
assert_ne!(channel3.pop(), None);
assert_ne!(channel3.pop(), None);
assert_ne!(channel3.pop(), None);
wl.stop();
}
}