use crate::prelude::{ErrorKind, Message, PortId};
use crate::types::{Data, DataMessage, DeserializerFn, LinkMessage};
use crate::{bail, Result};
use flume::TryRecvError;
use std::collections::HashMap;
use std::ops::Deref;
use std::sync::Arc;
use uhlc::Timestamp;
pub struct Inputs {
pub(crate) hmap: HashMap<PortId, Vec<flume::Receiver<LinkMessage>>>,
}
impl Deref for Inputs {
type Target = HashMap<PortId, Vec<flume::Receiver<LinkMessage>>>;
fn deref(&self) -> &Self::Target {
&self.hmap
}
}
impl Inputs {
pub(crate) fn new() -> Self {
Self {
hmap: HashMap::default(),
}
}
pub(crate) fn insert(&mut self, port_id: PortId, rx: flume::Receiver<LinkMessage>) {
self.hmap
.entry(port_id)
.or_insert_with(Vec::default)
.push(rx)
}
pub fn take(&mut self, port_id: impl AsRef<str>) -> Option<InputBuilder> {
self.hmap
.remove(port_id.as_ref())
.map(|receivers| InputBuilder {
port_id: port_id.as_ref().into(),
receivers,
})
}
}
pub struct InputBuilder {
pub(crate) port_id: PortId,
pub(crate) receivers: Vec<flume::Receiver<LinkMessage>>,
}
impl InputBuilder {
pub fn raw(self) -> InputRaw {
InputRaw {
port_id: self.port_id,
receivers: self.receivers,
}
}
pub fn typed<T>(
self,
deserializer: impl Fn(&[u8]) -> anyhow::Result<T> + Send + Sync + 'static,
) -> Input<T> {
Input {
input_raw: self.raw(),
deserializer: Arc::new(deserializer),
}
}
}
#[derive(Clone, Debug)]
pub struct InputRaw {
pub(crate) port_id: PortId,
pub(crate) receivers: Vec<flume::Receiver<LinkMessage>>,
}
impl InputRaw {
pub fn port_id(&self) -> &PortId {
&self.port_id
}
pub fn channels_count(&self) -> usize {
self.receivers.len()
}
pub fn try_recv(&self) -> Result<LinkMessage> {
for receiver in &self.receivers {
match receiver.try_recv() {
Ok(message) => return Ok(message),
Err(e) => {
if matches!(e, TryRecvError::Disconnected) {
log::error!("[Input: {}] A channel is disconnected", self.port_id);
}
}
}
}
bail!(ErrorKind::Empty, "[Input: {}] No message", self.port_id)
}
pub async fn recv(&self) -> Result<LinkMessage> {
let mut recv_futures = self
.receivers
.iter()
.map(|link| link.recv_async())
.collect::<Vec<_>>();
loop {
let (res, _, remaining) = futures::future::select_all(recv_futures).await;
match res {
Ok(message) => return Ok(message),
Err(_disconnected) => {
log::error!("[Input: {}] A channel is disconnected", self.port_id);
if remaining.is_empty() {
bail!(
ErrorKind::Disconnected,
"[Input: {}] All channels are disconnected",
self.port_id
);
}
recv_futures = remaining;
}
}
}
}
}
pub struct Input<T> {
pub(crate) input_raw: InputRaw,
pub(crate) deserializer: Arc<DeserializerFn<T>>,
}
impl<T: Send + Sync + 'static> Deref for Input<T> {
type Target = InputRaw;
fn deref(&self) -> &Self::Target {
&self.input_raw
}
}
impl<T: Send + Sync + 'static> Input<T> {
pub async fn recv(&self) -> Result<(Message<T>, Timestamp)> {
match self.input_raw.recv().await? {
LinkMessage::Data(DataMessage { data, timestamp }) => Ok((
Message::Data(Data::try_from_payload(data, self.deserializer.clone())?),
timestamp,
)),
LinkMessage::Watermark(timestamp) => Ok((Message::Watermark, timestamp)),
}
}
pub fn try_recv(&self) -> Result<(Message<T>, Timestamp)> {
match self.input_raw.try_recv()? {
LinkMessage::Data(DataMessage { data, timestamp }) => Ok((
Message::Data(Data::try_from_payload(data, self.deserializer.clone())?),
timestamp,
)),
LinkMessage::Watermark(ts) => Ok((Message::Watermark, ts)),
}
}
}
#[cfg(test)]
#[path = "./tests/input-tests.rs"]
mod tests;