1use crate::sources;
2
3use log::{debug, error};
4use std::collections::HashMap;
5use std::net::IpAddr;
6use std::option::Option;
7use std::vec::Vec;
8
9use crate::sources::Family;
10
11pub type Sources = Vec<Box<dyn sources::Source>>;
13
14use std::default::Default;
15
16#[derive(Debug, Copy, Clone, Default)]
18pub enum Policy {
19 #[default]
22 All,
23 First,
26}
27
28pub struct Consensus {
31 voters: Sources,
32 policy: Policy,
33 family: Family,
34}
35
36pub struct ConsensusBuilder {
38 voters: Sources,
39 policy: Policy,
40 family: Family,
41}
42
43impl Default for ConsensusBuilder {
44 fn default() -> Self {
45 Self::new()
46 }
47}
48
49impl ConsensusBuilder {
50 pub fn new() -> ConsensusBuilder {
51 ConsensusBuilder {
52 voters: vec![],
53 policy: Policy::default(),
54 family: Family::default(),
55 }
56 }
57
58 pub fn add_sources<T>(mut self, source: T) -> ConsensusBuilder
64 where
65 T: IntoIterator<Item = Box<dyn sources::Source>>,
66 {
67 self.voters.extend(source);
68 self
69 }
70
71 pub fn policy(mut self, policy: Policy) -> ConsensusBuilder {
72 self.policy = policy;
73 self
74 }
75
76 pub fn family(mut self, family: Family) -> ConsensusBuilder {
77 self.family = family;
78 self
79 }
80
81 pub fn build(self) -> Consensus {
83 Consensus {
84 voters: self.voters,
85 policy: self.policy,
86 family: self.family,
87 }
88 }
89}
90
91impl Consensus {
92 pub async fn get_consensus(&self) -> Option<IpAddr> {
94 match self.policy {
95 Policy::All => self.all().await,
96 Policy::First => self.first().await,
97 }
98 }
99
100 async fn all(&self) -> Option<IpAddr> {
101 let results =
102 futures::future::join_all(self.voters.iter().map(|voter| voter.get_ip(self.family)))
103 .await;
104
105 debug!("Results {:?}", results);
106 let mut accumulate = HashMap::new();
107 for (pos, result) in results.into_iter().enumerate() {
108 match result {
109 Ok(result) => {
110 accumulate
111 .entry(result)
112 .and_modify(|c| *c += 1)
113 .or_insert(1);
114 }
115 Err(err) => error!("Source {} failed {:?}", self.voters[pos], err),
116 };
117 }
118
119 let mut ordered_output: Vec<_> = accumulate.iter().collect();
120 ordered_output.sort_unstable_by(|(_, a), (_, b)| a.cmp(b));
121 debug!("Sorted results {:?}", ordered_output);
122
123 ordered_output.pop().map(|x| *x.0)
124 }
125
126 async fn first(&self) -> Option<IpAddr> {
127 for voter in &self.voters {
128 let result = voter.get_ip(self.family).await;
129 debug!("Results {:?}", result);
130 if result.is_ok() {
131 return result.ok();
132 }
133 }
134 debug!("Tried all sources");
135 None
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142
143 use crate::sources::MockSource;
144 use mockall::predicate::eq;
145 use std::net::Ipv4Addr;
146 use tokio_test::block_on;
147
148 const IP0: IpAddr = IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0));
149
150 fn make_success(ip: IpAddr) -> Box<dyn sources::Source> {
151 let mut mock = MockSource::new();
152 mock.expect_get_ip()
153 .with(eq(Family::Any))
154 .times(1)
155 .returning(move |_| Box::pin(futures::future::ready(Ok(ip))));
156 Box::new(mock)
157 }
158
159 fn make_fail() -> Box<dyn sources::Source> {
160 let mut mock = MockSource::new();
161 mock.expect_get_ip()
162 .with(eq(Family::Any))
163 .times(1)
164 .returning(move |_| {
165 let invalid_ip: Result<IpAddr, std::net::AddrParseError> = "x.0.0.0".parse();
166 Box::pin(futures::future::ready(Err(sources::Error::InvalidAddress(
167 invalid_ip.err().unwrap(),
168 ))))
169 });
170 Box::new(mock)
171 }
172
173 fn make_untouched() -> Box<dyn sources::Source> {
174 let mut mock = MockSource::new();
175 mock.expect_get_ip().with(eq(Family::Any)).times(0);
176 Box::new(mock)
177 }
178
179 #[test]
180 fn test_success() {
181 let sources: Sources = vec![make_success(IP0)];
182 let consensus = ConsensusBuilder::new().add_sources(sources).build();
183 let result = consensus.get_consensus();
184 let value = block_on(result);
185 assert_eq!(Some(IP0), value);
186 }
187
188 #[test]
189 fn test_all_success_multiple_same() {
190 let consensus = ConsensusBuilder::new()
191 .add_sources(vec![make_success(IP0), make_success(IP0)])
192 .policy(Policy::All)
193 .build();
194
195 let result = consensus.get_consensus();
196 let value = block_on(result);
197 assert_eq!(Some(IP0), value);
198 }
199
200 #[test]
201 fn test_all_success_multiple_same_diff() {
202 let ip2 = "0.0.0.1".parse().expect("valid ip");
203 let consensus = ConsensusBuilder::new()
204 .policy(Policy::All)
205 .add_sources(vec![
206 make_success(IP0),
207 make_success(IP0),
208 make_success(ip2),
209 ])
210 .build();
211
212 let result = consensus.get_consensus();
213 let value = block_on(result);
214 assert_eq!(Some(IP0), value);
215 }
216
217 #[test]
218 fn test_all_success_multiple_with_fails() {
219 let consensus = ConsensusBuilder::new()
220 .add_sources(vec![make_success(IP0), make_fail()])
221 .policy(Policy::All)
222 .build();
223 let result = consensus.get_consensus();
224 let value = block_on(result);
225 assert_eq!(Some(IP0), value);
226 }
227
228 #[test]
229 fn test_only_failures() {
230 for policy in [Policy::All, Policy::First].iter() {
231 let consensus = ConsensusBuilder::new()
232 .add_sources(vec![make_fail()])
233 .policy(*policy)
234 .build();
235 let result = consensus.get_consensus();
236 let value = block_on(result);
237 assert_eq!(None, value);
238 }
239 }
240
241 #[test]
242 fn test_add_sources_multiple_times() {
243 let consensus = ConsensusBuilder::new()
244 .add_sources(vec![make_fail()])
245 .add_sources(vec![make_success(IP0)])
246 .policy(Policy::All)
247 .build();
248 let result = consensus.get_consensus();
249 let value = block_on(result);
250 assert_eq!(Some(IP0), value);
251 }
252
253 #[test]
254 fn test_first_success_multiple_with_fails() {
255 let consensus = ConsensusBuilder::new()
256 .add_sources(vec![make_fail(), make_success(IP0)])
257 .policy(Policy::First)
258 .build();
259 let result = consensus.get_consensus();
260 let value = block_on(result);
261 assert_eq!(Some(IP0), value);
262 }
263
264 #[test]
265 fn test_first_success_with_first_success() {
266 let consensus = ConsensusBuilder::new()
267 .add_sources(vec![make_success(IP0), make_untouched()])
268 .policy(Policy::First)
269 .build();
270 let result = consensus.get_consensus();
271 let value = block_on(result);
272 assert_eq!(Some(IP0), value);
273 }
274}