use std::collections::HashSet;
use std::time::Duration;
#[derive(Clone, Default)]
pub enum Sub<Msg> {
#[default]
None,
Batch(Vec<Sub<Msg>>),
Interval {
id: &'static str,
duration: Duration,
msg: Msg,
},
}
impl<Msg: Clone> Sub<Msg> {
#[inline]
pub fn none() -> Self {
Sub::None
}
pub fn batch(subs: impl IntoIterator<Item = Sub<Msg>>) -> Self {
let subs: Vec<_> = subs.into_iter().collect();
if subs.is_empty() {
Sub::None
} else if subs.len() == 1 {
subs.into_iter().next().unwrap()
} else {
Sub::Batch(subs)
}
}
pub fn interval(id: &'static str, duration: Duration, msg: Msg) -> Self {
Sub::Interval { id, duration, msg }
}
pub(crate) fn collect_interval_ids(&self, ids: &mut HashSet<&'static str>) {
match self {
Sub::None => {}
Sub::Batch(subs) => {
for sub in subs {
sub.collect_interval_ids(ids);
}
}
Sub::Interval { id, .. } => {
ids.insert(id);
}
}
}
pub(crate) fn intervals(&self) -> Vec<(&'static str, Duration, Msg)> {
let mut result = Vec::new();
self.collect_intervals(&mut result);
result
}
fn collect_intervals(&self, result: &mut Vec<(&'static str, Duration, Msg)>) {
match self {
Sub::None => {}
Sub::Batch(subs) => {
for sub in subs {
sub.collect_intervals(result);
}
}
Sub::Interval { id, duration, msg } => {
result.push((id, *duration, msg.clone()));
}
}
}
}
impl<Msg> Sub<Msg> {
#[inline]
pub fn is_none(&self) -> bool {
matches!(self, Sub::None)
}
#[inline]
pub fn is_interval(&self) -> bool {
matches!(self, Sub::Interval { .. })
}
#[inline]
pub fn is_batch(&self) -> bool {
matches!(self, Sub::Batch(_))
}
pub fn len(&self) -> usize {
match self {
Sub::None => 0,
Sub::Batch(subs) => subs.iter().map(|s| s.len()).sum(),
Sub::Interval { .. } => 1,
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sub_none() {
let sub: Sub<i32> = Sub::none();
assert!(sub.is_none());
assert_eq!(sub.len(), 0);
}
#[test]
fn test_sub_interval() {
let sub = Sub::interval("test", Duration::from_secs(1), 42);
assert!(sub.is_interval());
assert_eq!(sub.len(), 1);
}
#[test]
fn test_sub_batch() {
let sub = Sub::batch([
Sub::interval("a", Duration::from_secs(1), 1),
Sub::interval("b", Duration::from_secs(2), 2),
]);
assert!(sub.is_batch());
assert_eq!(sub.len(), 2);
}
#[test]
fn test_collect_interval_ids() {
let sub = Sub::batch([
Sub::interval("a", Duration::from_secs(1), 1),
Sub::interval("b", Duration::from_secs(2), 2),
Sub::none(),
]);
let mut ids = HashSet::new();
sub.collect_interval_ids(&mut ids);
assert!(ids.contains("a"));
assert!(ids.contains("b"));
assert_eq!(ids.len(), 2);
}
#[test]
fn test_batch_flattening() {
let empty: Sub<i32> = Sub::batch([]);
assert!(empty.is_none());
let single = Sub::batch([Sub::interval("x", Duration::from_secs(1), 1)]);
assert!(single.is_interval());
}
#[test]
fn test_intervals_iterator() {
let sub = Sub::batch([
Sub::interval("a", Duration::from_secs(1), 10),
Sub::interval("b", Duration::from_secs(2), 20),
]);
let intervals = sub.intervals();
assert_eq!(intervals.len(), 2);
assert_eq!(intervals[0], ("a", Duration::from_secs(1), 10));
assert_eq!(intervals[1], ("b", Duration::from_secs(2), 20));
}
}