below_model/
collector_plugin.rs1use std::sync::Arc;
16use std::sync::Mutex;
17
18use async_trait::async_trait;
19
20use super::*;
21
22#[async_trait]
24pub trait AsyncCollectorPlugin {
25 type T;
26
27 async fn try_collect(&mut self) -> Result<Option<Self::T>>;
37}
38
39type SharedVal<T> = Arc<Mutex<Option<Result<T>>>>;
40
41pub struct AsyncCollector<T, Plugin: AsyncCollectorPlugin<T = T>> {
44 shared: SharedVal<T>,
45 plugin: Plugin,
46}
47
48impl<T, Plugin: AsyncCollectorPlugin<T = T>> AsyncCollector<T, Plugin> {
49 fn new(shared: SharedVal<T>, plugin: Plugin) -> Self {
50 Self { shared, plugin }
51 }
52
53 fn update(&self, value: Result<T>) {
54 *self.shared.lock().unwrap() = Some(value);
55 }
56
57 pub async fn collect_and_update(&mut self) -> Result<bool> {
64 let collect_result = self
65 .plugin
66 .try_collect()
67 .await
68 .context("Collector failed to read");
69
70 match collect_result {
71 Ok(Some(sample)) => {
72 self.update(Ok(sample));
73 Ok(true)
74 }
75 Ok(None) => Ok(false),
76 Err(e) => {
77 let error_msg = format!("{:#}", e);
78 self.update(Err(e));
79 Err(anyhow!(error_msg))
80 }
81 }
82 }
83}
84
85pub struct Consumer<T> {
87 shared: SharedVal<T>,
88}
89
90impl<T> Consumer<T> {
91 fn new(shared: SharedVal<T>) -> Self {
92 Self { shared }
93 }
94
95 pub fn try_take(&self) -> Result<Option<T>> {
98 match self.shared.lock().unwrap().take() {
99 Some(Ok(v)) => Ok(Some(v)),
100 Some(Err(e)) => Err(e),
101 None => Ok(None),
102 }
103 }
104}
105
106pub fn collector_consumer<T, Plugin: AsyncCollectorPlugin<T = T>>(
108 plugin: Plugin,
109) -> (AsyncCollector<T, Plugin>, Consumer<T>) {
110 let shared = Arc::new(Mutex::new(None));
111 (
112 AsyncCollector::new(shared.clone(), plugin),
113 Consumer::new(shared),
114 )
115}
116
117#[cfg(test)]
118mod test {
119 use std::sync::Arc;
120 use std::sync::Barrier;
121 use std::thread;
122
123 use super::*;
124
125 struct TestCollector {
126 counter: u64,
127 }
128
129 #[async_trait]
130 impl AsyncCollectorPlugin for TestCollector {
131 type T = u64;
132
133 async fn try_collect(&mut self) -> Result<Option<u64>> {
134 self.counter += 1;
135 if self.counter == 3 {
136 Ok(None)
138 } else if self.counter == 4 {
139 Err(anyhow!("boom"))
141 } else {
142 Ok(Some(self.counter))
143 }
144 }
145 }
146
147 #[test]
148 fn test_collect_and_consume() {
149 let (mut collector, consumer) = collector_consumer(TestCollector { counter: 0 });
150 let barrier = Arc::new(Barrier::new(2));
151 let c = barrier.clone();
152
153 let handle = thread::spawn(move || {
154 futures::executor::block_on(collector.collect_and_update()).unwrap();
155 futures::executor::block_on(collector.collect_and_update()).unwrap();
157 c.wait(); c.wait(); futures::executor::block_on(collector.collect_and_update()).unwrap();
162 c.wait(); c.wait(); let is_error = futures::executor::block_on(collector.collect_and_update()).is_err();
167 c.wait(); assert!(is_error, "Collector did not return an error");
169 });
170 barrier.wait(); assert_eq!(Some(2), consumer.try_take().unwrap());
173 barrier.wait(); barrier.wait(); assert_eq!(None, consumer.try_take().unwrap());
177 barrier.wait(); barrier.wait(); assert!(consumer.try_take().is_err());
181
182 handle.join().unwrap();
183 }
184}