async_profiler_agent/reporter/
multi.rs

1use async_trait::async_trait;
2
3use crate::metadata::ReportMetadata;
4
5use super::Reporter;
6
7use std::fmt;
8
9/// An aggregated error that contains an error per reporter. A reporter is identified
10/// by the result of its Debug impl.
11#[derive(Debug, thiserror::Error)]
12struct MultiError {
13    errors: Vec<(String, Box<dyn std::error::Error + Send>)>,
14}
15
16impl fmt::Display for MultiError {
17    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
18        write!(f, "{{")?;
19        let mut first = true;
20        for (reporter, err) in self.errors.iter() {
21            if !first {
22                write!(f, ", ")?;
23            }
24            first = false;
25            write!(f, "{}: {}", reporter, err)?;
26        }
27        write!(f, "}}")
28    }
29}
30
31#[derive(Debug)]
32/// A reporter that reports profiling results to several destinations.
33///
34/// If one of the destinations errors, it will continue reporting to the other ones.
35pub struct MultiReporter {
36    reporters: Vec<Box<dyn Reporter + Send + Sync>>,
37}
38
39impl MultiReporter {
40    /// Create a new MultiReporter from a set of reporters
41    pub fn new(reporters: Vec<Box<dyn Reporter + Send + Sync>>) -> Self {
42        MultiReporter { reporters }
43    }
44}
45
46#[async_trait]
47impl Reporter for MultiReporter {
48    async fn report(
49        &self,
50        jfr: Vec<u8>,
51        metadata: &ReportMetadata,
52    ) -> Result<(), Box<dyn std::error::Error + Send>> {
53        let jfr_ref = &jfr[..];
54        let errors = futures::future::join_all(self.reporters.iter().map(|reporter| async move {
55            reporter
56                .report(jfr_ref.to_owned(), metadata)
57                .await
58                .map_err(move |e| (format!("{:?}", reporter), e))
59        }))
60        .await;
61        // return all errors
62        let errors: Vec<_> = errors.into_iter().flat_map(|e| e.err()).collect();
63        if errors.is_empty() {
64            Ok(())
65        } else {
66            Err(Box::new(MultiError { errors }))
67        }
68    }
69}
70
71#[cfg(test)]
72pub mod test {
73    use std::{
74        sync::{
75            atomic::{self, AtomicBool},
76            Arc,
77        },
78        time::Duration,
79    };
80
81    use async_trait::async_trait;
82
83    use crate::{
84        metadata::{ReportMetadata, DUMMY_METADATA},
85        reporter::Reporter,
86    };
87
88    use super::MultiReporter;
89
90    #[derive(Debug)]
91    struct OkReporter(Arc<AtomicBool>);
92    #[async_trait]
93    impl Reporter for OkReporter {
94        async fn report(
95            &self,
96            _jfr: Vec<u8>,
97            _metadata: &ReportMetadata,
98        ) -> Result<(), Box<dyn std::error::Error + Send>> {
99            tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
100            self.0.store(true, atomic::Ordering::Relaxed);
101            Ok(())
102        }
103    }
104
105    #[derive(Debug, thiserror::Error)]
106    enum Error {
107        #[error("failed: {0}")]
108        Failed(String),
109    }
110
111    #[derive(Debug)]
112    struct ErrReporter(String);
113    #[async_trait]
114    impl Reporter for ErrReporter {
115        async fn report(
116            &self,
117            _jfr: Vec<u8>,
118            _metadata: &ReportMetadata,
119        ) -> Result<(), Box<dyn std::error::Error + Send>> {
120            Err(Box::new(Error::Failed(self.0.clone())))
121        }
122    }
123
124    #[tokio::test(start_paused = true)]
125    async fn test_multi_reporter_ok() {
126        let signals: Vec<_> = (0..10).map(|_| Arc::new(AtomicBool::new(false))).collect();
127        let reporter = MultiReporter::new(
128            signals
129                .iter()
130                .map(|signal| {
131                    Box::new(OkReporter(signal.clone())) as Box<dyn Reporter + Send + Sync>
132                })
133                .collect(),
134        );
135        // test that reports are done in parallel
136        tokio::time::timeout(
137            Duration::from_secs(2),
138            reporter.report(vec![], &DUMMY_METADATA),
139        )
140        .await
141        .unwrap()
142        .unwrap();
143        // test that reports are done
144        assert!(signals.iter().all(|s| s.load(atomic::Ordering::Relaxed)));
145    }
146
147    #[tokio::test(start_paused = true)]
148    async fn test_multi_reporter_err() {
149        let signal_before = Arc::new(AtomicBool::new(false));
150        let signal_after = Arc::new(AtomicBool::new(false));
151        let reporter = MultiReporter::new(vec![
152            Box::new(OkReporter(signal_before.clone())) as Box<dyn Reporter + Send + Sync>,
153            Box::new(ErrReporter("foo".to_owned())) as Box<dyn Reporter + Send + Sync>,
154            Box::new(ErrReporter("bar".to_owned())) as Box<dyn Reporter + Send + Sync>,
155            Box::new(OkReporter(signal_after.clone())) as Box<dyn Reporter + Send + Sync>,
156        ]);
157        // test that reports are done and return an error
158        let err = format!(
159            "{}",
160            reporter.report(vec![], &DUMMY_METADATA).await.unwrap_err()
161        );
162        assert_eq!(
163            err,
164            "{ErrReporter(\"foo\"): failed: foo, ErrReporter(\"bar\"): failed: bar}"
165        );
166        // test that reports are done even though a reporter errored
167        assert!(signal_before.load(atomic::Ordering::Relaxed));
168        assert!(signal_after.load(atomic::Ordering::Relaxed));
169    }
170}