use std::sync::Mutex;
use crate::types::AgentMessage;
pub trait MessageProvider: Send + Sync {
fn poll_steering(&self) -> Vec<AgentMessage>;
fn poll_follow_up(&self) -> Vec<AgentMessage>;
fn has_steering(&self) -> bool {
false
}
}
pub struct FnMessageProvider<S, F>
where
S: Fn() -> Vec<AgentMessage> + Send + Sync,
F: Fn() -> Vec<AgentMessage> + Send + Sync,
{
steering: S,
follow_up: F,
}
impl<S, F> MessageProvider for FnMessageProvider<S, F>
where
S: Fn() -> Vec<AgentMessage> + Send + Sync,
F: Fn() -> Vec<AgentMessage> + Send + Sync,
{
fn poll_steering(&self) -> Vec<AgentMessage> {
(self.steering)()
}
fn poll_follow_up(&self) -> Vec<AgentMessage> {
(self.follow_up)()
}
}
pub const fn from_fns<S, F>(steering: S, follow_up: F) -> FnMessageProvider<S, F>
where
S: Fn() -> Vec<AgentMessage> + Send + Sync,
F: Fn() -> Vec<AgentMessage> + Send + Sync,
{
FnMessageProvider {
steering,
follow_up,
}
}
#[derive(Clone)]
pub struct MessageSender {
steering_tx: tokio::sync::mpsc::UnboundedSender<AgentMessage>,
follow_up_tx: tokio::sync::mpsc::UnboundedSender<AgentMessage>,
}
impl MessageSender {
pub fn send_steering(&self, message: AgentMessage) -> bool {
self.steering_tx.send(message).is_ok()
}
pub fn send_follow_up(&self, message: AgentMessage) -> bool {
self.follow_up_tx.send(message).is_ok()
}
pub fn send(&self, message: AgentMessage) -> bool {
self.send_follow_up(message)
}
}
impl std::fmt::Debug for MessageSender {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MessageSender").finish_non_exhaustive()
}
}
pub struct ChannelMessageProvider {
steering_rx: Mutex<tokio::sync::mpsc::UnboundedReceiver<AgentMessage>>,
follow_up_rx: Mutex<tokio::sync::mpsc::UnboundedReceiver<AgentMessage>>,
}
impl ChannelMessageProvider {
fn drain_receiver(
rx: &Mutex<tokio::sync::mpsc::UnboundedReceiver<AgentMessage>>,
) -> Vec<AgentMessage> {
let mut guard = rx.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
let mut messages = Vec::new();
while let Ok(msg) = guard.try_recv() {
messages.push(msg);
}
messages
}
}
impl MessageProvider for ChannelMessageProvider {
fn poll_steering(&self) -> Vec<AgentMessage> {
Self::drain_receiver(&self.steering_rx)
}
fn poll_follow_up(&self) -> Vec<AgentMessage> {
Self::drain_receiver(&self.follow_up_rx)
}
fn has_steering(&self) -> bool {
let guard = self
.steering_rx
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
!guard.is_empty()
}
}
pub struct ComposedMessageProvider {
primary: std::sync::Arc<dyn MessageProvider>,
secondary: std::sync::Arc<dyn MessageProvider>,
}
impl ComposedMessageProvider {
pub fn new(
primary: std::sync::Arc<dyn MessageProvider>,
secondary: std::sync::Arc<dyn MessageProvider>,
) -> Self {
Self { primary, secondary }
}
}
impl MessageProvider for ComposedMessageProvider {
fn poll_steering(&self) -> Vec<AgentMessage> {
let mut msgs = self.primary.poll_steering();
msgs.extend(self.secondary.poll_steering());
msgs
}
fn poll_follow_up(&self) -> Vec<AgentMessage> {
let mut msgs = self.primary.poll_follow_up();
msgs.extend(self.secondary.poll_follow_up());
msgs
}
fn has_steering(&self) -> bool {
self.primary.has_steering() || self.secondary.has_steering()
}
}
pub fn message_channel() -> (ChannelMessageProvider, MessageSender) {
let (steering_tx, steering_rx) = tokio::sync::mpsc::unbounded_channel();
let (follow_up_tx, follow_up_rx) = tokio::sync::mpsc::unbounded_channel();
let provider = ChannelMessageProvider {
steering_rx: Mutex::new(steering_rx),
follow_up_rx: Mutex::new(follow_up_rx),
};
let sender = MessageSender {
steering_tx,
follow_up_tx,
};
(provider, sender)
}