1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
//! Looking up SRV records.

use std::io;
use domain_core::bits::name::{
    Dname, ParsedDname, ParsedDnameError, ToRelativeDname, ToDname
};
use domain_core::iana::Rtype;
use domain_core::rdata::parsed::{A, Aaaa, Srv};
use rand;
use rand::distributions::{Distribution, Uniform};
use tokio::prelude::{Async, Future, Poll, Stream};
use crate::resolver::Resolver;
use super::host::{FoundHosts, FoundHostsSocketIter, LookupHost, lookup_host};


//------------ lookup_srv ----------------------------------------------------

/// Creates a future that looks up SRV records.
///
/// The future will use the resolver given in `resolver` to query the
/// DNS for SRV records associated with domain name `name` and service
/// `service`. 
///
/// The value returned upon success can be turned into a stream of
/// `ResolvedSrvItem`s corresponding to the found SRV records, ordered as per
/// the usage rules defined in [RFC 2782]. If no matching SRV record is found,
/// A/AAAA queries on the bare domain name `name` will be attempted, yielding
/// a single element upon success using the port given by `fallback_port`,
/// typcially the standard port for the service in question.
///
/// Each item in the stream can be turned into an iterator over socket
/// addresses as accepted by, for instance, `TcpStream::connect`.
///
/// The future resolves to `None` whenever the request service is
/// “decidedly not available” at the requested domain, that is there is a
/// single SRV record with the root label as its target.
pub fn lookup_srv<R, S, N>(
    resolver: R,
    service: S,
    name: N,
    fallback_port: u16
) -> LookupSrv<R, S, N>
where
    R: Resolver,
    S: ToRelativeDname + Clone + Send + 'static,
    N: ToDname + Send + 'static
{
    let query = {
        let full_name = match (&service).chain(&name) {
            Ok(name) => name,
            Err(_) => {
                return LookupSrv {
                    data: None,
                    query: Err(Some(SrvError::LongName))
                }
            }
        };
        resolver.query((full_name, Rtype::Srv))
    };
    LookupSrv {
        data: Some(LookupData {
            resolver,
            host: name,
            service,
            fallback_port
        }),
        query: Ok(query)
    }
}


//------------ LookupData ----------------------------------------------------

#[derive(Debug)]
struct LookupData<R, S, N> {
    /// The resolver to run queries on.
    resolver: R,

    /// Bare host to be queried.
    ///
    /// This is kept for fallback if no SRV records are found.
    host: N,

    /// Service name
    service: S,

    /// Fallback port, used if no SRV records are found
    fallback_port: u16,
}


//------------ LookupSrv -----------------------------------------------------

/// The future returned by [`lookup_srv()`].
///
/// [`lookup_srv()`]: fn.lookup_srv.html
pub struct LookupSrv<R: Resolver, S, N> {
    data: Option<LookupData<R, S, N>>,
    query: Result<R::Query, Option<SrvError>>,
}


impl<R, S, N> Future for LookupSrv<R, S, N>
where
    R: Resolver,
    S: ToRelativeDname + Clone + Send + 'static,
    N: ToDname + Send + 'static
{
    type Item = Option<FoundSrvs<R, S>>;
    type Error = SrvError;

    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
        match self.query {
            Ok(ref mut query) => match query.poll() {
                Ok(Async::NotReady) => Ok(Async::NotReady),
                Ok(Async::Ready(answer)) => {
                    Ok(Async::Ready(
                        FoundSrvs::new(
                            answer,
                            self.data.take().expect("polled resolved future")
                        )?
                    ))
                }
                Err(_) => {
                    Ok(Async::Ready(Some(
                        FoundSrvs::new_dummy(
                            self.data.take().expect("polled resolved future"))
                    )))
                }
            }
            Err(ref mut err) => {
                Err(err.take().expect("polled resolved future"))
            }
        }
    }
}


//------------ LookupSrvStream -----------------------------------------------

/// Stream over SrvItem elements.
///
/// SrvItem elements are resolved as needed, skipping them in case of failure.
/// It is therefore guaranteed to yield only SrvItem structs that have
/// a `SrvItemState::Resolved` state.
pub struct LookupSrvStream<R: Resolver, S> {
    /// The resolver to use for A/AAAA requests.
    resolver: R,

    /// A vector of (potentially unresolved) SrvItem elements.
    ///
    /// Note that we take items from this via `pop`, so it needs to be ordered
    /// backwards.
    items: Vec<SrvItem<S>>,

    /// A/AAAA lookup for the last `SrvItem`  in `items`.
    lookup: Option<LookupHost<R>>
}

