pargraph 0.2.0

Operator based parallel graph processing.
Documentation
// SPDX-FileCopyrightText: 2022 Thomas Kramer <code@tkramer.ch>
//
// SPDX-License-Identifier: GPL-3.0-or-later

//! Naive implementation of a worklist based on a `VecDeque` and multi-producer-single-consumer channels.

use std::collections::VecDeque;
use std::sync::mpsc;
use std::thread;

use super::{Worklist, WorklistChannel};

/// Naive worklist implementation based on message passing.
pub struct NaiveWorklist<T> {
    query_tx: mpsc::Sender<WorklistQuery<T>>,
    management_rx: mpsc::Receiver<NaiveWorklistChannel<T>>,
    initial_len: usize,
}

struct WorklistThreadData<T> {
    /// The actual worklist content.
    storage: VecDeque<T>,
    /// IDs of channels which wait for a response to a 'pop' query.
    waiting_pop_queries: VecDeque<usize>,
    /// Receive queries (push, pop).
    query_rx: mpsc::Receiver<WorklistQuery<T>>,
    /// Sender socket to send queries to the server thread.
    query_tx: mpsc::Sender<WorklistQuery<T>>,
    /// Registered channels.
    response_channels: Vec<mpsc::Sender<WorklistResponse<T>>>,
}

impl<T> WorklistThreadData<T> {
    fn create_channel(&mut self) -> NaiveWorklistChannel<T> {
        let channel_id = self.response_channels.len();

        // Create channel to worker thread.
        let (tx, rx) = mpsc::channel();
        self.response_channels.push(tx);

        NaiveWorklistChannel {
            channel_id,
            tx: self.query_tx.clone(),
            rx,
        }
    }
}

impl<T> NaiveWorklist<T>
where
    T: Send + Sync + 'static,
{
    /// Create a new empty worklist.
    pub fn empty() -> Self {
        Self::new(Default::default())
    }

    /// Create a new worklist with some initial content.
    /// This will spawn a server thread in the background which
    /// manages the queries from the communication channels.
    pub fn new(initial_content: VecDeque<T>) -> Self {
        let (query_tx, query_rx) = mpsc::channel();
        let (management_tx, management_rx) = mpsc::channel();

        let initial_len = initial_content.len();

        let mut thread_data = WorklistThreadData {
            storage: initial_content,
            query_rx,
            query_tx: query_tx.clone(),
            response_channels: Default::default(),
            waiting_pop_queries: Default::default(),
        };

        thread::spawn(move || {
            while let Ok(msg) = thread_data.query_rx.recv() {
                match msg {
                    WorklistQuery::Push(_channel_id, item) => {
                        if let Some(channel_id) = thread_data.waiting_pop_queries.pop_front() {
                            // There's a channel waiting for an item. Forward it with out storing it.
                            thread_data.response_channels[channel_id]
                                .send(WorklistResponse::Item(item))
                                .expect("failed to send response");
                        } else {
                            // No channel is waiting for an item. Store it.
                            thread_data.storage.push_back(item);
                        }
                    }
                    WorklistQuery::Pop(channel_id) => {
                        if let Some(item) = thread_data.storage.pop_front() {
                            // Answer the query with the first element from the queue.
                            thread_data.response_channels[channel_id]
                                .send(WorklistResponse::Item(item))
                                .expect("failed to send response");
                        } else {
                            // There's no data available. Remember the query to answer it once an item arrives.
                            thread_data.waiting_pop_queries.push_back(channel_id);
                        }
                    }
                    WorklistQuery::WorklistLen(channel_id) => {
                        thread_data.response_channels[channel_id]
                            .send(WorklistResponse::Size(thread_data.storage.len()))
                            .expect("failed to send response");
                    }
                    WorklistQuery::NewChannel => management_tx
                        .send(thread_data.create_channel())
                        .expect("Failed to send new channel."),
                    WorklistQuery::TerminateServerThread => {
                        // Notify waiting channels that there's no more data coming.
                        thread_data.response_channels.iter().for_each(|tx| {
                            let _ignore_err = tx.send(WorklistResponse::Terminated);
                        });
                        break;
                    }
                }
            }
        });

        Self {
            query_tx,
            management_rx,
            initial_len,
        }
    }
}

impl<T> NaiveWorklist<T> {
    fn stop_server(&self) {
        let _ignore_error = self.query_tx.send(WorklistQuery::TerminateServerThread);
    }
}

