pipebase/collect/
bag.rs

1use serde::Deserialize;
2use tokio::time::Interval;
3
4use super::Collect;
5use crate::common::{Bag, ConfigInto, FromConfig, FromPath, Period};
6use async_trait::async_trait;
7
8/// Collect items
9#[async_trait]
10pub trait BagCollect<T, B>
11where
12    T: Send + 'static,
13    B: Bag<T> + Send,
14{
15    fn get_bag(&mut self) -> &mut B;
16
17    /// Collect item
18    async fn bag_collect(&mut self, t: T) -> anyhow::Result<()> {
19        let b = self.get_bag();
20        b.collect(t).await
21    }
22
23    /// Flush bag and return items
24    async fn flush_bag(&mut self) -> anyhow::Result<Vec<T>> {
25        let bag = self.get_bag();
26        let bag = bag.flush().await?;
27        Ok(bag)
28    }
29}
30
31#[derive(Deserialize)]
32pub struct InMemoryBagCollectorConfig {
33    pub flush_period: Period,
34}
35
36impl FromPath for InMemoryBagCollectorConfig {}
37
38#[async_trait]
39impl<T> ConfigInto<InMemoryBagCollector<T>> for InMemoryBagCollectorConfig {}
40
41/// In memory cache items
42pub struct InMemoryBagCollector<T> {
43    /// Caller should flush cache every flush_period
44    pub flush_period: Period,
45    pub buffer: Vec<T>,
46}
47
48#[async_trait]
49impl<T> FromConfig<InMemoryBagCollectorConfig> for InMemoryBagCollector<T> {
50    async fn from_config(config: InMemoryBagCollectorConfig) -> anyhow::Result<Self> {
51        Ok(InMemoryBagCollector {
52            flush_period: config.flush_period,
53            buffer: vec![],
54        })
55    }
56}
57
58#[async_trait]
59impl<T> BagCollect<T, Vec<T>> for InMemoryBagCollector<T>
60where
61    T: Clone + Send + 'static,
62{
63    fn get_bag(&mut self) -> &mut Vec<T> {
64        &mut self.buffer
65    }
66}
67
68/// # Parameters
69/// * T: input
70/// * Vec<T>: output
71#[async_trait]
72impl<T> Collect<T, Vec<T>, InMemoryBagCollectorConfig> for InMemoryBagCollector<T>
73where
74    T: Clone + Send + 'static,
75{
76    async fn collect(&mut self, t: T) -> anyhow::Result<()> {
77        self.bag_collect(t).await
78    }
79
80    async fn flush(&mut self) -> anyhow::Result<Option<Vec<T>>> {
81        let bag = self.flush_bag().await?;
82        if bag.is_empty() {
83            return Ok(None);
84        }
85        return Ok(Some(bag));
86    }
87
88    /// Call by collector pipe to flush bag in period
89    fn get_flush_interval(&self) -> Interval {
90        let flush_period = self.flush_period.clone();
91        tokio::time::interval(flush_period.into())
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use crate::prelude::*;
98    use tokio::sync::mpsc::Receiver;
99
100    #[derive(Clone, Debug)]
101    struct Record {
102        pub key: String,
103        pub val: i32,
104    }
105
106    async fn receive_records(rx: &mut Receiver<Vec<Record>>) -> Vec<Record> {
107        let mut all_records: Vec<Record> = vec![];
108        loop {
109            match rx.recv().await {
110                Some(records) => all_records.extend(records),
111                None => return all_records,
112            }
113        }
114    }
115
116    #[tokio::test]
117    async fn test_in_mem_bag_collector() {
118        let (tx0, rx0) = channel!(Record, 10);
119        let (tx1, mut rx1) = channel!(Vec<Record>, 10);
120        let channels = pipe_channels!(rx0, [tx1]);
121        let config = config!(
122            InMemoryBagCollectorConfig,
123            "resources/catalogs/bag_collector.yml"
124        );
125        let pipe = collector!("bag_collector");
126        let context = pipe.get_context();
127        let ph = populate_records(
128            tx0,
129            vec![
130                Record {
131                    key: "0".to_owned(),
132                    val: 0,
133                },
134                Record {
135                    key: "1".to_owned(),
136                    val: 1,
137                },
138                Record {
139                    key: "2".to_owned(),
140                    val: 2,
141                },
142            ],
143        );
144        ph.await;
145        join_pipes!([run_pipe!(pipe, config, channels)]);
146        let records = receive_records(&mut rx1).await;
147        assert_eq!(3, records.len());
148        assert_eq!(0, records.get(0).unwrap().val);
149        assert_eq!(1, records.get(1).unwrap().val);
150        assert_eq!(2, records.get(2).unwrap().val);
151        assert_eq!(State::Done, context.get_state());
152    }
153
154    #[tokio::test]
155    async fn test_collector_exit() {
156        let (tx0, rx0) = channel!(u128, 1024);
157        let (tx1, rx1) = channel!(Vec<u128>, 1024);
158        let channels0 = pipe_channels!([tx0]);
159        let channels1 = pipe_channels!(rx0, [tx1]);
160        let config0 = config!(TimerConfig, "resources/catalogs/timer.yml");
161        let config1 = config!(
162            InMemoryBagCollectorConfig,
163            "resources/catalogs/bag_collector.yml"
164        );
165        let timer = poller!("timer");
166        let collector = collector!("tick_collector");
167        let run_timer = run_pipe!(timer, config0, channels0);
168        let run_collector = run_pipe!(collector, config1, channels1);
169        let start_millis = std::time::SystemTime::now();
170        drop(rx1);
171        join_pipes!([run_timer, run_collector]);
172        let now_millis = std::time::SystemTime::now();
173        // timer and collector should exit asap since downstream rx1 dropped
174        let duration = now_millis.duration_since(start_millis).unwrap();
175        assert!(duration.as_secs() < 10)
176    }
177}