1use serde::Deserialize;
2use tokio::time::Interval;
3
4use super::Collect;
5use crate::common::{Bag, ConfigInto, FromConfig, FromPath, Period};
6use async_trait::async_trait;
7
8#[async_trait]
10pub trait BagCollect<T, B>
11where
12 T: Send + 'static,
13 B: Bag<T> + Send,
14{
15 fn get_bag(&mut self) -> &mut B;
16
17 async fn bag_collect(&mut self, t: T) -> anyhow::Result<()> {
19 let b = self.get_bag();
20 b.collect(t).await
21 }
22
23 async fn flush_bag(&mut self) -> anyhow::Result<Vec<T>> {
25 let bag = self.get_bag();
26 let bag = bag.flush().await?;
27 Ok(bag)
28 }
29}
30
31#[derive(Deserialize)]
32pub struct InMemoryBagCollectorConfig {
33 pub flush_period: Period,
34}
35
36impl FromPath for InMemoryBagCollectorConfig {}
37
38#[async_trait]
39impl<T> ConfigInto<InMemoryBagCollector<T>> for InMemoryBagCollectorConfig {}
40
41pub struct InMemoryBagCollector<T> {
43 pub flush_period: Period,
45 pub buffer: Vec<T>,
46}
47
48#[async_trait]
49impl<T> FromConfig<InMemoryBagCollectorConfig> for InMemoryBagCollector<T> {
50 async fn from_config(config: InMemoryBagCollectorConfig) -> anyhow::Result<Self> {
51 Ok(InMemoryBagCollector {
52 flush_period: config.flush_period,
53 buffer: vec![],
54 })
55 }
56}
57
58#[async_trait]
59impl<T> BagCollect<T, Vec<T>> for InMemoryBagCollector<T>
60where
61 T: Clone + Send + 'static,
62{
63 fn get_bag(&mut self) -> &mut Vec<T> {
64 &mut self.buffer
65 }
66}
67
68#[async_trait]
72impl<T> Collect<T, Vec<T>, InMemoryBagCollectorConfig> for InMemoryBagCollector<T>
73where
74 T: Clone + Send + 'static,
75{
76 async fn collect(&mut self, t: T) -> anyhow::Result<()> {
77 self.bag_collect(t).await
78 }
79
80 async fn flush(&mut self) -> anyhow::Result<Option<Vec<T>>> {
81 let bag = self.flush_bag().await?;
82 if bag.is_empty() {
83 return Ok(None);
84 }
85 return Ok(Some(bag));
86 }
87
88 fn get_flush_interval(&self) -> Interval {
90 let flush_period = self.flush_period.clone();
91 tokio::time::interval(flush_period.into())
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use crate::prelude::*;
98 use tokio::sync::mpsc::Receiver;
99
100 #[derive(Clone, Debug)]
101 struct Record {
102 pub key: String,
103 pub val: i32,
104 }
105
106 async fn receive_records(rx: &mut Receiver<Vec<Record>>) -> Vec<Record> {
107 let mut all_records: Vec<Record> = vec![];
108 loop {
109 match rx.recv().await {
110 Some(records) => all_records.extend(records),
111 None => return all_records,
112 }
113 }
114 }
115
116 #[tokio::test]
117 async fn test_in_mem_bag_collector() {
118 let (tx0, rx0) = channel!(Record, 10);
119 let (tx1, mut rx1) = channel!(Vec<Record>, 10);
120 let channels = pipe_channels!(rx0, [tx1]);
121 let config = config!(
122 InMemoryBagCollectorConfig,
123 "resources/catalogs/bag_collector.yml"
124 );
125 let pipe = collector!("bag_collector");
126 let context = pipe.get_context();
127 let ph = populate_records(
128 tx0,
129 vec![
130 Record {
131 key: "0".to_owned(),
132 val: 0,
133 },
134 Record {
135 key: "1".to_owned(),
136 val: 1,
137 },
138 Record {
139 key: "2".to_owned(),
140 val: 2,
141 },
142 ],
143 );
144 ph.await;
145 join_pipes!([run_pipe!(pipe, config, channels)]);
146 let records = receive_records(&mut rx1).await;
147 assert_eq!(3, records.len());
148 assert_eq!(0, records.get(0).unwrap().val);
149 assert_eq!(1, records.get(1).unwrap().val);
150 assert_eq!(2, records.get(2).unwrap().val);
151 assert_eq!(State::Done, context.get_state());
152 }
153
154 #[tokio::test]
155 async fn test_collector_exit() {
156 let (tx0, rx0) = channel!(u128, 1024);
157 let (tx1, rx1) = channel!(Vec<u128>, 1024);
158 let channels0 = pipe_channels!([tx0]);
159 let channels1 = pipe_channels!(rx0, [tx1]);
160 let config0 = config!(TimerConfig, "resources/catalogs/timer.yml");
161 let config1 = config!(
162 InMemoryBagCollectorConfig,
163 "resources/catalogs/bag_collector.yml"
164 );
165 let timer = poller!("timer");
166 let collector = collector!("tick_collector");
167 let run_timer = run_pipe!(timer, config0, channels0);
168 let run_collector = run_pipe!(collector, config1, channels1);
169 let start_millis = std::time::SystemTime::now();
170 drop(rx1);
171 join_pipes!([run_timer, run_collector]);
172 let now_millis = std::time::SystemTime::now();
173 let duration = now_millis.duration_since(start_millis).unwrap();
175 assert!(duration.as_secs() < 10)
176 }
177}