impl<R: Resolver, S> LookupSrvStream<R, S> {
    fn new(found: FoundSrvs<R, S>) -> Self {
        LookupSrvStream {
            resolver: found.resolver,
            items: found.items.into_iter().rev().collect(),
            lookup: None,
        }
    }
}


//--- Stream

impl<R, S> Stream for LookupSrvStream<R, S>
where R: Resolver, S: ToRelativeDname + Clone + Send + 'static {
    type Item = ResolvedSrvItem<S>;
    type Error = SrvError;

    fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
        // See if we have a query result. We need to break this in to because
        // of the mut ref on the inside of self.lookup.
        let res = if let Some(ref mut query) = self.lookup {
            match query.poll() {
                Ok(Async::NotReady) => return Ok(Async::NotReady),
                Ok(Async::Ready(found)) => {
                    Some(ResolvedSrvItem::from_item_and_hosts(
                        self.items.pop().unwrap(),
                        found
                    ))
                }
                Err(_) => None
            }
        }
        else {
            None
        };

        // We have a query result. Clear lookup and return.
        if let Some(res) = res {
            self.lookup = None;
            return Ok(Async::Ready(Some(res)))
        }

        // Start a new query if necessary. Return if we are done.
        match self.items.last() {
            Some(item) => match item.state {
                SrvItemState::Unresolved(ref host) => {
                    self.lookup = Some(lookup_host(&self.resolver, host));
                }
                _ => { }
            }
            None => return Ok(Async::Ready(None)) // we are done.
        }

        if self.lookup.is_some() {
            self.poll()
        }
        else {
            Ok(Async::Ready(Some(
                ResolvedSrvItem::from_item(self.items.pop().unwrap()).unwrap()
            )))
        }
    }
}


//------------ FoundSrvs -----------------------------------------------------

#[derive(Clone, Debug)]
pub struct FoundSrvs<R, S> {
    resolver: R,
    items: Vec<SrvItem<S>>,
}

impl<R, S> FoundSrvs<R, S> {
    pub fn into_stream(self) -> LookupSrvStream<R, S>
    where R: Resolver {
        LookupSrvStream::new(self)
    }

    /// Moves all results from `other` into `Self`, leaving `other` empty.
    ///
    /// Reorders merged results as if they were from a single query.
    pub fn merge(&mut self, other : &mut Self) {
        self.items.append(&mut other.items);
        Self::reorder_items(&mut self.items);
    }
}

impl<R: Resolver, S: Clone> FoundSrvs<R, S> {
    fn new<N: ToDname>(
        answer: R::Answer,
        data: LookupData<R, S, N>
    ) -> Result<Option<Self>, SrvError> {
        let name = answer.as_ref().canonical_name().unwrap();
        let mut rrs = Vec::new();
        Self::process_records(&mut rrs, &answer, &name)?;

        if rrs.len() == 0 {
            return Ok(Some(Self::new_dummy(data)))
        }
        if rrs.len() == 1 && rrs[0].target().is_root() {
            // Exactly one record with target "." indicates no service.
            return Ok(None)
        }

        // Build results including potentially resolved IP addresses
        let mut items = Vec::with_capacity(rrs.len());
        Self::items_from_rrs(&rrs, &answer, &mut items, &data)?;
        Self::reorder_items(&mut items);

        Ok(Some(FoundSrvs {
            resolver: data.resolver,
            items
        }))
    }

    fn new_dummy<N: ToDname>(data: LookupData<R, S, N>) -> Self {
        FoundSrvs {
            resolver: data.resolver,
            items: vec![
                SrvItem {
                    priority: 0,
                    weight: 0,
                    port: data.fallback_port,
                    service: None,
                    state: SrvItemState::Unresolved(data.host.to_name())
                }
            ]
        }
    }

    fn process_records(
        rrs: &mut Vec<Srv>,
        answer: &R::Answer,
        name: &ParsedDname
    ) -> Result<(), SrvError> {
        for record in answer.as_ref().answer()?.limit_to::<Srv>() {
            if let Ok(record) = record {
                if record.owner() == name {
                    rrs.push(record.data().clone())
                }
            }
        }
        Ok(())
    }

    fn items_from_rrs<N>(
        rrs: &[Srv],
        answer: &R::Answer,
        result: &mut Vec<SrvItem<S>>,
        data: &LookupData<R, S, N>,
    ) -> Result<(), SrvError> {
        for rr in rrs {
            let mut addrs = Vec::new();
            let name = rr.target().to_name();
            for record in answer.as_ref().additional()?.limit_to::<A>() {
                if let Ok(record) = record {
                    if record.owner() == &name {
                        addrs.push(record.data().addr().into())
                    }
                }
            }
            for record in answer.as_ref().additional()?.limit_to::<Aaaa>() {
                if let Ok(record) = record {
                    if record.owner() == &name {
                        addrs.push(record.data().addr().into())
                    }
                }
            }
            let state = if addrs.is_empty() {
                SrvItemState::Unresolved(name)
            }
            else {
                SrvItemState::Resolved(FoundHosts::new(name, addrs))
            };
            result.push(SrvItem  {
                priority: rr.priority(),
                weight: rr.weight(),
                state: state,
                port: rr.port(),
                service: Some(data.service.clone())
            })
        }
        Ok(())
    }
}

