public_ip_address/cache.rs
1//! # 🗄️ Response cache Module
2//!
3//! This module provides a `ResponseCache` struct that holds the current IP address lookup response and the time it was created, and when it should expire.
4//!
5//! The `ResponseCache` can be saved to disk, loaded from disk, and deleted from disk. It also provides methods to clear the cache,
6//! update the cache with a new response, check if the cache has expired, and retrieve the IP address or the entire response from the cache.
7//!
8//! The cache is stored in a JSON format by default in the system cache directory. And a custom file name can be provided.
9//!
10//! If the `encryption` feature is enabled, the cache is encrypted using AEAD.
11//!
12//! ## Example
13//! ```rust
14//! use std::error::Error;
15//! use public_ip_address::{cache::ResponseCache, response::LookupResponse};
16//! use public_ip_address::lookup::LookupProvider;
17//!
18//! fn main() -> Result<(), Box<dyn Error>> {
19//! let response = LookupResponse::new(
20//! "1.1.1.1".parse::<std::net::IpAddr>()?,
21//! LookupProvider::IpBase,
22//! );
23//! let mut response_cache = ResponseCache::new(None);
24//! response_cache.update_current(&response, None);
25//! response_cache.save()?;
26//! let cached = ResponseCache::load(None)?;
27//! println!("{:?}", cached);
28//! cached.delete()?;
29//! Ok(())
30//! }
31//! ```
32
33use crate::{error::CacheError, LookupResponse};
34use etcetera::{choose_base_strategy, BaseStrategy};
35use log::{debug, trace};
36use serde::{Deserialize, Serialize};
37use std::{
38 collections::BTreeMap,
39 fs,
40 fs::File,
41 io::prelude::*,
42 net::IpAddr,
43 time::{Duration, SystemTime},
44};
45
46#[cfg(feature = "encryption")]
47use cocoon::Cocoon;
48
49/// Result type wrapper for the cache
50pub type Result<T> = std::result::Result<T, CacheError>;
51
52/// Represents an entry of the cached response
53///
54/// It contains the `LookupResponse`, the time when the response was cached, and the time-to-live (TTL) of the cache.
55#[derive(Serialize, Deserialize, PartialEq, Debug)]
56#[non_exhaustive]
57pub struct ResponseRecord {
58 /// Cached response
59 pub response: LookupResponse,
60 response_time: SystemTime,
61 ttl: Option<u64>,
62}
63
64impl ResponseRecord {
65 /// Creates a new `ResponseRecord` instance.
66 ///
67 /// # Arguments
68 ///
69 /// * `response` - A `LookupResponse` to be cached.
70 /// * `ttl` - An optional `u64` value representing after how many seconds the cached value expires.
71 /// None means the cache never expires.
72 pub fn new(response: LookupResponse, ttl: Option<u64>) -> ResponseRecord {
73 ResponseRecord {
74 response,
75 response_time: SystemTime::now(),
76 ttl,
77 }
78 }
79
80 /// Determines if the cached response has expired.
81 ///
82 /// If the TTL is not set, the function assumes that the cache never expires and returns false.
83 pub fn is_expired(&self) -> bool {
84 if let Some(ttl) = self.ttl {
85 let difference = SystemTime::now()
86 .duration_since(self.response_time)
87 .unwrap_or_default();
88 difference >= Duration::from_secs(ttl)
89 } else {
90 // No TTL, cache never expires
91 false
92 }
93 }
94
95 /// Returns the IP address of the cached response.
96 pub fn ip(&self) -> std::net::IpAddr {
97 self.response.ip
98 }
99}
100
101/// Holds the current IP address lookup response
102///
103/// The cache can be saved to disk, loaded from disk, and deleted from disk. It also provides methods to clear the cache,
104/// update the cache with a new response, check if the cache has expired, and retrieve the IP address or the entire response from the cache.
105#[derive(Serialize, Deserialize, Debug, Default, PartialEq)]
106#[non_exhaustive]
107pub struct ResponseCache {
108 /// The current IP address lookup response
109 pub current_address: Option<ResponseRecord>,
110 /// A tree of arbitrary IP address responses
111 pub lookup_address: BTreeMap<IpAddr, ResponseRecord>,
112 /// The cache file name
113 file_name: Option<String>,
114}
115
116impl ResponseCache {
117 /// Creates a new `ResponseCache` instance.
118 ///
119 /// The `ResponseRecord` is stored as the `current_address` in the `ResponseCache`.
120 ///
121 /// # Arguments
122 ///
123 /// * `file_name` - An `Option<String>` representing the name of the file where the cache will be stored. If `None`, no file will be used.
124 ///
125 /// # Examples
126 ///
127 /// ```
128 /// # use public_ip_address::cache::ResponseCache;
129 /// # use public_ip_address::lookup::LookupProvider;
130 /// # use public_ip_address::response::LookupResponse;
131 /// let response = LookupResponse::new(
132 /// "1.1.1.1".parse::<std::net::IpAddr>().unwrap(),
133 /// LookupProvider::IpBase);
134 /// let mut cache = ResponseCache::new(None);
135 /// cache.update_current(&response, None);
136 /// ```
137 ///
138 /// ```
139 /// # use public_ip_address::cache::ResponseCache;
140 /// let cache = ResponseCache::new(Some("cache.txt".to_string()));
141 /// ```
142 ///
143 pub fn new(file_name: Option<String>) -> ResponseCache {
144 trace!("Creating new cache structure");
145 ResponseCache {
146 current_address: None,
147 lookup_address: BTreeMap::new(),
148 file_name,
149 }
150 }
151
152 /// Clears the cache.
153 ///
154 /// # Examples
155 ///
156 /// ```
157 /// # use public_ip_address::cache::ResponseCache;
158 /// let mut cache = ResponseCache::default();
159 /// cache.clear();
160 /// assert!(cache.current_response().is_none());
161 /// ```
162 pub fn clear(&mut self) {
163 trace!("Clearing cache");
164 self.current_address = None;
165 self.lookup_address.clear();
166 }
167
168 /// Updates the cache entry for the current host with a new response.
169 ///
170 /// # Arguments
171 ///
172 /// * `response` - A `LookupResponse` instance representing the new address to be cached.
173 /// * `ttl` - An `Option<u64>` representing the time-to-live (TTL) in seconds for the new cached response. If `None`, the cache never expires.
174 ///
175 pub fn update_current(&mut self, response: &LookupResponse, ttl: Option<u64>) {
176 self.current_address = Some(ResponseRecord::new(response.to_owned(), ttl));
177 }
178
179 /// Checks if the `current_address` cache entry has expired.
180 pub fn current_is_expired(&self) -> bool {
181 match self.current_address {
182 Some(ref current) => current.is_expired(),
183 None => true,
184 }
185 }
186
187 /// Returns the IP address of the current host cache entry.
188 pub fn current_ip(&self) -> Option<std::net::IpAddr> {
189 self.current_address.as_ref().map(|current| current.ip())
190 }
191
192 /// Returns the `current_address` cache entry.
193 pub fn current_response(&self) -> Option<LookupResponse> {
194 self.current_address
195 .as_ref()
196 .map(|current| current.response.to_owned())
197 }
198
199 /// Updates the lookup cache with a new response.
200 pub fn update_target(&mut self, ip: IpAddr, response: &LookupResponse, ttl: Option<u64>) {
201 self.lookup_address
202 .insert(ip, ResponseRecord::new(response.to_owned(), ttl));
203 }
204
205 /// Checks if the lookup cache entry for the given IP address has expired.
206 pub fn target_is_expired(&self, ip: &IpAddr) -> bool {
207 match self.lookup_address.get(ip) {
208 Some(lookup) => lookup.is_expired(),
209 None => true,
210 }
211 }
212
213 /// Returns lookup cached entry for the given IP address.
214 pub fn target_response(&self, ip: &IpAddr) -> Option<LookupResponse> {
215 self.lookup_address
216 .get(ip)
217 .map(|lookup| lookup.response.to_owned())
218 }
219
220 /// Writes the `ResponseCache` instance to a file on disk.
221 ///
222 /// This method serializes the `ResponseCache` instance into a JSON string, encrypts the data if the "encryption" feature is enabled,
223 /// and then writes the encrypted (or plain text) data to a file. The file is located at the path specified by the `file_name` field of the `ResponseCache` instance.
224 ///
225 /// # Examples
226 ///
227 /// ```
228 /// # use public_ip_address::cache::ResponseCache;
229 /// let cache = ResponseCache::new(Some("cache.txt".to_string()));
230 /// _ = cache.save();
231 /// ```
232 pub fn save(&self) -> Result<()> {
233 debug!("Saving cache to {}", get_cache_path(&self.file_name));
234 let data = serde_json::to_string(self)?.into_bytes();
235
236 #[cfg(feature = "encryption")]
237 let data = encrypt(data)?;
238
239 let mut file = File::create(get_cache_path(&self.file_name))?;
240 file.write_all(&data)?;
241 Ok(())
242 }
243
244 /// Loads the `ResponseCache` instance from a file on disk.
245 ///
246 /// This method reads the file specified by `file_name`, decrypts the data if the "encryption" feature is enabled,
247 /// and then deserializes the data into a `ResponseCache` instance.
248 ///
249 /// # Arguments
250 ///
251 /// * `file_name` - An `Option<String>` representing the name of the file from which the cache will be loaded.
252 /// If `None`, the default file name `lookup.cache` will be used.
253 ///
254 /// # Examples
255 ///
256 /// ```
257 /// # use public_ip_address::cache::ResponseCache;
258 /// let cache = ResponseCache::load(Some("cache.txt".to_string()));
259 /// ```
260 pub fn load(file_name: Option<String>) -> Result<ResponseCache> {
261 debug!("Loading cache from {}", get_cache_path(&file_name));
262 let mut file = File::open(get_cache_path(&file_name))?;
263 let mut data = Vec::new();
264 file.read_to_end(&mut data)?;
265
266 #[cfg(feature = "encryption")]
267 let data = decrypt(data)?;
268
269 let decoded = String::from_utf8(data).unwrap_or_default();
270 let deserialized: ResponseCache = serde_json::from_str(&decoded)?;
271 Ok(deserialized)
272 }
273
274 /// Deletes the `ResponseCache` instance from disk.
275 pub fn delete(self) -> Result<()> {
276 trace!("Deleting cache file {}", get_cache_path(&self.file_name));
277 fs::remove_file(get_cache_path(&self.file_name))?;
278 Ok(())
279 }
280}
281
282/// Determines the path for the cache file.
283///
284/// This function uses a series of fallbacks to find a suitable directory for the cache file:
285/// 1. It first tries to use the system's cache directory, as determined by the `BaseStrategy`.
286/// 2. If the cache directory doesn't exist, it tries to create it.
287/// 3. If it can't create the cache directory, it falls back to the system's data directory.
288/// 4. If it can't use the data directory, it falls back to the user's home directory.
289/// 5. If it can't use the home directory, it falls back to the current directory.
290///
291/// The cache file is named "lookup.cache" by default. However, this can be overridden by providing a different name as a parameter.
292///
293/// # Arguments
294///
295/// * `file_name` - An `Option<String>` representing the desired name of the cache file. If `None`, the default name "lookup.cache" is used.
296///
297/// # Returns
298///
299/// * `String` - The path to the cache file.
300///
301/// # Examples
302///
303/// ```
304/// # use public_ip_address::cache::get_cache_path;
305/// let cache_path = get_cache_path(&Some("my_cache.txt".to_string()));
306/// ```
307pub fn get_cache_path(file_name: &Option<String>) -> String {
308 let file_name = if let Some(file_name) = file_name {
309 file_name
310 } else {
311 "lookup.cache"
312 };
313
314 if let Ok(base_dirs) = choose_base_strategy() {
315 let mut dir = base_dirs.cache_dir();
316 // Create cache directory if it doesn't exist
317 if !dir.exists() && fs::create_dir_all(&dir).is_err() {
318 // If we can't create the cache directory, fallback to data directory
319 dir = base_dirs.data_dir();
320 if !dir.exists() && fs::create_dir_all(&dir).is_err() {
321 // If we can't create the data directory, fallback to home directory
322 dir = base_dirs.home_dir().to_path_buf();
323 }
324 }
325 if let Some(path) = dir.join(file_name).to_str() {
326 return path.to_string();
327 }
328 };
329 // As last resort, fallback to current directory
330 file_name.to_string()
331}
332
333/// Decrypts the given data using AEAD.
334///
335/// In debug mode, it uses a weaker key derivation function for faster speed.
336///
337/// # Arguments
338///
339/// * `data` - The data to be decrypted, as a vector of bytes.
340///
341/// # Returns
342///
343/// * `Ok(Vec<u8>)` - The decrypted data, as a vector of bytes.
344/// * `Err(CacheError::EncryptionError)` - If there was an error during decryption.
345#[cfg(feature = "encryption")]
346fn decrypt(data: Vec<u8>) -> Result<Vec<u8>> {
347 trace!("Decrypting data");
348 let password = mid::get(env!("CARGO_PKG_NAME")).unwrap_or("lookup".to_string());
349 let cocoon = if cfg!(debug_assertions) {
350 Cocoon::new(password.as_bytes()).with_weak_kdf()
351 } else {
352 Cocoon::new(password.as_bytes())
353 };
354 match cocoon.unwrap(&data) {
355 Ok(data) => Ok(data),
356 Err(e) => Err(CacheError::EncryptionError(format!(
357 "Error decrypting: {:?}",
358 e
359 ))),
360 }
361}
362
363/// Encrypts the given data using AEAD.
364///
365/// In debug mode, it uses a weaker key derivation function for faster speed.
366///
367/// # Arguments
368///
369/// * `data` - The data to be encrypted, as a vector of bytes.
370///
371/// # Returns
372///
373/// * `Ok(Vec<u8>)` - The encrypted data, as a vector of bytes.
374/// * `Err(CacheError::EncryptionError)` - If there was an error during encryption.
375#[cfg(feature = "encryption")]
376fn encrypt(data: Vec<u8>) -> Result<Vec<u8>> {
377 trace!("Encrypting data");
378 let password = mid::get(env!("CARGO_PKG_NAME")).unwrap_or("lookup".to_string());
379 let mut cocoon = if cfg!(debug_assertions) {
380 Cocoon::new(password.as_bytes()).with_weak_kdf()
381 } else {
382 Cocoon::new(password.as_bytes())
383 };
384 match cocoon.wrap(&data) {
385 Ok(data) => Ok(data),
386 Err(e) => Err(CacheError::EncryptionError(format!(
387 "Error encrypting: {:?}",
388 e
389 ))),
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396 use crate::lookup::LookupProvider;
397 use serial_test::serial;
398
399 #[test]
400 #[serial]
401 fn test_cache_file() {
402 let response = LookupResponse::new(
403 "1.1.1.1".parse().unwrap(),
404 LookupProvider::Mock("1.1.1.1".to_string(), "localhost".to_string()),
405 );
406 println!("{}", get_cache_path(&None));
407 let mut cache = ResponseCache::new(None);
408 cache.update_current(&response, None);
409 cache.save().unwrap();
410 let cached = ResponseCache::load(None).unwrap();
411 assert_eq!(
412 cached.current_ip().unwrap(),
413 "1.1.1.1".parse::<std::net::IpAddr>().unwrap(),
414 "IP address not matching"
415 );
416 cache.delete().unwrap();
417 }
418
419 #[test]
420 fn test_expired() {
421 let response = LookupResponse::new(
422 "1.1.1.1".parse().unwrap(),
423 LookupProvider::Mock("1.1.1.1".to_string(), "localhost".to_string()),
424 );
425 let mut cache = ResponseCache::default();
426 assert!(cache.current_is_expired(), "Empty cache should be expired");
427 cache.update_current(&response, None);
428 assert_eq!(
429 cache.current_ip().unwrap(),
430 "1.1.1.1".parse::<std::net::IpAddr>().unwrap(),
431 "IP address not matching"
432 );
433 assert!(
434 !cache.current_is_expired(),
435 "Cache with no TTL should not be expired"
436 );
437 cache.update_current(&response, Some(1));
438 assert!(
439 !cache.current_is_expired(),
440 "Fresh cache should not be expired {cache:#?}"
441 );
442 // Wait for cache to expire
443 std::thread::sleep(Duration::from_secs(1));
444 assert!(
445 cache.current_is_expired(),
446 "Expired cache should be expired"
447 );
448 }
449
450 #[test]
451 fn test_cache_tree() {
452 let addresses = [
453 "1.1.1.1".parse().unwrap(),
454 "2.1.1.1".parse().unwrap(),
455 "3.1.1.1".parse().unwrap(),
456 ];
457 let mut cache = ResponseCache::default();
458 for address in &addresses {
459 let response = LookupResponse::new(*address, LookupProvider::Ipify);
460 cache.update_target(*address, &response, None);
461 }
462
463 for address in &addresses {
464 assert_eq!(
465 cache.target_response(address).unwrap().ip,
466 *address,
467 "IP address not matching: {cache:#?}"
468 );
469 }
470 }
471
472 #[test]
473 fn test_cache_clear() {
474 let response = LookupResponse::new(
475 "1.1.1.1".parse().unwrap(),
476 LookupProvider::Mock("1.1.1.1".to_string(), "localhost".to_string()),
477 );
478 let mut cache = ResponseCache::new(None);
479 cache.update_current(&response, None);
480 let response = LookupResponse::new("2.2.2.2".parse().unwrap(), LookupProvider::Ipify);
481 cache.update_target(response.ip, &response, None);
482 cache.clear();
483 let cache = ResponseCache::default();
484 assert_eq!(
485 cache,
486 ResponseCache::default(),
487 "Cache not cleared properly: {cache:#?}"
488 );
489 }
490
491 #[test]
492 #[cfg(feature = "encryption")]
493 fn test_encrypt_decrypt() {
494 let data = b"hello world".to_vec();
495 let encrypted = encrypt(data.clone()).unwrap();
496 let decrypted = decrypt(encrypted).unwrap();
497 assert_eq!(data, decrypted);
498 }
499}