1use std::collections::HashSet;
21use std::time::Duration;
22
23#[derive(Clone, Default)]
29pub enum Sub<Msg> {
30 #[default]
32 None,
33
34 Batch(Vec<Sub<Msg>>),
36
37 Interval {
43 id: &'static str,
45 duration: Duration,
47 msg: Msg,
49 },
50}
51
52impl<Msg: Clone> Sub<Msg> {
53 #[inline]
55 pub fn none() -> Self {
56 Sub::None
57 }
58
59 pub fn batch(subs: impl IntoIterator<Item = Sub<Msg>>) -> Self {
61 let subs: Vec<_> = subs.into_iter().collect();
62 if subs.is_empty() {
63 Sub::None
64 } else if subs.len() == 1 {
65 subs.into_iter().next().unwrap()
66 } else {
67 Sub::Batch(subs)
68 }
69 }
70
71 pub fn interval(id: &'static str, duration: Duration, msg: Msg) -> Self {
92 Sub::Interval { id, duration, msg }
93 }
94
95 pub(crate) fn collect_interval_ids(&self, ids: &mut HashSet<&'static str>) {
97 match self {
98 Sub::None => {}
99 Sub::Batch(subs) => {
100 for sub in subs {
101 sub.collect_interval_ids(ids);
102 }
103 }
104 Sub::Interval { id, .. } => {
105 ids.insert(id);
106 }
107 }
108 }
109
110 pub(crate) fn intervals(&self) -> Vec<(&'static str, Duration, Msg)> {
112 let mut result = Vec::new();
113 self.collect_intervals(&mut result);
114 result
115 }
116
117 fn collect_intervals(&self, result: &mut Vec<(&'static str, Duration, Msg)>) {
118 match self {
119 Sub::None => {}
120 Sub::Batch(subs) => {
121 for sub in subs {
122 sub.collect_intervals(result);
123 }
124 }
125 Sub::Interval { id, duration, msg } => {
126 result.push((id, *duration, msg.clone()));
127 }
128 }
129 }
130}
131
132impl<Msg> Sub<Msg> {
137 #[inline]
139 pub fn is_none(&self) -> bool {
140 matches!(self, Sub::None)
141 }
142
143 #[inline]
145 pub fn is_interval(&self) -> bool {
146 matches!(self, Sub::Interval { .. })
147 }
148
149 #[inline]
151 pub fn is_batch(&self) -> bool {
152 matches!(self, Sub::Batch(_))
153 }
154
155 pub fn len(&self) -> usize {
157 match self {
158 Sub::None => 0,
159 Sub::Batch(subs) => subs.iter().map(|s| s.len()).sum(),
160 Sub::Interval { .. } => 1,
161 }
162 }
163
164 pub fn is_empty(&self) -> bool {
166 self.len() == 0
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[test]
175 fn test_sub_none() {
176 let sub: Sub<i32> = Sub::none();
177 assert!(sub.is_none());
178 assert_eq!(sub.len(), 0);
179 }
180
181 #[test]
182 fn test_sub_interval() {
183 let sub = Sub::interval("test", Duration::from_secs(1), 42);
184 assert!(sub.is_interval());
185 assert_eq!(sub.len(), 1);
186 }
187
188 #[test]
189 fn test_sub_batch() {
190 let sub = Sub::batch([
191 Sub::interval("a", Duration::from_secs(1), 1),
192 Sub::interval("b", Duration::from_secs(2), 2),
193 ]);
194 assert!(sub.is_batch());
195 assert_eq!(sub.len(), 2);
196 }
197
198 #[test]
199 fn test_collect_interval_ids() {
200 let sub = Sub::batch([
201 Sub::interval("a", Duration::from_secs(1), 1),
202 Sub::interval("b", Duration::from_secs(2), 2),
203 Sub::none(),
204 ]);
205
206 let mut ids = HashSet::new();
207 sub.collect_interval_ids(&mut ids);
208
209 assert!(ids.contains("a"));
210 assert!(ids.contains("b"));
211 assert_eq!(ids.len(), 2);
212 }
213
214 #[test]
215 fn test_batch_flattening() {
216 let empty: Sub<i32> = Sub::batch([]);
218 assert!(empty.is_none());
219
220 let single = Sub::batch([Sub::interval("x", Duration::from_secs(1), 1)]);
222 assert!(single.is_interval());
223 }
224
225 #[test]
226 fn test_intervals_iterator() {
227 let sub = Sub::batch([
228 Sub::interval("a", Duration::from_secs(1), 10),
229 Sub::interval("b", Duration::from_secs(2), 20),
230 ]);
231
232 let intervals = sub.intervals();
233 assert_eq!(intervals.len(), 2);
234 assert_eq!(intervals[0], ("a", Duration::from_secs(1), 10));
235 assert_eq!(intervals[1], ("b", Duration::from_secs(2), 20));
236 }
237}