1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
//! Simplified URL parsing infrastructure, including the domain resolver
//! implementation if the `embedded-domain-resolver` feature is disabled.

mod parser;
// mod parser_regex;

#[cfg(not(feature = "embedded-domain-resolver"))]
static DOMAIN_RESOLVER: once_cell::sync::OnceCell<Box<dyn ResolvesDomain>> = once_cell::sync::OnceCell::new();

/// Sets the library's domain resolver implementation.
///
/// If the `embedded-domain-resolver` feature is disabled and the library is
/// used without this having been set, panics may occur!
///
/// Will return the resolver if it has already been previously set.
#[cfg(not(feature = "embedded-domain-resolver"))]
pub fn set_domain_resolver(resolver: Box<dyn ResolvesDomain>) -> Result<(), Box<dyn ResolvesDomain>> {
    DOMAIN_RESOLVER.set(resolver)
}

/// Default `addr`-based domain resolution implementation used when the
/// `embedded-domain-resolver` feature is enabled.
#[cfg(feature = "embedded-domain-resolver")]
struct DefaultResolver;

#[cfg(feature = "embedded-domain-resolver")]
impl ResolvesDomain for DefaultResolver {
    fn get_host_domain(&self, host: &str) -> (usize, usize) {
        use addr::parser::DomainName;
        use addr::psl::List;

        if host.is_empty() {
            (0, 0)
        } else {
            match List.parse_domain_name(host) {
                Err(_e) => (0, host.len()),
                Ok(domain) => {
                    let host_len = host.len();
                    let domain_len = domain.root()
                        .unwrap_or_else(|| domain.suffix())
                        .len();
                    (host_len - domain_len, host_len)
                }
            }
        }
    }
}

/// Required trait for any domain resolution implementation used with this
/// crate.
pub trait ResolvesDomain: Send + Sync {
    /// Return the start and end indices of the domain (eTLD+1) of the given hostname.
    ///
    /// If there isn't a valid domain, `(0, host.len())` should be returned.
    ///
    /// ```
    /// # use adblock::url_parser::ResolvesDomain;
    /// # /// I'd use DefaultResolver here, but I can't use private structs in doctests.
    /// # /// Enjoy this mock implementation instead :(
    /// # struct Resolver;
    /// # impl ResolvesDomain for Resolver {
    /// #     fn get_host_domain(&self, host: &str) -> (usize, usize) {
    /// #         match host {
    /// #             "api.m.example.com" => (6, 17),
    /// #             "a.b.co.uk" => (2, 9),
    /// #             _ => unreachable!()
    /// #         }
    /// #     }
    /// # }
    /// # let resolver = Resolver;
    /// let host = "api.m.example.com";
    /// let (start, end) = resolver.get_host_domain(host);
    /// assert_eq!(&host[start..end], "example.com");
    ///
    /// let host = "a.b.co.uk";
    /// let (start, end) = resolver.get_host_domain(host);
    /// assert_eq!(&host[start..end], "b.co.uk");
    /// ```
    fn get_host_domain(&self, host: &str) -> (usize, usize);
}

/// Parsed URL representation.
pub struct RequestUrl {
    pub url: String,
    schema_end: usize,
    pub hostname_pos: (usize, usize),
    domain: (usize, usize),
}

impl RequestUrl {
    pub fn schema(&self) -> &str {
        &self.url[..self.schema_end]
    }
    pub fn hostname(&self) -> &str {
        &self.url[self.hostname_pos.0..self.hostname_pos.1]
    }
    pub fn domain(&self) -> &str {
        &self.url[self.hostname_pos.0 + self.domain.0 .. self.hostname_pos.0 + self.domain.1]
    }
}

/// Return the start and end indices of the domain of the given hostname.
pub(crate) fn get_host_domain(host: &str) -> (usize, usize) {
    #[cfg(not(feature = "embedded-domain-resolver"))]
    let domain_resolver = DOMAIN_RESOLVER.get().expect("An external domain resolver must be set when the `embedded-domain-resolver` feature is disabled.");
    #[cfg(feature = "embedded-domain-resolver")]
    let domain_resolver = DefaultResolver;

    domain_resolver.get_host_domain(host)
}

/// Return the string representation of the host (domain or IP address) for
/// this URL, if any together with the URL.
///
/// As part of hostname parsing, punycode decoding is used to convert URLs with
/// UTF characters to plain ASCII ones.  Serialisation then contains this
/// decoded URL that is used for further matching.
pub fn parse_url(url: &str) -> Option<RequestUrl> {
    let parsed = parser::Hostname::parse(&url).ok();
    parsed.and_then(|h| {
        match h.host_str() {
            Some(_host) => Some(RequestUrl {
                url: h.url_str().to_owned(),
                schema_end: h.scheme_end,
                hostname_pos: (h.host_start, h.host_end),
                domain: get_host_domain(&h.url_str()[h.host_start..h.host_end])
            }),
            _ => None
        }
    })
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    #[cfg(feature = "embedded-domain-resolver")]
    fn test_get_host_domain() {
        fn domain(host: &str) -> &str {
            let resolver = DefaultResolver;
            let (a, b) = resolver.get_host_domain(host);
            &host[a..b]
        }
        assert_eq!(domain("www.google.com"), "google.com");
        assert_eq!(domain("google.com."), "google.com.");
        assert_eq!(domain("a.b.co.uk"), "b.co.uk");
        assert_eq!(domain("foo.bar"), "foo.bar");
    }
}