use std::collections::HashSet;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use async_trait::async_trait;
use crate::error::Result;
use crate::frames::{Frame, FrameDirection, FrameHandler, FrameKind, FrameProcessor};
use super::bus::{AgentBus, BusMessage, BusPayload};
fn default_excluded(kind: FrameKind) -> bool {
matches!(
kind,
FrameKind::Start
| FrameKind::End
| FrameKind::Cancel
| FrameKind::Stop
| FrameKind::Error
| FrameKind::Heartbeat
| FrameKind::EndTask
| FrameKind::CancelTask
| FrameKind::StopTask
| FrameKind::InterruptionTask
| FrameKind::PauseProcessor
| FrameKind::PauseProcessorUrgent
| FrameKind::ResumeProcessor
| FrameKind::ResumeProcessorUrgent
)
}
struct EdgeInner {
source_name: String,
targets: Vec<String>,
exclude: HashSet<FrameKind>,
active: tokio::sync::OnceCell<Arc<AtomicBool>>,
bus: tokio::sync::OnceCell<Arc<dyn AgentBus>>,
}
impl EdgeInner {
fn gate_open(&self) -> bool {
self.active
.get()
.map(|a| a.load(Ordering::Relaxed))
.unwrap_or(true)
}
fn should_publish(&self, frame: &Frame) -> bool {
let kind = frame.kind();
!default_excluded(kind) && !self.exclude.contains(&kind)
}
}
#[derive(Clone)]
pub struct BusOutputEdge {
inner: Arc<EdgeInner>,
}
impl BusOutputEdge {
pub fn new(source_name: impl Into<String>, targets: Vec<String>) -> Self {
Self {
inner: Arc::new(EdgeInner {
source_name: source_name.into(),
targets,
exclude: HashSet::new(),
active: tokio::sync::OnceCell::new(),
bus: tokio::sync::OnceCell::new(),
}),
}
}
pub fn with_exclude(
source_name: impl Into<String>,
targets: Vec<String>,
exclude: HashSet<FrameKind>,
) -> Self {
Self {
inner: Arc::new(EdgeInner {
source_name: source_name.into(),
targets,
exclude,
active: tokio::sync::OnceCell::new(),
bus: tokio::sync::OnceCell::new(),
}),
}
}
pub fn set_bus(&self, bus: Arc<dyn AgentBus>) {
let _ = self.inner.bus.set(bus);
}
pub fn bind_activation(&self, active: Arc<AtomicBool>) {
let _ = self.inner.active.set(active);
}
pub fn to_processor(&self) -> FrameProcessor {
FrameProcessor::new(
format!("BusOutputEdge({})", self.inner.source_name),
Box::new(BusOutputEdgeHandler {
inner: self.inner.clone(),
}),
true, )
}
}
struct BusOutputEdgeHandler {
inner: Arc<EdgeInner>,
}
#[async_trait]
impl FrameHandler for BusOutputEdgeHandler {
async fn on_process_frame(
&self,
processor: &FrameProcessor,
frame: Frame,
direction: FrameDirection,
) -> Result<()> {
if self.inner.should_publish(&frame) && self.inner.gate_open() {
if let Some(bus) = self.inner.bus.get() {
if self.inner.targets.is_empty() {
bus.send(BusMessage::new(
self.inner.source_name.clone(),
None,
BusPayload::Frame {
frame: frame.clone(),
direction,
},
))
.await;
} else {
for target in &self.inner.targets {
bus.send(BusMessage::new(
self.inner.source_name.clone(),
Some(target.clone()),
BusPayload::Frame {
frame: frame.clone(),
direction,
},
))
.await;
}
}
}
}
processor.push_frame(frame, direction).await
}
}