1use std::collections::HashSet;
4use std::sync::Arc;
5use std::iter::FromIterator;
6use std::net::{IpAddr, SocketAddr, AddrParseError};
7use std::slice::Iter as VecIter;
8
9use rand::{thread_rng, Rng};
10use rand::distributions::{IndependentSample, Range};
11
12pub type Weight = u64;
16
17#[derive(Clone, Debug)]
26pub struct Address(Arc<Internal>);
27
28
29#[derive(Debug)]
30struct Internal {
31 addresses: Vec<Vec<(Weight, SocketAddr)>>,
32}
33
34#[derive(Debug)]
46pub struct Builder {
47 addresses: Vec<Vec<(Weight, SocketAddr)>>,
48}
49
50#[derive(Debug)]
52pub struct WeightedSet<'a> {
53 addresses: &'a [(Weight, SocketAddr)],
54}
55
56#[derive(Debug)]
59pub struct PriorityIter<'a>(VecIter<'a, Vec<(Weight, SocketAddr)>>);
60
61#[derive(Debug)]
65pub struct OwnedAddressIter(Arc<Internal>, usize, usize);
66
67#[derive(Debug)]
72pub struct AddressIter<'a>(VecIter<'a, (Weight, SocketAddr)>);
73
74impl<'a> Iterator for PriorityIter<'a> {
75 type Item = WeightedSet<'a>;
76 fn next(&mut self) -> Option<WeightedSet<'a>> {
77 self.0.next().map(|vec| WeightedSet {
78 addresses: &vec,
79 })
80 }
81}
82
83impl<'a> Iterator for OwnedAddressIter {
84 type Item = SocketAddr;
85 fn next(&mut self) -> Option<SocketAddr> {
86 let n = self.2;
87 self.2 += 1;
88 self.0.addresses.get(self.1)
89 .and_then(|vec| vec.get(n))
90 .map(|&(_, addr)| addr)
91 }
92}
93
94impl<'a> Iterator for AddressIter<'a> {
95 type Item = SocketAddr;
96 fn next(&mut self) -> Option<SocketAddr> {
97 self.0.next().map(|&(_weight, addr)| addr)
98 }
99}
100
101impl From<(IpAddr, u16)> for Address {
102 fn from((ip, port): (IpAddr, u16)) -> Address {
103 Address(Arc::new(Internal {
104 addresses: vec![vec![(0, SocketAddr::new(ip, port))]],
105 }))
106 }
107}
108
109impl From<SocketAddr> for Address {
110 fn from(addr: SocketAddr) -> Address {
111 Address(Arc::new(Internal {
112 addresses: vec![vec![(0, addr)]],
113 }))
114 }
115}
116
117impl<'a> From<&'a [SocketAddr]> for Address {
118 fn from(addr: &[SocketAddr]) -> Address {
119 Address(Arc::new(Internal {
120 addresses: vec![
121 addr.iter().map(|&a| (0, a)).collect()
122 ],
123 }))
124 }
125}
126
127impl FromIterator<SocketAddr> for Address {
128 fn from_iter<T>(iter: T) -> Self
129 where T: IntoIterator<Item=SocketAddr>
130 {
131 Address(Arc::new(Internal {
132 addresses: vec![iter.into_iter().map(|a| (0, a)).collect()],
133 }))
134 }
135}
136
137impl AsRef<Address> for Address {
138 fn as_ref(&self) -> &Address {
139 self
140 }
141}
142
143impl Builder {
144 pub fn new() -> Builder {
146 return Builder {
147 addresses: vec![Vec::new()],
148 }
149 }
150
151 pub fn add_addresses<'x, I>(&mut self, items: I) -> &mut Builder
157 where I: IntoIterator<Item=&'x (Weight, SocketAddr)>
158 {
159 self.addresses.push(items.into_iter().cloned().collect());
160 self
161 }
162 pub fn into_address(self) -> Address {
166 Address(Arc::new(Internal {
167 addresses: self.addresses.into_iter()
168 .filter(|vec| vec.len() > 0)
169 .collect(),
170 }))
171 }
172}
173
174
175impl Address {
176 pub fn pick_one(&self) -> Option<SocketAddr> {
186 self.at(0).pick_one()
187 }
188
189 pub fn addresses_at(&self, priority: usize) -> OwnedAddressIter {
195 OwnedAddressIter(self.0.clone(), priority, 0)
196 }
197
198 pub fn at(&self, priority: usize) -> WeightedSet {
207 self.0.addresses.get(priority)
208 .map(|vec| WeightedSet { addresses: vec })
209 .unwrap_or(WeightedSet{ addresses: &[] })
210 }
211
212 pub fn iter(&self) -> PriorityIter {
214 PriorityIter(self.0.addresses.iter())
215 }
216
217 pub fn parse_list<I>(iter: I)
222 -> Result<Address, AddrParseError>
223 where I: IntoIterator,
224 I::Item: AsRef<str>
225 {
226 Ok(Address(Arc::new(Internal {
227 addresses: vec![
228 iter.into_iter()
229 .map(|x| x.as_ref().parse().map(|sa| (0, sa)))
230 .collect::<Result<Vec<_>, _>>()?
231 ],
232 })))
233 }
234}
235
236impl PartialEq for Address {
237 fn eq(&self, other: &Address) -> bool {
238 self.0.addresses.len() == other.0.addresses.len() &&
239 self.iter().zip(other.iter()).all(|(s, o)| s == o)
240 }
241}
242
243impl Eq for Address {}
244
245
246impl<'a> WeightedSet<'a> {
247 pub fn pick_one(&self) -> Option<SocketAddr> {
254 if self.addresses.len() == 0 {
255 return None
256 }
257 let total_weight = self.addresses.iter().map(|&(w, _)| w).sum();
258 if total_weight == 0 {
259 return Some(thread_rng().choose(self.addresses).unwrap().1)
261 }
262 let range = Range::new(0, total_weight);
263 let mut n = range.ind_sample(&mut thread_rng());
264 for &(w, addr) in self.addresses {
265 if n < w {
266 return Some(addr);
267 }
268 n -= w;
269 }
270 unreachable!();
271 }
272 pub fn addresses(&self) -> AddressIter {
278 AddressIter(self.addresses.iter())
279 }
280
281 pub fn compare_addresses(&self, other: &WeightedSet)
286 -> (Vec<SocketAddr>, Vec<SocketAddr>)
287 {
288 let mut old = Vec::new();
290 let mut new = Vec::new();
291 for &(_, a) in self.addresses {
292 if !other.addresses.iter().find(|&&(_, a1)| a == a1).is_some() {
293 old.push(a);
294 }
295 }
296 for &(_, a) in other.addresses {
297 if !self.addresses.iter().find(|&&(_, a1)| a == a1).is_some() {
298 new.push(a);
299 }
300 }
301 return (old, new);
302 }
303
304 pub fn len(&self) -> usize {
306 self.addresses.len()
307 }
308}
309
310impl<'a> PartialEq for WeightedSet<'a> {
311 fn eq(&self, other: &WeightedSet) -> bool {
312 if self.addresses.len() != other.addresses.len() {
317 return false;
318 }
319 for &pair in self.addresses {
320 if !other.addresses.iter().find(|&&pair1| pair == pair1).is_some()
321 {
322 return false;
323 }
324 }
325 for &pair in other.addresses {
326 if !self.addresses.iter().find(|&&pair1| pair == pair1).is_some()
327 {
328 return false;
329 }
330 }
331 return true;
332 }
333}
334
335pub fn union<I>(iter: I) -> Address
341 where I: IntoIterator,
342 I::Item: AsRef<Address>,
343{
344 let mut set = HashSet::new();
345 for child in iter {
346 set.extend(child.as_ref().at(0).addresses());
347 }
348 return set.into_iter().collect();
349}
350
351#[cfg(test)]
352mod test {
353
354 use super::{Address, union};
355 use std::collections::HashSet;
356 use std::net::{SocketAddr, IpAddr};
357 use std::str::FromStr;
358
359 use futures::Future;
360 use futures::stream::{Stream, iter_ok};
361
362 #[test]
363 fn test_iter() {
364 let ab = [ "127.0.0.1:1234", "10.0.0.1:3456" ]
365 .iter()
366 .map(|x| SocketAddr::from_str(x).unwrap())
367 .collect::<Address>();
368 let r = ab.iter()
369 .map(|x| x.addresses().collect::<Vec<_>>())
370 .collect::<Vec<_>>();
371 assert_eq!(r, vec![
372 [ "127.0.0.1:1234", "10.0.0.1:3456" ]
373 .iter()
374 .map(|x| SocketAddr::from_str(x).unwrap())
375 .collect::<Vec<_>>()
376 ]);
377 }
378
379 #[test]
380 fn from_socket_addr() {
381 Address::from(SocketAddr::from_str("127.0.0.1:1234").unwrap());
382 }
383
384 #[test]
385 fn from_ip() {
386 Address::from((IpAddr::from_str("127.0.0.1").unwrap(), 1234));
387 }
388
389 #[test]
390 fn from_slice() {
391 Address::from(&[SocketAddr::from_str("127.0.0.1:1234").unwrap()][..]);
392 }
393
394 #[test]
395 fn test_eq() {
396 let a1 = [ "127.0.0.1:1234", "10.0.0.1:3456" ]
397 .iter()
398 .map(|x| SocketAddr::from_str(x).unwrap())
399 .collect::<Address>();
400
401 let a2 = [ "127.0.0.1:1234", "10.0.0.1:3456" ]
402 .iter()
403 .map(|x| SocketAddr::from_str(x).unwrap())
404 .collect::<Address>();
405
406 assert_eq!(a1, a2);
407 }
408
409 #[test]
410 fn test_eq_reverse() {
411 let a1 = [ "127.0.0.1:1234", "10.0.0.1:3456" ]
412 .iter()
413 .map(|x| SocketAddr::from_str(x).unwrap())
414 .collect::<Address>();
415
416 let a2 = [ "10.0.0.1:3456", "127.0.0.1:1234" ]
417 .iter()
418 .map(|x| SocketAddr::from_str(x).unwrap())
419 .collect::<Address>();
420
421 assert_eq!(a1, a2);
422 }
423
424 #[test]
425 fn test_ne() {
426 let a1 = [ "127.0.0.1:1234", "10.0.0.1:5555" ]
427 .iter()
428 .map(|x| SocketAddr::from_str(x).unwrap())
429 .collect::<Address>();
430
431 let a2 = [ "10.0.0.1:3456", "127.0.0.1:1234" ]
432 .iter()
433 .map(|x| SocketAddr::from_str(x).unwrap())
434 .collect::<Address>();
435
436 assert_ne!(a1, a2);
437 }
438
439 #[test]
440 fn test_diff() {
441 let a1 = [ "127.0.0.1:1234", "10.0.0.1:3456" ]
442 .iter()
443 .map(|x| SocketAddr::from_str(x).unwrap())
444 .collect::<Address>();
445
446 let a2 = [ "127.0.0.2:1234", "10.0.0.1:3456" ]
447 .iter()
448 .map(|x| SocketAddr::from_str(x).unwrap())
449 .collect::<Address>();
450
451 let l1 = a1.iter().next().unwrap();
452 let l2 = a2.iter().next().unwrap();
453
454 assert_eq!(l1.compare_addresses(&l2),
455 (vec![SocketAddr::from_str("127.0.0.1:1234").unwrap()],
456 vec![SocketAddr::from_str("127.0.0.2:1234").unwrap()]));
457 }
458
459
460 #[test]
461 fn test_union() {
462 let a1 = Address::parse_list(
463 &[ "127.0.0.1:1234", "10.0.0.1:3456" ]
464 ).unwrap();
465 let a2 = Address::parse_list(
466 &[ "127.0.0.2:1234", "10.0.0.1:3456" ]
467 ).unwrap();
468
469 let a = union([a1, a2].iter());
470 assert_eq!(a.at(0).addresses().collect::<HashSet<_>>(), vec![
471 SocketAddr::from_str("127.0.0.1:1234").unwrap(),
472 SocketAddr::from_str("127.0.0.2:1234").unwrap(),
473 SocketAddr::from_str("10.0.0.1:3456").unwrap(),
474 ].into_iter().collect::<HashSet<_>>());
475 assert_eq!(a.at(0).addresses().collect::<Vec<_>>().len(), 3);
477 }
478
479 fn check_type<S: Stream>(stream: S) -> S
480 where S::Item: IntoIterator<Item=SocketAddr>
481 {
482 stream
483 }
484
485 #[test]
486 fn test_addresses_at_lifetime() {
487 assert_eq!(2usize,
488 check_type(
489 iter_ok::<_, ()>(vec![Address::parse_list(
490 &["127.0.0.1:8080", "172.0.0.1:8010"]
491 ).unwrap()])
492 .map(|a| a.addresses_at(0))
493 ).map(|a| a.into_iter().count())
494 .collect().wait().unwrap().into_iter().sum());
495 }
496}