use std::collections::BinaryHeap;
use tokio::sync::mpsc::{Receiver, Sender, channel};
enum PollState {
NoPoll,
Poll(usize),
PollAll,
}
struct LinearedItem<T> {
value: T,
sender_id: usize,
}
impl<T: Ord> PartialEq for LinearedItem<T> {
fn eq(&self, other: &Self) -> bool {
self.value.eq(&other.value)
}
}
impl<T: Ord> Eq for LinearedItem<T> {}
#[allow(clippy::non_canonical_partial_ord_impl)]
impl<T: Ord> PartialOrd for LinearedItem<T> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.value.cmp(&other.value))
}
}
impl<T: Ord> Ord for LinearedItem<T> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.value.cmp(&other.value)
}
}
pub struct Linearizer<T> {
receivers: Vec<Receiver<T>>,
poll_state: PollState,
heap: BinaryHeap<LinearedItem<T>>,
}
impl<T: Ord> Linearizer<T> {
pub fn new(num_inserters: usize, buffer_size: usize) -> (Self, Vec<Inserter<T>>) {
let mut receivers = Vec::with_capacity(num_inserters);
let mut inserters = Vec::with_capacity(num_inserters);
for _ in 0..num_inserters {
let (sender, receiver) = channel(buffer_size);
receivers.push(receiver);
inserters.push(Inserter { sender });
}
let slf = Self {
receivers,
poll_state: PollState::PollAll,
heap: BinaryHeap::with_capacity(num_inserters),
};
(slf, inserters)
}
pub fn new_with_maintain_order(
num_inserters: usize,
buffer_size: usize,
maintain_order: bool,
) -> (Self, Vec<Inserter<T>>) {
if maintain_order {
return Self::new(num_inserters, buffer_size);
}
let (sender, receiver) = channel(buffer_size * num_inserters);
let receivers = vec![receiver];
let inserters = (0..num_inserters)
.map(|_| Inserter {
sender: sender.clone(),
})
.collect();
let slf = Self {
receivers,
poll_state: PollState::PollAll,
heap: BinaryHeap::with_capacity(1),
};
(slf, inserters)
}
pub async fn get(&mut self) -> Option<T> {
let poll_range = match self.poll_state {
PollState::NoPoll => 0..0,
PollState::Poll(i) => i..i + 1,
PollState::PollAll => 0..self.receivers.len(),
};
for sender_id in poll_range {
if let Some(value) = self.receivers[sender_id].recv().await {
self.heap.push(LinearedItem { value, sender_id });
}
}
if let Some(first_in_merged_streams) = self.heap.pop() {
let LinearedItem { value, sender_id } = first_in_merged_streams;
self.poll_state = PollState::Poll(sender_id);
Some(value)
} else {
self.poll_state = PollState::NoPoll;
None
}
}
}
pub struct Inserter<T> {
sender: Sender<T>,
}
impl<T: Ord> Inserter<T> {
pub async fn insert(&mut self, value: T) -> Result<(), T> {
self.sender.send(value).await.map_err(|e| e.0)
}
}