external_ip/
consensus.rs

1use crate::sources;
2
3use log::{debug, error};
4use std::collections::HashMap;
5use std::net::IpAddr;
6use std::option::Option;
7use std::vec::Vec;
8
9use crate::sources::Family;
10
11/// Type alias for easier usage of the library
12pub type Sources = Vec<Box<dyn sources::Source>>;
13
14use std::default::Default;
15
16/// Policies for Consensus resolution
17#[derive(Debug, Copy, Clone, Default)]
18pub enum Policy {
19    /// Requires all sources to be queried, it will ignore the sources returning errors but and it
20    /// will return the IP with the most replies as the result.
21    #[default]
22    All,
23    /// Will test the sources one by one in order until there's one success and will return it as
24    /// the result.
25    First,
26}
27
28/// Consensus system that aggregates the various sources of information and returns the most common
29/// reply
30pub struct Consensus {
31    voters: Sources,
32    policy: Policy,
33    family: Family,
34}
35
36/// Consensus builder
37pub struct ConsensusBuilder {
38    voters: Sources,
39    policy: Policy,
40    family: Family,
41}
42
43impl Default for ConsensusBuilder {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49impl ConsensusBuilder {
50    pub fn new() -> ConsensusBuilder {
51        ConsensusBuilder {
52            voters: vec![],
53            policy: Policy::default(),
54            family: Family::default(),
55        }
56    }
57
58    /// Adds sources to the builder
59    ///
60    /// # Arguments
61    ///
62    /// * `source` - Iterable of sources to add
63    pub fn add_sources<T>(mut self, source: T) -> ConsensusBuilder
64    where
65        T: IntoIterator<Item = Box<dyn sources::Source>>,
66    {
67        self.voters.extend(source);
68        self
69    }
70
71    pub fn policy(mut self, policy: Policy) -> ConsensusBuilder {
72        self.policy = policy;
73        self
74    }
75
76    pub fn family(mut self, family: Family) -> ConsensusBuilder {
77        self.family = family;
78        self
79    }
80
81    /// Returns the configured consensus struct from the builder
82    pub fn build(self) -> Consensus {
83        Consensus {
84            voters: self.voters,
85            policy: self.policy,
86            family: self.family,
87        }
88    }
89}
90
91impl Consensus {
92    /// Returns the IP address it found or None if no source worked.
93    pub async fn get_consensus(&self) -> Option<IpAddr> {
94        match self.policy {
95            Policy::All => self.all().await,
96            Policy::First => self.first().await,
97        }
98    }
99
100    async fn all(&self) -> Option<IpAddr> {
101        let results =
102            futures::future::join_all(self.voters.iter().map(|voter| voter.get_ip(self.family)))
103                .await;
104
105        debug!("Results {:?}", results);
106        let mut accumulate = HashMap::new();
107        for (pos, result) in results.into_iter().enumerate() {
108            match result {
109                Ok(result) => {
110                    accumulate
111                        .entry(result)
112                        .and_modify(|c| *c += 1)
113                        .or_insert(1);
114                }
115                Err(err) => error!("Source {} failed {:?}", self.voters[pos], err),
116            };
117        }
118
119        let mut ordered_output: Vec<_> = accumulate.iter().collect();
120        ordered_output.sort_unstable_by(|(_, a), (_, b)| a.cmp(b));
121        debug!("Sorted results {:?}", ordered_output);
122
123        ordered_output.pop().map(|x| *x.0)
124    }
125
126    async fn first(&self) -> Option<IpAddr> {
127        for voter in &self.voters {
128            let result = voter.get_ip(self.family).await;
129            debug!("Results {:?}", result);
130            if result.is_ok() {
131                return result.ok();
132            }
133        }
134        debug!("Tried all sources");
135        None
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142
143    use crate::sources::MockSource;
144    use mockall::predicate::eq;
145    use std::net::Ipv4Addr;
146    use tokio_test::block_on;
147
148    const IP0: IpAddr = IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0));
149
150    fn make_success(ip: IpAddr) -> Box<dyn sources::Source> {
151        let mut mock = MockSource::new();
152        mock.expect_get_ip()
153            .with(eq(Family::Any))
154            .times(1)
155            .returning(move |_| Box::pin(futures::future::ready(Ok(ip))));
156        Box::new(mock)
157    }
158
159    fn make_fail() -> Box<dyn sources::Source> {
160        let mut mock = MockSource::new();
161        mock.expect_get_ip()
162            .with(eq(Family::Any))
163            .times(1)
164            .returning(move |_| {
165                let invalid_ip: Result<IpAddr, std::net::AddrParseError> = "x.0.0.0".parse();
166                Box::pin(futures::future::ready(Err(sources::Error::InvalidAddress(
167                    invalid_ip.err().unwrap(),
168                ))))
169            });
170        Box::new(mock)
171    }
172
173    fn make_untouched() -> Box<dyn sources::Source> {
174        let mut mock = MockSource::new();
175        mock.expect_get_ip().with(eq(Family::Any)).times(0);
176        Box::new(mock)
177    }
178
179    #[test]
180    fn test_success() {
181        let sources: Sources = vec![make_success(IP0)];
182        let consensus = ConsensusBuilder::new().add_sources(sources).build();
183        let result = consensus.get_consensus();
184        let value = block_on(result);
185        assert_eq!(Some(IP0), value);
186    }
187
188    #[test]
189    fn test_all_success_multiple_same() {
190        let consensus = ConsensusBuilder::new()
191            .add_sources(vec![make_success(IP0), make_success(IP0)])
192            .policy(Policy::All)
193            .build();
194
195        let result = consensus.get_consensus();
196        let value = block_on(result);
197        assert_eq!(Some(IP0), value);
198    }
199
200    #[test]
201    fn test_all_success_multiple_same_diff() {
202        let ip2 = "0.0.0.1".parse().expect("valid ip");
203        let consensus = ConsensusBuilder::new()
204            .policy(Policy::All)
205            .add_sources(vec![
206                make_success(IP0),
207                make_success(IP0),
208                make_success(ip2),
209            ])
210            .build();
211
212        let result = consensus.get_consensus();
213        let value = block_on(result);
214        assert_eq!(Some(IP0), value);
215    }
216
217    #[test]
218    fn test_all_success_multiple_with_fails() {
219        let consensus = ConsensusBuilder::new()
220            .add_sources(vec![make_success(IP0), make_fail()])
221            .policy(Policy::All)
222            .build();
223        let result = consensus.get_consensus();
224        let value = block_on(result);
225        assert_eq!(Some(IP0), value);
226    }
227
228    #[test]
229    fn test_only_failures() {
230        for policy in [Policy::All, Policy::First].iter() {
231            let consensus = ConsensusBuilder::new()
232                .add_sources(vec![make_fail()])
233                .policy(*policy)
234                .build();
235            let result = consensus.get_consensus();
236            let value = block_on(result);
237            assert_eq!(None, value);
238        }
239    }
240
241    #[test]
242    fn test_add_sources_multiple_times() {
243        let consensus = ConsensusBuilder::new()
244            .add_sources(vec![make_fail()])
245            .add_sources(vec![make_success(IP0)])
246            .policy(Policy::All)
247            .build();
248        let result = consensus.get_consensus();
249        let value = block_on(result);
250        assert_eq!(Some(IP0), value);
251    }
252
253    #[test]
254    fn test_first_success_multiple_with_fails() {
255        let consensus = ConsensusBuilder::new()
256            .add_sources(vec![make_fail(), make_success(IP0)])
257            .policy(Policy::First)
258            .build();
259        let result = consensus.get_consensus();
260        let value = block_on(result);
261        assert_eq!(Some(IP0), value);
262    }
263
264    #[test]
265    fn test_first_success_with_first_success() {
266        let consensus = ConsensusBuilder::new()
267            .add_sources(vec![make_success(IP0), make_untouched()])
268            .policy(Policy::First)
269            .build();
270        let result = consensus.get_consensus();
271        let value = block_on(result);
272        assert_eq!(Some(IP0), value);
273    }
274}