use std::sync::Arc;
use std::sync::Mutex;
use async_trait::async_trait;
use super::*;
#[async_trait]
pub trait AsyncCollectorPlugin {
type T;
async fn try_collect(&mut self) -> Result<Option<Self::T>>;
}
type SharedVal<T> = Arc<Mutex<Option<Result<T>>>>;
pub struct AsyncCollector<T, Plugin: AsyncCollectorPlugin<T = T>> {
shared: SharedVal<T>,
plugin: Plugin,
}
impl<T, Plugin: AsyncCollectorPlugin<T = T>> AsyncCollector<T, Plugin> {
fn new(shared: SharedVal<T>, plugin: Plugin) -> Self {
Self { shared, plugin }
}
fn update(&self, value: Result<T>) {
*self.shared.lock().unwrap() = Some(value);
}
pub async fn collect_and_update(&mut self) -> Result<bool> {
let collect_result = self
.plugin
.try_collect()
.await
.context("Collector failed to read");
match collect_result {
Ok(Some(sample)) => {
self.update(Ok(sample));
Ok(true)
}
Ok(None) => Ok(false),
Err(e) => {
let error_msg = format!("{:#}", e);
self.update(Err(e));
Err(anyhow!(error_msg))
}
}
}
}
pub struct Consumer<T> {
shared: SharedVal<T>,
}
impl<T> Consumer<T> {
fn new(shared: SharedVal<T>) -> Self {
Self { shared }
}
pub fn try_take(&self) -> Result<Option<T>> {
match self.shared.lock().unwrap().take() {
Some(Ok(v)) => Ok(Some(v)),
Some(Err(e)) => Err(e),
None => Ok(None),
}
}
}
pub fn collector_consumer<T, Plugin: AsyncCollectorPlugin<T = T>>(
plugin: Plugin,
) -> (AsyncCollector<T, Plugin>, Consumer<T>) {
let shared = Arc::new(Mutex::new(None));
(
AsyncCollector::new(shared.clone(), plugin),
Consumer::new(shared),
)
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use std::sync::Barrier;
use std::thread;
use super::*;
struct TestCollector {
counter: u64,
}
#[async_trait]
impl AsyncCollectorPlugin for TestCollector {
type T = u64;
async fn try_collect(&mut self) -> Result<Option<u64>> {
self.counter += 1;
if self.counter == 3 {
Ok(None)
} else if self.counter == 4 {
Err(anyhow!("boom"))
} else {
Ok(Some(self.counter))
}
}
}
#[test]
fn test_collect_and_consume() {
let (mut collector, consumer) = collector_consumer(TestCollector { counter: 0 });
let barrier = Arc::new(Barrier::new(2));
let c = barrier.clone();
let handle = thread::spawn(move || {
futures::executor::block_on(collector.collect_and_update()).unwrap();
futures::executor::block_on(collector.collect_and_update()).unwrap();
c.wait(); c.wait(); futures::executor::block_on(collector.collect_and_update()).unwrap();
c.wait(); c.wait(); let is_error = futures::executor::block_on(collector.collect_and_update()).is_err();
c.wait(); assert!(is_error, "Collector did not return an error");
});
barrier.wait(); assert_eq!(Some(2), consumer.try_take().unwrap());
barrier.wait(); barrier.wait(); assert_eq!(None, consumer.try_take().unwrap());
barrier.wait(); barrier.wait(); assert!(consumer.try_take().is_err());
handle.join().unwrap();
}
}