use crate::eval::Value;
use crate::diagnostics::{Error, Result};
use super::{ConcurrencyError, futures::Future};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, broadcast, watch};
use tokio::time::timeout;
use futures::FutureExt;
#[derive(Debug, Clone)]
pub struct Channel {
sender: ChannelSender,
receiver: Arc<tokio::sync::Mutex<ChannelReceiver>>,
}
#[derive(Debug, Clone)]
pub struct ChannelSender {
inner: SenderInner,
}
#[derive(Debug)]
pub struct ChannelReceiver {
inner: ReceiverInner,
}
#[derive(Debug, Clone)]
enum SenderInner {
Bounded(mpsc::Sender<Value>),
Unbounded(mpsc::UnboundedSender<Value>),
Broadcast(broadcast::Sender<Value>),
Watch(watch::Sender<Value>),
}
#[derive(Debug)]
enum ReceiverInner {
Bounded(mpsc::Receiver<Value>),
Unbounded(mpsc::UnboundedReceiver<Value>),
Broadcast(broadcast::Receiver<Value>),
Watch(watch::Receiver<Value>),
}
#[derive(Debug, Clone)]
pub struct ChannelConfig {
pub buffer_size: Option<usize>,
pub channel_type: ChannelType,
pub backpressure: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ChannelType {
MpscBounded,
MpscUnbounded,
Broadcast,
Watch,
}
impl Default for ChannelConfig {
fn default() -> Self {
Self {
buffer_size: Some(100),
channel_type: ChannelType::MpscBounded,
backpressure: true,
}
}
}
impl Channel {
pub fn new(config: ChannelConfig) -> Result<Self> {
match config.channel_type {
ChannelType::MpscBounded => {
let buffer_size = config.buffer_size
.ok_or_else(|| Error::runtime_error("Buffer size required for bounded channel".to_string(), None))?;
let (tx, rx) = mpsc::channel(buffer_size);
Ok(Self {
sender: ChannelSender { inner: SenderInner::Bounded(tx) },
receiver: Arc::new(tokio::sync::Mutex::new(ChannelReceiver { inner: ReceiverInner::Bounded(rx) })),
})
}
ChannelType::MpscUnbounded => {
let (tx, rx) = mpsc::unbounded_channel();
Ok(Self {
sender: ChannelSender { inner: SenderInner::Unbounded(tx) },
receiver: Arc::new(tokio::sync::Mutex::new(ChannelReceiver { inner: ReceiverInner::Unbounded(rx) })),
})
}
ChannelType::Broadcast => {
let capacity = config.buffer_size.unwrap_or(1000);
let (tx, rx) = broadcast::channel(capacity);
Ok(Self {
sender: ChannelSender { inner: SenderInner::Broadcast(tx) },
receiver: Arc::new(tokio::sync::Mutex::new(ChannelReceiver { inner: ReceiverInner::Broadcast(rx) })),
})
}
ChannelType::Watch => {
let (tx, rx) = watch::channel(Value::Unspecified);
Ok(Self {
sender: ChannelSender { inner: SenderInner::Watch(tx) },
receiver: Arc::new(tokio::sync::Mutex::new(ChannelReceiver { inner: ReceiverInner::Watch(rx) })),
})
}
}
}
pub fn bounded(buffer_size: usize) -> Result<Self> {
Self::new(ChannelConfig {
buffer_size: Some(buffer_size),
channel_type: ChannelType::MpscBounded,
backpressure: true,
})
}
pub fn unbounded() -> Result<Self> {
Self::new(ChannelConfig {
buffer_size: None,
channel_type: ChannelType::MpscUnbounded,
backpressure: false,
})
}
pub fn broadcast(capacity: usize) -> Result<Self> {
Self::new(ChannelConfig {
buffer_size: Some(capacity),
channel_type: ChannelType::Broadcast,
backpressure: false,
})
}
pub fn watch() -> Result<Self> {
Self::new(ChannelConfig {
buffer_size: None,
channel_type: ChannelType::Watch,
backpressure: false,
})
}
pub fn sender(&self) -> ChannelSender {
self.sender.clone()
}
pub fn receiver(&self) -> Arc<tokio::sync::Mutex<ChannelReceiver>> {
self.receiver.clone()
}
pub fn senders(&self, count: usize) -> Vec<ChannelSender> {
(0..count).map(|_| self.sender.clone()).collect()
}
pub fn subscribe(&self) -> Result<ChannelReceiver> {
match &self.sender.inner {
SenderInner::Broadcast(tx) => {
Ok(ChannelReceiver { inner: ReceiverInner::Broadcast(tx.subscribe()) })
}
SenderInner::Watch(tx) => {
Ok(ChannelReceiver { inner: ReceiverInner::Watch(tx.subscribe()) })
}
_ => Err(Box::new(Error::runtime_error("Channel type does not support subscriptions".to_string(), None))),
}
}
}
impl ChannelSender {
pub async fn send(&self, value: Value) -> Result<()> {
match &self.inner {
SenderInner::Bounded(tx) => {
tx.send(value).await
.map_err(|_| ConcurrencyError::ChannelClosed.boxed())
}
SenderInner::Unbounded(tx) => {
tx.send(value)
.map_err(|_| ConcurrencyError::ChannelClosed.boxed())
}
SenderInner::Broadcast(tx) => {
tx.send(value)
.map(|_| ())
.map_err(|_| ConcurrencyError::ChannelClosed.boxed())
}
SenderInner::Watch(tx) => {
tx.send(value)
.map_err(|_| ConcurrencyError::ChannelClosed.boxed())
}
}
}
pub fn try_send(&self, value: Value) -> Result<()> {
match &self.inner {
SenderInner::Bounded(tx) => {
tx.try_send(value)
.map_err(|e| match e {
mpsc::error::TrySendError::Closed(_) => ConcurrencyError::ChannelClosed.boxed(),
mpsc::error::TrySendError::Full(_) => Error::runtime_error("Channel full".to_string(), None).into(),
})
}
SenderInner::Unbounded(tx) => {
tx.send(value)
.map_err(|_| ConcurrencyError::ChannelClosed.boxed())
}
SenderInner::Broadcast(tx) => {
tx.send(value)
.map(|_| ())
.map_err(|_| ConcurrencyError::ChannelClosed.boxed())
}
SenderInner::Watch(tx) => {
tx.send(value)
.map_err(|_| ConcurrencyError::ChannelClosed.boxed())
}
}
}
pub async fn send_timeout(&self, value: Value, duration: Duration) -> Result<()> {
match timeout(duration, self.send(value)).await {
Ok(result) => result,
Err(_) => Err(ConcurrencyError::Timeout.into()),
}
}
pub fn is_closed(&self) -> bool {
match &self.inner {
SenderInner::Bounded(tx) => tx.is_closed(),
SenderInner::Unbounded(tx) => tx.is_closed(),
SenderInner::Broadcast(tx) => tx.receiver_count() == 0,
SenderInner::Watch(tx) => tx.receiver_count() == 0,
}
}
pub fn receiver_count(&self) -> usize {
match &self.inner {
SenderInner::Bounded(_) | SenderInner::Unbounded(_) => 1,
SenderInner::Broadcast(tx) => tx.receiver_count(),
SenderInner::Watch(tx) => tx.receiver_count(),
}
}
}
impl ChannelReceiver {
pub async fn recv(&mut self) -> Result<Value> {
match &mut self.inner {
ReceiverInner::Bounded(rx) => {
rx.recv().await
.ok_or_else(|| ConcurrencyError::ChannelClosed.boxed())
}
ReceiverInner::Unbounded(rx) => {
rx.recv().await
.ok_or_else(|| ConcurrencyError::ChannelClosed.boxed())
}
ReceiverInner::Broadcast(rx) => {
rx.recv().await
.map_err(|e| match e {
broadcast::error::RecvError::Closed => ConcurrencyError::ChannelClosed.into(),
broadcast::error::RecvError::Lagged(n) =>
Error::runtime_error(format!("Lagged behind by {n} messages"), None).into(),
})
}
ReceiverInner::Watch(rx) => {
rx.changed().await
.map_err(|_| ConcurrencyError::ChannelClosed.boxed())?;
Ok(rx.borrow().clone())
}
}
}
pub fn try_recv(&mut self) -> Result<Value> {
match &mut self.inner {
ReceiverInner::Bounded(rx) => {
rx.try_recv()
.map_err(|e| match e {
mpsc::error::TryRecvError::Empty => Error::runtime_error("Channel empty".to_string(), None).into(),
mpsc::error::TryRecvError::Disconnected => ConcurrencyError::ChannelClosed.boxed(),
})
}
ReceiverInner::Unbounded(rx) => {
rx.try_recv()
.map_err(|e| match e {
mpsc::error::TryRecvError::Empty => Error::runtime_error("Channel empty".to_string(), None).into(),
mpsc::error::TryRecvError::Disconnected => ConcurrencyError::ChannelClosed.boxed(),
})
}
ReceiverInner::Broadcast(rx) => {
rx.try_recv()
.map_err(|e| match e {
broadcast::error::TryRecvError::Empty => Error::runtime_error("Channel empty".to_string(), None).into(),
broadcast::error::TryRecvError::Closed => ConcurrencyError::ChannelClosed.boxed(),
broadcast::error::TryRecvError::Lagged(n) =>
Error::runtime_error(format!("Lagged behind by {n} messages"), None).into(),
})
}
ReceiverInner::Watch(rx) => {
match rx.has_changed() {
Ok(true) => Ok(rx.borrow_and_update().clone()),
Ok(false) => Err(Box::new(Error::runtime_error("No new value available".to_string(), None))),
Err(e) => Err(Box::new(Error::runtime_error(format!("Watch receiver error: {e}"), None))),
}
}
}
}
pub async fn recv_timeout(&mut self, duration: Duration) -> Result<Value> {
match timeout(duration, self.recv()).await {
Ok(result) => result,
Err(_) => Err(ConcurrencyError::Timeout.into()),
}
}
}
pub struct Select {
futures: Vec<SelectBranch>,
}
struct SelectBranch {
future: Future,
id: usize,
}
impl Select {
pub fn new() -> Self {
Self {
futures: Vec::new(),
}
}
pub fn recv(mut self, id: usize, receiver: Arc<tokio::sync::Mutex<ChannelReceiver>>) -> Self {
let future = Future::new(async move {
let mut rx = receiver.lock().await;
let value = rx.recv().await?;
Ok(Value::from_vec(vec![
Value::symbol_from_str("recv"),
Value::integer(id as i64),
value,
]))
});
self.futures.push(SelectBranch { future, id });
self
}
pub fn send(mut self, id: usize, sender: ChannelSender, value: Value) -> Self {
let future = Future::new(async move {
sender.send(value).await?;
Ok(Value::from_vec(vec![
Value::symbol_from_str("send"),
Value::integer(id as i64),
Value::Unspecified,
]))
});
self.futures.push(SelectBranch { future, id });
self
}
pub fn timeout(mut self, id: usize, duration: Duration) -> Self {
let future = Future::new(async move {
tokio::time::sleep(duration).await;
Ok(Value::from_vec(vec![
Value::symbol_from_str("timeout"),
Value::integer(id as i64),
Value::Unspecified,
]))
});
self.futures.push(SelectBranch { future, id });
self
}
pub async fn execute(self) -> Result<Value> {
if self.futures.is_empty() {
return Err(Box::new(Error::runtime_error("No operations in select".to_string(), None)))
}
let futures: Vec<_> = self.futures.into_iter()
.map(|branch| async move { branch.future.await_result().await }.boxed())
.collect();
futures::future::select_all(futures).await.0
}
}
impl Default for Select {
fn default() -> Self {
Self::new()
}
}
pub struct ChannelOps;
impl ChannelOps {
pub fn from_spec(_spec: Value) -> Result<Channel> {
let config = ChannelConfig::default();
Channel::new(config)
}
pub fn pipeline(stages: Vec<Box<dyn Fn(Value) -> Result<Value> + Send + Sync>>) -> Result<(ChannelSender, Arc<tokio::sync::Mutex<ChannelReceiver>>)> {
if stages.is_empty() {
return Err(Box::new(Error::runtime_error("Empty pipeline".to_string(), None)));
}
let first_channel = Channel::unbounded()?;
let mut current_receiver = first_channel.receiver();
let stages_len = stages.len();
for (i, stage) in stages.into_iter().enumerate() {
if i == stages_len - 1 {
break;
}
let next_channel = Channel::unbounded()?;
let next_sender = next_channel.sender();
let next_receiver = next_channel.receiver();
tokio::spawn(async move {
loop {
let mut rx = current_receiver.lock().await;
match rx.recv().await {
Ok(value) => {
match stage(value) {
Ok(transformed) => {
if next_sender.send(transformed).await.is_err() {
break; }
}
Err(_) => break, }
}
Err(_) => break, }
}
});
current_receiver = next_receiver;
}
Ok((first_channel.sender(), current_receiver))
}
}