below_model/
collector_plugin.rs

1// Copyright (c) Facebook, Inc. and its affiliates.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16use std::sync::Mutex;
17
18use async_trait::async_trait;
19
20use super::*;
21
22// For data collection that should be performed on a different thread
23#[async_trait]
24pub trait AsyncCollectorPlugin {
25    type T;
26
27    // Try to collect a sample of type `T`.
28    //
29    // On success, this should return `Ok(Some(sample))`.
30    //
31    // On a recoverable error, this should return `Ok(None)`. The
32    // function itself should consume the error (e.g. log the error)
33    // so that it does not get sent to a consumer thread
34    //
35    // On unrecoverable error, this should return `Err(e)`.
36    async fn try_collect(&mut self) -> Result<Option<Self::T>>;
37}
38
39type SharedVal<T> = Arc<Mutex<Option<Result<T>>>>;
40
41// A wrapper around an `AsyncCollectorPlugin` that allows samples to
42// be sent to a `Consumer`.
43pub struct AsyncCollector<T, Plugin: AsyncCollectorPlugin<T = T>> {
44    shared: SharedVal<T>,
45    plugin: Plugin,
46}
47
48impl<T, Plugin: AsyncCollectorPlugin<T = T>> AsyncCollector<T, Plugin> {
49    fn new(shared: SharedVal<T>, plugin: Plugin) -> Self {
50        Self { shared, plugin }
51    }
52
53    fn update(&self, value: Result<T>) {
54        *self.shared.lock().unwrap() = Some(value);
55    }
56
57    // Collect sample and update value shared with consumer. Replaces
58    // any existing sample that consumer has not consumed yet.
59    //
60    // Returns true if data was collected and sent. Returns false if
61    // there was a recoverable error. Returns an error if there was an
62    // unrecoverable error.
63    pub async fn collect_and_update(&mut self) -> Result<bool> {
64        let collect_result = self
65            .plugin
66            .try_collect()
67            .await
68            .context("Collector failed to read");
69
70        match collect_result {
71            Ok(Some(sample)) => {
72                self.update(Ok(sample));
73                Ok(true)
74            }
75            Ok(None) => Ok(false),
76            Err(e) => {
77                let error_msg = format!("{:#}", e);
78                self.update(Err(e));
79                Err(anyhow!(error_msg))
80            }
81        }
82    }
83}
84
85// A consumer for samples collected from a `AsyncCollector`
86pub struct Consumer<T> {
87    shared: SharedVal<T>,
88}
89
90impl<T> Consumer<T> {
91    fn new(shared: SharedVal<T>) -> Self {
92        Self { shared }
93    }
94
95    // Try to get latest sample of data if it exists.
96    // Returns the error if the collector sent an error.
97    pub fn try_take(&self) -> Result<Option<T>> {
98        match self.shared.lock().unwrap().take() {
99            Some(Ok(v)) => Ok(Some(v)),
100            Some(Err(e)) => Err(e),
101            None => Ok(None),
102        }
103    }
104}
105
106// Create a collector consumer pair for a collector plugin
107pub fn collector_consumer<T, Plugin: AsyncCollectorPlugin<T = T>>(
108    plugin: Plugin,
109) -> (AsyncCollector<T, Plugin>, Consumer<T>) {
110    let shared = Arc::new(Mutex::new(None));
111    (
112        AsyncCollector::new(shared.clone(), plugin),
113        Consumer::new(shared),
114    )
115}
116
117#[cfg(test)]
118mod test {
119    use std::sync::Arc;
120    use std::sync::Barrier;
121    use std::thread;
122
123    use super::*;
124
125    struct TestCollector {
126        counter: u64,
127    }
128
129    #[async_trait]
130    impl AsyncCollectorPlugin for TestCollector {
131        type T = u64;
132
133        async fn try_collect(&mut self) -> Result<Option<u64>> {
134            self.counter += 1;
135            if self.counter == 3 {
136                // Recoverable error
137                Ok(None)
138            } else if self.counter == 4 {
139                // Unrecoverable error
140                Err(anyhow!("boom"))
141            } else {
142                Ok(Some(self.counter))
143            }
144        }
145    }
146
147    #[test]
148    fn test_collect_and_consume() {
149        let (mut collector, consumer) = collector_consumer(TestCollector { counter: 0 });
150        let barrier = Arc::new(Barrier::new(2));
151        let c = barrier.clone();
152
153        let handle = thread::spawn(move || {
154            futures::executor::block_on(collector.collect_and_update()).unwrap();
155            // Test overwriting sample
156            futures::executor::block_on(collector.collect_and_update()).unwrap();
157            c.wait(); // <-- 1
158            // Consumer checking overwritten sample
159            c.wait(); // <-- 2
160            // Test sending None
161            futures::executor::block_on(collector.collect_and_update()).unwrap();
162            c.wait(); // <-- 3
163            // Consumer checking None
164            c.wait(); // <-- 4
165            // Test sending error. Will fail on both collector and consumer threads.
166            let is_error = futures::executor::block_on(collector.collect_and_update()).is_err();
167            c.wait(); // <-- 5
168            assert!(is_error, "Collector did not return an error");
169        });
170        // Collector overwriting sample
171        barrier.wait(); // <-- 1
172        assert_eq!(Some(2), consumer.try_take().unwrap());
173        barrier.wait(); // <-- 2
174        // Collector sending None
175        barrier.wait(); // <-- 3
176        assert_eq!(None, consumer.try_take().unwrap());
177        barrier.wait(); // <-- 4
178        // Collector sending error
179        barrier.wait(); // <-- 5
180        assert!(consumer.try_take().is_err());
181
182        handle.join().unwrap();
183    }
184}