https_dns/
cache.rs

1extern crate lru;
2
3use lru::LruCache;
4use std::{
5    sync::{Arc, Mutex},
6    time::{Duration, Instant},
7};
8use trust_dns_proto::op::{message::Message, Query};
9
10#[derive(Debug, Hash, PartialEq, Eq)]
11struct Key {
12    query: Query,
13}
14
15#[derive(Debug)]
16struct Value {
17    message: Message,
18    instant: Instant,
19    ttl: Duration,
20}
21
22#[derive(Clone, Debug)]
23pub struct Cache {
24    lru_cache: Arc<Mutex<LruCache<Key, Value>>>,
25}
26
27impl Cache {
28    pub fn new() -> Self {
29        Cache {
30            lru_cache: Arc::new(Mutex::new(LruCache::new(1024))),
31        }
32    }
33
34    pub fn put(&mut self, message: Message) {
35        if message.queries().is_empty() {
36            return;
37        }
38
39        let query = message.queries()[0].clone();
40        let key = Key { query };
41
42        if let Some(min_record) = message
43            .answers()
44            .iter()
45            .min_by(|record_1, record_2| record_1.ttl().cmp(&record_2.ttl()))
46        {
47            let value = Value {
48                ttl: Duration::from_secs(min_record.ttl().into()),
49                instant: Instant::now(),
50                message,
51            };
52
53            let mut lru_cache = self.lru_cache.lock().unwrap();
54            lru_cache.put(key, value);
55        };
56    }
57
58    pub fn get(&mut self, message: &Message) -> Option<Message> {
59        let mut lru_cache = self.lru_cache.lock().unwrap();
60        if lru_cache.len() == 0 || message.queries().is_empty() {
61            return None;
62        }
63
64        let message_id = message.id();
65        let query = message.queries()[0].clone();
66        let cache_key = Key { query };
67
68        let cache_value = match lru_cache.get(&cache_key) {
69            Some(cache_value) => cache_value,
70            None => {
71                return None;
72            }
73        };
74
75        let instant = cache_value.instant;
76        let ttl = cache_value.ttl;
77        let mut message = cache_value.message.clone();
78
79        if instant.elapsed() < ttl {
80            message.set_id(message_id);
81            Some(message)
82        } else {
83            lru_cache.pop(&cache_key);
84            None
85        }
86    }
87}
88
89impl Default for Cache {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use super::Cache;
98    use std::net::Ipv4Addr;
99    use trust_dns_proto::{
100        op::{message::Message, Query},
101        rr::{Name, RData, Record, RecordType},
102    };
103
104    #[test]
105    fn test_cache_hit() {
106        let mut cache = Cache::new();
107        let mut query = Query::new();
108        let name: Name = "example.com".parse().unwrap();
109        query.set_name(name.clone());
110
111        let mut answer = Record::with(name, RecordType::A, 1000);
112        answer.set_data(Some(RData::A(Ipv4Addr::new(1, 1, 1, 1))));
113
114        let mut response_message = Message::new();
115        response_message.add_query(query.clone());
116        response_message.add_answer(answer);
117        cache.put(response_message);
118
119        let mut request_message = Message::new();
120        let request_message = request_message.add_query(query);
121        cache.get(request_message).unwrap();
122    }
123
124    #[test]
125    #[should_panic]
126    fn test_cache_expire() {
127        let mut cache = Cache::new();
128        let mut query = Query::new();
129        let name: Name = "example.com".parse().unwrap();
130        query.set_name(name.clone());
131
132        let mut answer = Record::with(name, RecordType::A, 0);
133        answer.set_data(Some(RData::A(Ipv4Addr::new(1, 1, 1, 1))));
134
135        let mut response_message = Message::new();
136        response_message.add_query(query.clone());
137        response_message.add_answer(answer);
138        cache.put(response_message);
139
140        let mut request_message = Message::new();
141        request_message.add_query(query);
142        cache.get(&request_message).unwrap();
143    }
144}