chrootable_https/
cache.rs

1use lru_cache::LruCache;
2use std::net::IpAddr;
3use std::time::{Duration, Instant};
4
5
6/// https://tools.ietf.org/html/rfc2181,
7pub const MAX_TTL: u64 = 86400_u64;
8
9
10#[derive(Debug)]
11struct LruValue {
12    // this is None in case of an NX
13    ipaddr: Option<IpAddr>,
14    valid_until: Instant,
15}
16
17impl LruValue {
18    fn is_fresh(&self, now: Instant) -> bool {
19        now <= self.valid_until
20    }
21}
22
23#[derive(Debug, PartialEq)]
24pub enum Value {
25    None,
26    NX,
27    Some(IpAddr),
28}
29
30pub struct TtlConfig {
31    pub positive_min_ttl: Duration,
32    pub negative_min_ttl: Duration,
33    pub positive_max_ttl: Duration,
34    pub negative_max_ttl: Duration,
35}
36
37impl Default for TtlConfig {
38    fn default() -> TtlConfig {
39        TtlConfig {
40            positive_min_ttl: Duration::from_secs(0),
41            negative_min_ttl: Duration::from_secs(0),
42            positive_max_ttl: Duration::from_secs(MAX_TTL),
43            negative_max_ttl: Duration::from_secs(MAX_TTL),
44        }
45    }
46}
47
48#[derive(Debug)]
49pub struct DnsCache {
50    cache: LruCache<String, LruValue>,
51    positive_min_ttl: Duration,
52    negative_min_ttl: Duration,
53    positive_max_ttl: Duration,
54    negative_max_ttl: Duration,
55}
56
57impl DnsCache {
58    pub fn new(capacity: usize, ttl: TtlConfig) -> DnsCache {
59        let cache = LruCache::new(capacity);
60        DnsCache {
61            cache,
62            positive_min_ttl: ttl.positive_min_ttl,
63            negative_min_ttl: ttl.negative_min_ttl,
64            positive_max_ttl: ttl.positive_max_ttl,
65            negative_max_ttl: ttl.negative_max_ttl,
66        }
67    }
68
69    pub fn insert(&mut self, query: String, ipaddr: Option<IpAddr>, mut ttl: Duration, now: Instant) {
70        if ipaddr.is_some() {
71            if ttl < self.positive_min_ttl {
72                ttl = self.positive_min_ttl;
73            } else if ttl > self.positive_max_ttl {
74                ttl = self.positive_max_ttl;
75            }
76        } else if ttl < self.negative_min_ttl {
77            ttl = self.negative_min_ttl;
78        } else if ttl > self.negative_max_ttl {
79            ttl = self.negative_max_ttl;
80        }
81
82        let valid_until = now + ttl;
83
84        self.cache.insert(query, LruValue {
85            ipaddr,
86            valid_until,
87        });
88    }
89
90    pub fn get(&mut self, query: &str, now: Instant) -> Value {
91        if let Some(ipaddr) = self.cache.get_mut(query) {
92            if !ipaddr.is_fresh(now) {
93                self.cache.remove(query);
94                Value::None
95            } else if let Some(ipaddr) = ipaddr.ipaddr {
96                Value::Some(ipaddr)
97            } else {
98                Value::NX
99            }
100        } else {
101            Value::None
102        }
103    }
104}
105
106impl Default for DnsCache {
107    fn default() -> DnsCache {
108        DnsCache::new(32, TtlConfig::default())
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115    use std::time::{Duration, Instant};
116
117    #[test]
118    fn verify_insert() {
119        let now = Instant::now();
120        let mut cache = DnsCache::default();
121        let ipaddr = "1.1.1.1".parse().unwrap();
122        cache.insert("example.com".into(), Some(ipaddr), Duration::from_secs(1), now);
123    }
124
125    #[test]
126    fn verify_get() {
127        let now = Instant::now();
128        let mut cache = DnsCache::default();
129        let ipaddr = "1.1.1.1".parse().unwrap();
130        cache.insert("example.com".into(), Some(ipaddr), Duration::from_secs(1), now);
131        assert_eq!(cache.get("example.com", now), Value::Some(ipaddr));
132    }
133
134    #[test]
135    fn verify_expire() {
136        let now = Instant::now();
137        let mut cache = DnsCache::default();
138        let ipaddr = "1.1.1.1".parse().unwrap();
139        cache.insert("example.com".into(), Some(ipaddr), Duration::from_secs(1), now);
140        let now = now + Duration::from_secs(2);
141        assert_eq!(cache.get("example.com", now), Value::None);
142    }
143}