impl<T> Worklist<T> for NaiveWorklist<T> {
    type Channel = NaiveWorklistChannel<T>;

    fn create_channel(&mut self) -> Self::Channel {
        self.query_tx
            .send(WorklistQuery::NewChannel)
            .expect("Failed to request new channel.");
        self.management_rx
            .recv()
            .expect("Failed to receive new channel.")
    }

    /// Tell the worklist that there will be no more items arriving.
    fn stop(&mut self) {
        self.stop_server();
    }

    fn initial_len(&self) -> usize {
        self.initial_len
    }
}

impl<T> Drop for NaiveWorklist<T> {
    /// Stop the server thread when leaving scope.
    fn drop(&mut self) {
        self.stop_server();
    }
}

/// Message passed between worklist-server, worker threads and main thread.
enum WorklistQuery<T> {
    /// Insert an element into the worklist.
    Push(usize, T),
    /// Get an element from the worklist.
    Pop(usize),
    /// Create a new communication channel to the worklist server.
    NewChannel,
    /// Get the size of the worklist.
    WorklistLen(usize),
    TerminateServerThread,
}

#[derive(Debug, Copy, Clone)]
enum WorklistResponse<T> {
    /// Response to a `pop()` call.
    Item(T),
    /// Size of the global worklist.
    Size(usize),
    Terminated,
}

/// Communication channel for sending queries from worker threads to the worklist.
pub struct NaiveWorklistChannel<T> {
    /// ID of this channel. Used to tell the worklist on which channel to answer.
    channel_id: usize,
    /// Channel for sending push/pop queries to the worklist.
    tx: mpsc::Sender<WorklistQuery<T>>,
    /// Channel for receiving elements from the worklist.
    rx: mpsc::Receiver<WorklistResponse<T>>,
}

impl<T> WorklistChannel<T> for NaiveWorklistChannel<T> {
    fn push(&self, item: T) {
        self.tx
            .send(WorklistQuery::Push(self.channel_id, item))
            .expect("Failed to send query to worklist.");
    }

    fn pop(&self) -> Option<T> {
        // Send query.
        self.tx
            .send(WorklistQuery::Pop(self.channel_id))
            .expect("Failed to send query to worklist.");
        // Receive response.
        match self
            .rx
            .recv()
            .expect("Failed to receive result to 'pop()' query.")
        {
            WorklistResponse::Item(i) => Some(i),
            WorklistResponse::Size(_) => unreachable!(),
            WorklistResponse::Terminated => None,
        }
    }

    fn local_len(&self) -> usize {
        0
    }

    fn global_len(&self) -> usize {
        // Send query.
        self.tx
            .send(WorklistQuery::WorklistLen(self.channel_id))
            .expect("Failed to send query to worklist.");
        // Receive response.
        match self
            .rx
            .recv()
            .expect("Failed to receive result to 'global_len()' query.")
        {
            WorklistResponse::Item(_) => unreachable!(),
            WorklistResponse::Size(s) => s,
            WorklistResponse::Terminated => unreachable!(),
        }
    }

    fn close(self) {
        todo!()
    }
}

#[test]
fn test_naive_worklist() {
    let num_repetitions = 10000;
    for _ in 0..num_repetitions {
        let mut wl = NaiveWorklist::empty();

        let channel1 = wl.create_channel();
        let channel2 = wl.create_channel();

        channel1.push(1);
        channel2.push(2);
        channel1.push(3);
        assert_eq!(channel2.pop(), Some(1));
        assert_eq!(channel1.pop(), Some(2));
        assert_eq!(channel1.pop(), Some(3));
        //assert_eq!(channel1.pop(), None);
        //assert_eq!(channel2.pop(), None);

        wl.stop();
    }
}

#[test]
fn test_naive_worklist_from_other_thread() {
    let num_repetitions = 10000;
    for _ in 0..num_repetitions {
        let mut wl = NaiveWorklist::empty();

        let channel1 = wl.create_channel();
        let channel2 = wl.create_channel();

        let t1 = thread::spawn(move || {
            channel1.push(1);
            channel1.push(2);
        });

        let t2 = thread::spawn(move || {
            channel2.push(3);
        });

        t1.join().unwrap();
        t2.join().unwrap();

        let channel3 = wl.create_channel();
        assert_ne!(channel3.pop(), None);
        assert_ne!(channel3.pop(), None);
        assert_ne!(channel3.pop(), None);
        wl.stop();
        //assert_eq!(channel3.pop(), None);
    }
}