use std::time::Duration;
use tokio::sync::broadcast;
use crate::channels::{ComponentEvent, ComponentStatus};
pub async fn wait_for_component_status(
event_rx: &mut broadcast::Receiver<ComponentEvent>,
component_id: &str,
expected_status: ComponentStatus,
timeout: Duration,
) -> ComponentEvent {
wait_for_event(
event_rx,
|e| e.component_id == component_id && e.status == expected_status,
timeout,
&format!("component '{component_id}' to reach {expected_status:?}"),
)
.await
}
pub async fn wait_for_event(
event_rx: &mut broadcast::Receiver<ComponentEvent>,
predicate: impl Fn(&ComponentEvent) -> bool,
timeout: Duration,
description: &str,
) -> ComponentEvent {
match tokio::time::timeout(timeout, async {
loop {
match event_rx.recv().await {
Ok(event) if predicate(&event) => return event,
Ok(_) => continue,
Err(broadcast::error::RecvError::Lagged(n)) => {
eprintln!("Warning: event receiver lagged by {n} events");
continue;
}
Err(broadcast::error::RecvError::Closed) => {
panic!("Event channel closed while waiting for {description}");
}
}
}
})
.await
{
Ok(event) => event,
Err(_) => panic!("Timed out ({timeout:?}) waiting for {description}"),
}
}
pub async fn wait_for_all_statuses(
event_rx: &mut broadcast::Receiver<ComponentEvent>,
targets: &[(&str, ComponentStatus)],
timeout: Duration,
) {
let mut remaining: Vec<(String, ComponentStatus)> =
targets.iter().map(|(id, s)| (id.to_string(), *s)).collect();
match tokio::time::timeout(timeout, async {
while !remaining.is_empty() {
match event_rx.recv().await {
Ok(event) => {
remaining.retain(|(id, status)| {
!(event.component_id == *id && event.status == *status)
});
}
Err(broadcast::error::RecvError::Lagged(n)) => {
eprintln!("Warning: event receiver lagged by {n} events");
continue;
}
Err(broadcast::error::RecvError::Closed) => {
panic!(
"Event channel closed while waiting for statuses. Remaining: {remaining:?}",
);
}
}
}
})
.await
{
Ok(()) => {}
Err(_) => panic!(
"Timed out ({timeout:?}) waiting for component statuses. Remaining: {remaining:?}",
),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::channels::ComponentType;
fn make_event(id: &str, status: ComponentStatus) -> ComponentEvent {
ComponentEvent {
component_id: id.to_string(),
component_type: ComponentType::Source,
status,
timestamp: chrono::Utc::now(),
message: None,
}
}
#[tokio::test]
async fn test_wait_for_component_status_immediate() {
let (tx, mut rx) = broadcast::channel(16);
tx.send(make_event("s1", ComponentStatus::Running)).unwrap();
let event = wait_for_component_status(
&mut rx,
"s1",
ComponentStatus::Running,
Duration::from_secs(1),
)
.await;
assert_eq!(event.component_id, "s1");
}
#[tokio::test]
async fn test_wait_for_component_status_skips_non_matching() {
let (tx, mut rx) = broadcast::channel(16);
tx.send(make_event("s1", ComponentStatus::Starting))
.unwrap();
tx.send(make_event("s2", ComponentStatus::Running)).unwrap();
tx.send(make_event("s1", ComponentStatus::Running)).unwrap();
let event = wait_for_component_status(
&mut rx,
"s1",
ComponentStatus::Running,
Duration::from_secs(1),
)
.await;
assert_eq!(event.component_id, "s1");
assert_eq!(event.status, ComponentStatus::Running);
}
#[tokio::test]
async fn test_wait_for_all_statuses() {
let (tx, mut rx) = broadcast::channel(16);
tx.send(make_event("s1", ComponentStatus::Running)).unwrap();
tx.send(make_event("s2", ComponentStatus::Running)).unwrap();
wait_for_all_statuses(
&mut rx,
&[("s1", ComponentStatus::Running), ("s2", ComponentStatus::Running)],
Duration::from_secs(1),
)
.await;
}
#[tokio::test]
#[should_panic(expected = "Timed out")]
async fn test_wait_for_component_status_timeout() {
let (_tx, mut rx) = broadcast::channel::<ComponentEvent>(16);
wait_for_component_status(
&mut rx,
"s1",
ComponentStatus::Running,
Duration::from_millis(50),
)
.await;
}
}