1extern crate lru;
2
3use lru::LruCache;
4use std::{
5 sync::{Arc, Mutex},
6 time::{Duration, Instant},
7};
8use trust_dns_proto::op::{message::Message, Query};
9
10#[derive(Debug, Hash, PartialEq, Eq)]
11struct Key {
12 query: Query,
13}
14
15#[derive(Debug)]
16struct Value {
17 message: Message,
18 instant: Instant,
19 ttl: Duration,
20}
21
22#[derive(Clone, Debug)]
23pub struct Cache {
24 lru_cache: Arc<Mutex<LruCache<Key, Value>>>,
25}
26
27impl Cache {
28 pub fn new() -> Self {
29 Cache {
30 lru_cache: Arc::new(Mutex::new(LruCache::new(1024))),
31 }
32 }
33
34 pub fn put(&mut self, message: Message) {
35 if message.queries().is_empty() {
36 return;
37 }
38
39 let query = message.queries()[0].clone();
40 let key = Key { query };
41
42 if let Some(min_record) = message
43 .answers()
44 .iter()
45 .min_by(|record_1, record_2| record_1.ttl().cmp(&record_2.ttl()))
46 {
47 let value = Value {
48 ttl: Duration::from_secs(min_record.ttl().into()),
49 instant: Instant::now(),
50 message,
51 };
52
53 let mut lru_cache = self.lru_cache.lock().unwrap();
54 lru_cache.put(key, value);
55 };
56 }
57
58 pub fn get(&mut self, message: &Message) -> Option<Message> {
59 let mut lru_cache = self.lru_cache.lock().unwrap();
60 if lru_cache.len() == 0 || message.queries().is_empty() {
61 return None;
62 }
63
64 let message_id = message.id();
65 let query = message.queries()[0].clone();
66 let cache_key = Key { query };
67
68 let cache_value = match lru_cache.get(&cache_key) {
69 Some(cache_value) => cache_value,
70 None => {
71 return None;
72 }
73 };
74
75 let instant = cache_value.instant;
76 let ttl = cache_value.ttl;
77 let mut message = cache_value.message.clone();
78
79 if instant.elapsed() < ttl {
80 message.set_id(message_id);
81 Some(message)
82 } else {
83 lru_cache.pop(&cache_key);
84 None
85 }
86 }
87}
88
89impl Default for Cache {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use super::Cache;
98 use std::net::Ipv4Addr;
99 use trust_dns_proto::{
100 op::{message::Message, Query},
101 rr::{Name, RData, Record, RecordType},
102 };
103
104 #[test]
105 fn test_cache_hit() {
106 let mut cache = Cache::new();
107 let mut query = Query::new();
108 let name: Name = "example.com".parse().unwrap();
109 query.set_name(name.clone());
110
111 let mut answer = Record::with(name, RecordType::A, 1000);
112 answer.set_data(Some(RData::A(Ipv4Addr::new(1, 1, 1, 1))));
113
114 let mut response_message = Message::new();
115 response_message.add_query(query.clone());
116 response_message.add_answer(answer);
117 cache.put(response_message);
118
119 let mut request_message = Message::new();
120 let request_message = request_message.add_query(query);
121 cache.get(request_message).unwrap();
122 }
123
124 #[test]
125 #[should_panic]
126 fn test_cache_expire() {
127 let mut cache = Cache::new();
128 let mut query = Query::new();
129 let name: Name = "example.com".parse().unwrap();
130 query.set_name(name.clone());
131
132 let mut answer = Record::with(name, RecordType::A, 0);
133 answer.set_data(Some(RData::A(Ipv4Addr::new(1, 1, 1, 1))));
134
135 let mut response_message = Message::new();
136 response_message.add_query(query.clone());
137 response_message.add_answer(answer);
138 cache.put(response_message);
139
140 let mut request_message = Message::new();
141 request_message.add_query(query);
142 cache.get(&request_message).unwrap();
143 }
144}