use std::cell::UnsafeCell;
use std::sync::atomic::{AtomicU16, Ordering};
const SLOTS: usize = 5;
const BITS: u16 = 3;
const MASK: u16 = 0b111;
fn get(n: u16, idx: u16) -> u16 {
(n >> (BITS * idx)) & MASK
}
fn set(n: u16, idx: u16, v: u16) -> u16 {
let v = v << (BITS * idx);
let mask = MASK << (BITS * idx);
(n & !mask) | v
}
fn enqueue(q: &AtomicU16, val: u16) {
let mut current = q.load(Ordering::Relaxed);
loop {
let empty = (0..SLOTS as u16)
.find(|i| get(current, *i) == 0)
.expect("No empty slot available");
let modified = set(current, empty, val);
match q.compare_exchange_weak(current, modified, Ordering::Release, Ordering::Relaxed) {
Ok(_) => break,
Err(changed) => current = changed, }
}
}
fn dequeue(q: &AtomicU16) -> Option<u16> {
let mut current = q.load(Ordering::Relaxed);
loop {
let val = current & MASK;
if val == 0 {
break None;
}
let modified = current >> BITS;
match q.compare_exchange_weak(current, modified, Ordering::Acquire, Ordering::Relaxed) {
Ok(_) => break Some(val),
Err(changed) => current = changed,
}
}
}
pub struct Channel<T> {
storage: [UnsafeCell<Option<T>>; SLOTS],
empty: AtomicU16,
full: AtomicU16,
}
impl<T> Channel<T> {
pub fn new() -> Self {
let storage = Default::default();
let me = Self {
storage,
empty: AtomicU16::new(0),
full: AtomicU16::new(0),
};
for i in 1..SLOTS + 1 {
enqueue(&me.empty, i as u16);
}
me
}
pub fn send(&self, val: T) {
if let Some(empty_idx) = dequeue(&self.empty) {
unsafe { *self.storage[empty_idx as usize - 1].get() = Some(val) };
enqueue(&self.full, empty_idx);
}
}
pub fn recv(&self) -> Option<T> {
dequeue(&self.full).map(|idx| {
let result = unsafe { &mut *self.storage[idx as usize - 1].get() }
.take()
.expect("Full slot with nothing in it");
enqueue(&self.empty, idx);
result
})
}
}
impl<T> Default for Channel<T> {
fn default() -> Self {
Self::new()
}
}
unsafe impl<T: Send> Send for Channel<T> {}
unsafe impl<T: Send> Sync for Channel<T> {}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::thread;
use super::*;
#[test]
fn new_empty() {
let channel = Channel::<usize>::new();
assert!(channel.recv().is_none());
assert!(channel.recv().is_none());
}
#[test]
fn pass_value() {
let channel = Channel::new();
channel.send(42);
assert_eq!(42, channel.recv().unwrap());
assert!(channel.recv().is_none());
}
#[test]
fn multiple() {
let channel = Channel::new();
for i in 0..1000 {
channel.send(i);
assert_eq!(i, channel.recv().unwrap());
assert!(channel.recv().is_none());
}
}
#[test]
fn overflow() {
let channel = Channel::new();
for i in 0..10 {
channel.send(i);
}
for i in 0..5 {
assert_eq!(i, channel.recv().unwrap());
}
assert!(channel.recv().is_none());
}
#[test]
fn multi_thread() {
let channel = Arc::new(Channel::<usize>::new());
let sender = thread::spawn({
let channel = Arc::clone(&channel);
move || {
for i in 0..4 {
channel.send(i);
}
}
});
let mut results = Vec::new();
while results.len() < 4 {
results.extend(channel.recv());
}
assert_eq!(vec![0, 1, 2, 3], results);
sender.join().unwrap();
}
}