use std::hash::Hash;
use crossbeam::channel::{Receiver, Select, SelectedOperation, Sender};
use rustc_hash::FxHashMap;
pub struct TaggedSelect<'a, T: Sized + Copy + Eq + Hash> {
tag_by_index: FxHashMap<usize, T>,
index_by_tag: FxHashMap<T, usize>,
select: Select<'a>,
}
impl<'a, T: Sized + Copy + Eq + Hash> TaggedSelect<'a, T> {
pub fn new() -> Self {
TaggedSelect {
tag_by_index: FxHashMap::default(),
index_by_tag: FxHashMap::default(),
select: Select::new(),
}
}
pub fn recv<E>(&mut self, tag: T, receiver: &'a Receiver<E>) {
if !self.index_by_tag.contains_key(&tag) {
let operation_index = self.select.recv(receiver);
self.tag_by_index.insert(operation_index, tag);
self.index_by_tag.insert(tag, operation_index);
}
}
#[allow(dead_code)]
pub fn send<E>(&mut self, tag: T, sender: &'a Sender<E>) {
if !self.index_by_tag.contains_key(&tag) {
let operation_index = self.select.send(sender);
self.tag_by_index.insert(operation_index, tag);
self.index_by_tag.insert(tag, operation_index);
}
}
pub fn remove(&mut self, marker: T) {
if let Some(index) = self.index_by_tag.remove(&marker) {
self.tag_by_index.remove(&index);
self.select.remove(index);
}
}
pub fn select(&mut self) -> (T, SelectedOperation) {
let operation = self.select.select();
self.tag_by_index
.get(&operation.index())
.map(|&t| (t, operation))
.unwrap()
}
}
#[cfg(test)]
mod tests {
use anyhow::{anyhow, Context, Result};
use crossbeam::channel::{bounded, unbounded};
use crate::concurrency::TaggedSelect;
#[derive(Eq, PartialEq, Hash, Copy, Clone)]
enum TestTag {
First,
Second,
}
#[test]
fn select_recv() -> Result<()> {
let mut select = TaggedSelect::new();
let (send_1, recv_1) = unbounded::<u8>();
let (_send_2, recv_2) = unbounded::<u16>();
select.recv(TestTag::First, &recv_1);
select.recv(TestTag::Second, &recv_2);
send_1.send(12u8)?;
let result = match select.select() {
(TestTag::First, oper) => {
if let Ok(payload) = oper.recv(&recv_1) {
assert_eq!(payload, 12u8);
Ok(())
} else {
Err(anyhow!("Got unexpected message"))
}
}
(TestTag::Second, _oper) => Err(anyhow!("Got unexpected operation")),
};
result
}
#[test]
fn select_send() -> Result<()> {
let mut select = TaggedSelect::new();
let (send_1, _recv_1) = bounded::<u8>(1);
let (send_2, _recv_2) = unbounded::<u16>();
select.send(TestTag::Second, &send_2);
select.send(TestTag::First, &send_1);
send_1.send(12u8)?;
let result = match select.select() {
(TestTag::First, _oper) => Err(anyhow!("Got unexpected operation")),
(TestTag::Second, oper) => oper.send(&send_2, 16u16).context("Send operation failed"),
};
result
}
}