impl<R, S> FoundSrvs<R, S> {
    fn reorder_items(items: &mut [SrvItem<S>]) {
        // First, reorder by priority and weight, effectively
        // grouping by priority, with weight 0 records at the beginning of
        // each group.
        items.sort_by_key(|k| (k.priority, k.weight));

        // Find each group and reorder them using reorder_by_weight
        let mut current_prio = 0;
        let mut weight_sum = 0;
        let mut first_index = 0;
        for i in 0 .. items.len() {
            if current_prio != items[i].priority {
                current_prio = items[i].priority;
                Self::reorder_by_weight(&mut items[first_index..i], weight_sum);
                weight_sum = 0;
                first_index = i;
            }
            weight_sum += items[i].weight as u32;
        }
        Self::reorder_by_weight(&mut items[first_index..], weight_sum);
    }

    /// Reorders items in a priority level based on their weight
    fn reorder_by_weight(items: &mut [SrvItem<S>], weight_sum : u32) {
        let mut rng = rand::thread_rng();
        let mut weight_sum = weight_sum;
        for i in 0 .. items.len() {
            let range = Uniform::new(0, weight_sum + 1);
            let mut sum : u32 = 0;
            let pick = range.sample(&mut rng);
            for j in 0 .. items.len() {
                sum += items[j].weight as u32;
                if sum >= pick {
                    weight_sum -= items[j].weight as u32;
                    items.swap(i, j);
                    break;
                }
            }
        }
    }
}


//------------ SrvItem -------------------------------------------------------

#[derive(Clone, Debug)]
pub struct SrvItem<S> {
    priority: u16,
    weight: u16,
    port: u16,
    service: Option<S>,
    state: SrvItemState
}

#[derive(Clone, Debug)]
pub enum SrvItemState {
    Unresolved(Dname),
    Resolved(FoundHosts)
}

impl<S> SrvItem<S> {

    /// Returns a reference to the service + proto part of the domain name.
    ///
    /// Useful when mixing results from different SRV queries.
    pub fn txt_service(&self) -> Option<&S> {
        self.service.as_ref()
    }

    /// Returns a reference to the name of the target.
    pub fn target(&self) -> &Dname {
        match self.state {
            SrvItemState::Unresolved(ref target) => target,
            SrvItemState::Resolved(ref found_hosts) => found_hosts.canonical_name()
        }
    }
}


//------------ ResolvedSrvItem -----------------------------------------------

#[derive(Clone, Debug)]
pub struct ResolvedSrvItem<S> {
    priority: u16,
    weight: u16,
    port: u16,
    service: Option<S>,
    hosts: FoundHosts,
}

impl<S> ResolvedSrvItem<S> {
    /// Returns an iterator over socket addresses matching an SRV record.
    ///
    /// SrvItem does not implement the `ToSocketAddrs` trait as the result
    /// of `to_socket_addrs()` does not have a static lifetime.
    pub fn to_socket_addrs(&self) -> FoundHostsSocketIter {
        self.hosts.port_iter(self.port)
    }

    fn from_item(item: SrvItem<S>) -> Option<Self> {
        if let SrvItemState::Resolved(hosts) = item.state {
            Some(ResolvedSrvItem {
            priority: item.priority,
            weight: item.weight,
            port: item.port,
            service: item.service,
            hosts: hosts
            })
        }
        else {
            None
        }
    }

    fn from_item_and_hosts(item: SrvItem<S>, hosts: FoundHosts) -> Self {
        ResolvedSrvItem {
            priority: item.priority,
            weight: item.weight,
            port: item.port,
            service: item.service,
            hosts: hosts
        }
    }
}


//------------ SrvError ------------------------------------------------------

#[derive(Debug)]
pub enum SrvError {
    LongName,
    MalformedAnswer,
    Query(io::Error),
}

impl From<io::Error> for SrvError {
    fn from(err: io::Error) -> SrvError {
        SrvError::Query(err)
    }
}

impl From<ParsedDnameError> for SrvError {
    fn from(_: ParsedDnameError) -> SrvError {
        SrvError::MalformedAnswer
    }
}