1use std::collections::HashSet;
21use std::time::Duration;
22
23#[derive(Clone)]
29pub enum Sub<Msg> {
30 None,
32
33 Batch(Vec<Sub<Msg>>),
35
36 Interval {
42 id: &'static str,
44 duration: Duration,
46 msg: Msg,
48 },
49}
50
51impl<Msg> Default for Sub<Msg> {
52 fn default() -> Self {
53 Sub::None
54 }
55}
56
57impl<Msg: Clone> Sub<Msg> {
58 #[inline]
60 pub fn none() -> Self {
61 Sub::None
62 }
63
64 pub fn batch(subs: impl IntoIterator<Item = Sub<Msg>>) -> Self {
66 let subs: Vec<_> = subs.into_iter().collect();
67 if subs.is_empty() {
68 Sub::None
69 } else if subs.len() == 1 {
70 subs.into_iter().next().unwrap()
71 } else {
72 Sub::Batch(subs)
73 }
74 }
75
76 pub fn interval(id: &'static str, duration: Duration, msg: Msg) -> Self {
97 Sub::Interval { id, duration, msg }
98 }
99
100 pub(crate) fn collect_interval_ids(&self, ids: &mut HashSet<&'static str>) {
102 match self {
103 Sub::None => {}
104 Sub::Batch(subs) => {
105 for sub in subs {
106 sub.collect_interval_ids(ids);
107 }
108 }
109 Sub::Interval { id, .. } => {
110 ids.insert(id);
111 }
112 }
113 }
114
115 pub(crate) fn intervals(&self) -> Vec<(&'static str, Duration, Msg)> {
117 let mut result = Vec::new();
118 self.collect_intervals(&mut result);
119 result
120 }
121
122 fn collect_intervals(&self, result: &mut Vec<(&'static str, Duration, Msg)>) {
123 match self {
124 Sub::None => {}
125 Sub::Batch(subs) => {
126 for sub in subs {
127 sub.collect_intervals(result);
128 }
129 }
130 Sub::Interval { id, duration, msg } => {
131 result.push((id, *duration, msg.clone()));
132 }
133 }
134 }
135}
136
137impl<Msg> Sub<Msg> {
142 #[inline]
144 pub fn is_none(&self) -> bool {
145 matches!(self, Sub::None)
146 }
147
148 #[inline]
150 pub fn is_interval(&self) -> bool {
151 matches!(self, Sub::Interval { .. })
152 }
153
154 #[inline]
156 pub fn is_batch(&self) -> bool {
157 matches!(self, Sub::Batch(_))
158 }
159
160 pub fn len(&self) -> usize {
162 match self {
163 Sub::None => 0,
164 Sub::Batch(subs) => subs.iter().map(|s| s.len()).sum(),
165 Sub::Interval { .. } => 1,
166 }
167 }
168
169 pub fn is_empty(&self) -> bool {
171 self.len() == 0
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178
179 #[test]
180 fn test_sub_none() {
181 let sub: Sub<i32> = Sub::none();
182 assert!(sub.is_none());
183 assert_eq!(sub.len(), 0);
184 }
185
186 #[test]
187 fn test_sub_interval() {
188 let sub = Sub::interval("test", Duration::from_secs(1), 42);
189 assert!(sub.is_interval());
190 assert_eq!(sub.len(), 1);
191 }
192
193 #[test]
194 fn test_sub_batch() {
195 let sub = Sub::batch([
196 Sub::interval("a", Duration::from_secs(1), 1),
197 Sub::interval("b", Duration::from_secs(2), 2),
198 ]);
199 assert!(sub.is_batch());
200 assert_eq!(sub.len(), 2);
201 }
202
203 #[test]
204 fn test_collect_interval_ids() {
205 let sub = Sub::batch([
206 Sub::interval("a", Duration::from_secs(1), 1),
207 Sub::interval("b", Duration::from_secs(2), 2),
208 Sub::none(),
209 ]);
210
211 let mut ids = HashSet::new();
212 sub.collect_interval_ids(&mut ids);
213
214 assert!(ids.contains("a"));
215 assert!(ids.contains("b"));
216 assert_eq!(ids.len(), 2);
217 }
218
219 #[test]
220 fn test_batch_flattening() {
221 let empty: Sub<i32> = Sub::batch([]);
223 assert!(empty.is_none());
224
225 let single = Sub::batch([Sub::interval("x", Duration::from_secs(1), 1)]);
227 assert!(single.is_interval());
228 }
229
230 #[test]
231 fn test_intervals_iterator() {
232 let sub = Sub::batch([
233 Sub::interval("a", Duration::from_secs(1), 10),
234 Sub::interval("b", Duration::from_secs(2), 20),
235 ]);
236
237 let intervals = sub.intervals();
238 assert_eq!(intervals.len(), 2);
239 assert_eq!(intervals[0], ("a", Duration::from_secs(1), 10));
240 assert_eq!(intervals[1], ("b", Duration::from_secs(2), 20));
241 }
242}