1use std::collections::HashMap;
3use std::path::Path;
4use std::{fs, io, net};
5
6pub use nakamoto_common::p2p::peer::*;
7
8#[derive(Debug)]
10pub struct Cache {
11 addrs: HashMap<net::IpAddr, KnownAddress>,
12 file: fs::File,
13}
14
15impl Cache {
16 pub fn open<P: AsRef<Path>>(path: P) -> io::Result<Self> {
18 fs::OpenOptions::new()
19 .read(true)
20 .write(true)
21 .open(path)
22 .and_then(Self::from)
23 }
24
25 pub fn create<P: AsRef<Path>>(path: P) -> io::Result<Self> {
27 let file = fs::OpenOptions::new()
28 .create_new(true)
29 .write(true)
30 .open(path)?;
31
32 Ok(Self {
33 file,
34 addrs: HashMap::new(),
35 })
36 }
37
38 pub fn from(mut file: fs::File) -> io::Result<Self> {
40 use io::Read;
41 use microserde::json::Value;
42 use std::str::FromStr;
43
44 let mut s = String::new();
45 let mut addrs = HashMap::new();
46
47 file.read_to_string(&mut s)?;
48
49 if !s.is_empty() {
50 let val = microserde::json::from_str(&s)
51 .map_err(|_| io::Error::from(io::ErrorKind::InvalidData))?;
52
53 match val {
54 Value::Object(ary) => {
55 for (k, v) in ary.into_iter() {
56 let ka = KnownAddress::from_json(v)
57 .map_err(|_| io::Error::from(io::ErrorKind::InvalidData))?;
58 let ip = net::IpAddr::from_str(k.as_str())
59 .map_err(|_| io::Error::from(io::ErrorKind::InvalidData))?;
60
61 addrs.insert(ip, ka);
62 }
63 }
64 _ => return Err(io::ErrorKind::InvalidData.into()),
65 }
66 }
67
68 Ok(Self { file, addrs })
69 }
70}
71
72impl Store for Cache {
73 fn get_mut(&mut self, ip: &net::IpAddr) -> Option<&mut KnownAddress> {
74 self.addrs.get_mut(ip)
75 }
76
77 fn get(&self, ip: &net::IpAddr) -> Option<&KnownAddress> {
78 self.addrs.get(ip)
79 }
80
81 fn remove(&mut self, ip: &net::IpAddr) -> Option<KnownAddress> {
82 self.addrs.remove(ip)
83 }
84
85 fn insert(&mut self, ip: net::IpAddr, ka: KnownAddress) -> bool {
86 <HashMap<_, _> as Store>::insert(&mut self.addrs, ip, ka)
87 }
88
89 fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (&net::IpAddr, &KnownAddress)> + 'a> {
90 Box::new(self.addrs.iter())
91 }
92
93 fn clear(&mut self) {
94 self.addrs.clear()
95 }
96
97 fn len(&self) -> usize {
98 self.addrs.len()
99 }
100
101 fn flush<'a>(&mut self) -> io::Result<()> {
102 use io::{Seek, Write};
103 use microserde::json::Value;
104
105 let peers: microserde::json::Object = self
106 .addrs
107 .iter()
108 .map(|(ip, ka)| (ip.to_string(), ka.to_json()))
109 .collect();
110 let s = microserde::json::to_string(&Value::Object(peers));
111
112 self.file.set_len(0)?;
113 self.file.seek(io::SeekFrom::Start(0))?;
114 self.file.write_all(s.as_bytes())?;
115 self.file.write_all(&[b'\n'])?;
116 self.file.sync_data()?;
117
118 Ok(())
119 }
120}
121
122#[cfg(test)]
123mod test {
124 use super::*;
125 use nakamoto_common::bitcoin::network::address::Address;
126 use nakamoto_common::bitcoin::network::constants::ServiceFlags;
127 use nakamoto_common::block::time::LocalTime;
128
129 #[test]
130 fn test_empty() {
131 let tmp = tempfile::tempdir().unwrap();
132 let path = tmp.path().join("cache");
133
134 Cache::create(&path).unwrap();
135 let cache = Cache::open(&path).unwrap();
136
137 assert!(cache.is_empty());
138 }
139
140 #[test]
141 fn test_save_and_load() {
142 let tmp = tempfile::tempdir().unwrap();
143 let path = tmp.path().join("cache");
144 let mut expected = Vec::new();
145
146 {
147 let mut cache = Cache::create(&path).unwrap();
148
149 for i in 32..48 {
150 let ip = net::IpAddr::from([127, 0, 0, i]);
151 let sockaddr = net::SocketAddr::from((ip, 8333));
152 let services = ServiceFlags::NETWORK;
153 let ka = KnownAddress {
154 addr: Address::new(&sockaddr, services),
155 source: Source::Dns,
156 last_success: Some(LocalTime::from_secs(i as u64)),
157 last_sampled: Some(LocalTime::from_secs((i + 1) as u64)),
158 last_attempt: None,
159 last_active: None,
160 };
161 cache.insert(ip, ka);
162 }
163 cache.flush().unwrap();
164
165 for (ip, ka) in cache.iter() {
166 expected.push((*ip, ka.clone()));
167 }
168 }
169
170 {
171 let cache = Cache::open(&path).unwrap();
172 let mut actual = cache
173 .iter()
174 .map(|(i, ka)| (*i, ka.clone()))
175 .collect::<Vec<_>>();
176
177 actual.sort_by_key(|(i, _)| *i);
178 expected.sort_by_key(|(i, _)| *i);
179
180 assert_eq!(actual, expected);
181 }
182 }
183}