public_ip_address/
cache.rs

1//! # 🗄️ Response cache Module
2//!
3//! This module provides a `ResponseCache` struct that holds the current IP address lookup response and the time it was created, and when it should expire.
4//!
5//! The `ResponseCache` can be saved to disk, loaded from disk, and deleted from disk. It also provides methods to clear the cache,
6//! update the cache with a new response, check if the cache has expired, and retrieve the IP address or the entire response from the cache.
7//!
8//! The cache is stored in a JSON format by default in the system cache directory. And a custom file name can be provided.
9//!
10//! If the `encryption` feature is enabled, the cache is encrypted using AEAD.
11//!
12//! ## Example
13//! ```rust
14//! use std::error::Error;
15//! use public_ip_address::{cache::ResponseCache, response::LookupResponse};
16//! use public_ip_address::lookup::LookupProvider;
17//!
18//! fn main() -> Result<(), Box<dyn Error>> {
19//!     let response = LookupResponse::new(
20//!             "1.1.1.1".parse::<std::net::IpAddr>()?,
21//!             LookupProvider::IpBase,
22//!     );
23//!     let mut response_cache = ResponseCache::new(None);
24//!     response_cache.update_current(&response, None);
25//!     response_cache.save()?;
26//!     let cached = ResponseCache::load(None)?;
27//!     println!("{:?}", cached);
28//!     cached.delete()?;
29//!     Ok(())
30//! }
31//! ```
32
33use crate::{error::CacheError, LookupResponse};
34use etcetera::{choose_base_strategy, BaseStrategy};
35use log::{debug, trace};
36use serde::{Deserialize, Serialize};
37use std::{
38    collections::BTreeMap,
39    fs,
40    fs::File,
41    io::prelude::*,
42    net::IpAddr,
43    time::{Duration, SystemTime},
44};
45
46#[cfg(feature = "encryption")]
47use cocoon::Cocoon;
48
49/// Result type wrapper for the cache
50pub type Result<T> = std::result::Result<T, CacheError>;
51
52/// Represents an entry of the cached response
53///
54/// It contains the `LookupResponse`, the time when the response was cached, and the time-to-live (TTL) of the cache.
55#[derive(Serialize, Deserialize, PartialEq, Debug)]
56#[non_exhaustive]
57pub struct ResponseRecord {
58    /// Cached response
59    pub response: LookupResponse,
60    response_time: SystemTime,
61    ttl: Option<u64>,
62}
63
64impl ResponseRecord {
65    /// Creates a new `ResponseRecord` instance.
66    ///
67    /// # Arguments
68    ///
69    /// * `response` - A `LookupResponse` to be cached.
70    /// * `ttl` - An optional `u64` value representing after how many seconds the cached value expires.
71    ///   None means the cache never expires.
72    pub fn new(response: LookupResponse, ttl: Option<u64>) -> ResponseRecord {
73        ResponseRecord {
74            response,
75            response_time: SystemTime::now(),
76            ttl,
77        }
78    }
79
80    /// Determines if the cached response has expired.
81    ///
82    /// If the TTL is not set, the function assumes that the cache never expires and returns false.
83    pub fn is_expired(&self) -> bool {
84        if let Some(ttl) = self.ttl {
85            let difference = SystemTime::now()
86                .duration_since(self.response_time)
87                .unwrap_or_default();
88            difference >= Duration::from_secs(ttl)
89        } else {
90            // No TTL, cache never expires
91            false
92        }
93    }
94
95    /// Returns the IP address of the cached response.
96    pub fn ip(&self) -> std::net::IpAddr {
97        self.response.ip
98    }
99}
100
101/// Holds the current IP address lookup response
102///
103/// The cache can be saved to disk, loaded from disk, and deleted from disk. It also provides methods to clear the cache,
104/// update the cache with a new response, check if the cache has expired, and retrieve the IP address or the entire response from the cache.
105#[derive(Serialize, Deserialize, Debug, Default, PartialEq)]
106#[non_exhaustive]
107pub struct ResponseCache {
108    /// The current IP address lookup response
109    pub current_address: Option<ResponseRecord>,
110    /// A tree of arbitrary IP address responses
111    pub lookup_address: BTreeMap<IpAddr, ResponseRecord>,
112    /// The cache file name
113    file_name: Option<String>,
114}
115
116impl ResponseCache {
117    /// Creates a new `ResponseCache` instance.
118    ///
119    /// The `ResponseRecord` is stored as the `current_address` in the `ResponseCache`.
120    ///
121    /// # Arguments
122    ///
123    /// * `file_name` - An `Option<String>` representing the name of the file where the cache will be stored. If `None`, no file will be used.
124    ///
125    /// # Examples
126    ///
127    /// ```
128    /// # use public_ip_address::cache::ResponseCache;
129    /// # use public_ip_address::lookup::LookupProvider;
130    /// # use public_ip_address::response::LookupResponse;
131    /// let response = LookupResponse::new(
132    ///             "1.1.1.1".parse::<std::net::IpAddr>().unwrap(),
133    ///             LookupProvider::IpBase);
134    /// let mut cache = ResponseCache::new(None);
135    /// cache.update_current(&response, None);
136    /// ```
137    ///
138    /// ```
139    /// # use public_ip_address::cache::ResponseCache;
140    /// let cache = ResponseCache::new(Some("cache.txt".to_string()));
141    /// ```
142    ///
143    pub fn new(file_name: Option<String>) -> ResponseCache {
144        trace!("Creating new cache structure");
145        ResponseCache {
146            current_address: None,
147            lookup_address: BTreeMap::new(),
148            file_name,
149        }
150    }
151
152    /// Clears the cache.
153    ///
154    /// # Examples
155    ///
156    /// ```
157    /// # use public_ip_address::cache::ResponseCache;
158    /// let mut cache = ResponseCache::default();
159    /// cache.clear();
160    /// assert!(cache.current_response().is_none());
161    /// ```
162    pub fn clear(&mut self) {
163        trace!("Clearing cache");
164        self.current_address = None;
165        self.lookup_address.clear();
166    }
167
168    /// Updates the cache entry for the current host with a new response.
169    ///
170    /// # Arguments
171    ///
172    /// * `response` - A `LookupResponse` instance representing the new address to be cached.
173    /// * `ttl` - An `Option<u64>` representing the time-to-live (TTL) in seconds for the new cached response. If `None`, the cache never expires.
174    ///
175    pub fn update_current(&mut self, response: &LookupResponse, ttl: Option<u64>) {
176        self.current_address = Some(ResponseRecord::new(response.to_owned(), ttl));
177    }
178
179    /// Checks if the `current_address` cache entry has expired.
180    pub fn current_is_expired(&self) -> bool {
181        match self.current_address {
182            Some(ref current) => current.is_expired(),
183            None => true,
184        }
185    }
186
187    /// Returns the IP address of the current host cache entry.
188    pub fn current_ip(&self) -> Option<std::net::IpAddr> {
189        self.current_address.as_ref().map(|current| current.ip())
190    }
191
192    /// Returns the `current_address` cache entry.
193    pub fn current_response(&self) -> Option<LookupResponse> {
194        self.current_address
195            .as_ref()
196            .map(|current| current.response.to_owned())
197    }
198
199    /// Updates the lookup cache with a new response.
200    pub fn update_target(&mut self, ip: IpAddr, response: &LookupResponse, ttl: Option<u64>) {
201        self.lookup_address
202            .insert(ip, ResponseRecord::new(response.to_owned(), ttl));
203    }
204
205    /// Checks if the lookup cache entry for the given IP address has expired.
206    pub fn target_is_expired(&self, ip: &IpAddr) -> bool {
207        match self.lookup_address.get(ip) {
208            Some(lookup) => lookup.is_expired(),
209            None => true,
210        }
211    }
212
213    /// Returns lookup cached entry for the given IP address.
214    pub fn target_response(&self, ip: &IpAddr) -> Option<LookupResponse> {
215        self.lookup_address
216            .get(ip)
217            .map(|lookup| lookup.response.to_owned())
218    }
219
220    /// Writes the `ResponseCache` instance to a file on disk.
221    ///
222    /// This method serializes the `ResponseCache` instance into a JSON string, encrypts the data if the "encryption" feature is enabled,
223    /// and then writes the encrypted (or plain text) data to a file. The file is located at the path specified by the `file_name` field of the `ResponseCache` instance.
224    ///
225    /// # Examples
226    ///
227    /// ```
228    /// # use public_ip_address::cache::ResponseCache;
229    /// let cache = ResponseCache::new(Some("cache.txt".to_string()));
230    /// _ = cache.save();
231    /// ```
232    pub fn save(&self) -> Result<()> {
233        debug!("Saving cache to {}", get_cache_path(&self.file_name));
234        let data = serde_json::to_string(self)?.into_bytes();
235
236        #[cfg(feature = "encryption")]
237        let data = encrypt(data)?;
238
239        let mut file = File::create(get_cache_path(&self.file_name))?;
240        file.write_all(&data)?;
241        Ok(())
242    }
243
244    /// Loads the `ResponseCache` instance from a file on disk.
245    ///
246    /// This method reads the file specified by `file_name`, decrypts the data if the "encryption" feature is enabled,
247    /// and then deserializes the data into a `ResponseCache` instance.
248    ///
249    /// # Arguments
250    ///
251    /// * `file_name` - An `Option<String>` representing the name of the file from which the cache will be loaded.
252    ///   If `None`, the default file name `lookup.cache` will be used.
253    ///
254    /// # Examples
255    ///
256    /// ```
257    /// # use public_ip_address::cache::ResponseCache;
258    /// let cache = ResponseCache::load(Some("cache.txt".to_string()));
259    /// ```
260    pub fn load(file_name: Option<String>) -> Result<ResponseCache> {
261        debug!("Loading cache from {}", get_cache_path(&file_name));
262        let mut file = File::open(get_cache_path(&file_name))?;
263        let mut data = Vec::new();
264        file.read_to_end(&mut data)?;
265
266        #[cfg(feature = "encryption")]
267        let data = decrypt(data)?;
268
269        let decoded = String::from_utf8(data).unwrap_or_default();
270        let deserialized: ResponseCache = serde_json::from_str(&decoded)?;
271        Ok(deserialized)
272    }
273
274    /// Deletes the `ResponseCache` instance from disk.
275    pub fn delete(self) -> Result<()> {
276        trace!("Deleting cache file {}", get_cache_path(&self.file_name));
277        fs::remove_file(get_cache_path(&self.file_name))?;
278        Ok(())
279    }
280}
281
282/// Determines the path for the cache file.
283///
284/// This function uses a series of fallbacks to find a suitable directory for the cache file:
285/// 1. It first tries to use the system's cache directory, as determined by the `BaseStrategy`.
286/// 2. If the cache directory doesn't exist, it tries to create it.
287/// 3. If it can't create the cache directory, it falls back to the system's data directory.
288/// 4. If it can't use the data directory, it falls back to the user's home directory.
289/// 5. If it can't use the home directory, it falls back to the current directory.
290///
291/// The cache file is named "lookup.cache" by default. However, this can be overridden by providing a different name as a parameter.
292///
293/// # Arguments
294///
295/// * `file_name` - An `Option<String>` representing the desired name of the cache file. If `None`, the default name "lookup.cache" is used.
296///
297/// # Returns
298///
299/// * `String` - The path to the cache file.
300///
301/// # Examples
302///
303/// ```
304/// # use public_ip_address::cache::get_cache_path;
305/// let cache_path = get_cache_path(&Some("my_cache.txt".to_string()));
306/// ```
307pub fn get_cache_path(file_name: &Option<String>) -> String {
308    let file_name = if let Some(file_name) = file_name {
309        file_name
310    } else {
311        "lookup.cache"
312    };
313
314    if let Ok(base_dirs) = choose_base_strategy() {
315        let mut dir = base_dirs.cache_dir();
316        // Create cache directory if it doesn't exist
317        if !dir.exists() && fs::create_dir_all(&dir).is_err() {
318            // If we can't create the cache directory, fallback to data directory
319            dir = base_dirs.data_dir();
320            if !dir.exists() && fs::create_dir_all(&dir).is_err() {
321                // If we can't create the data directory, fallback to home directory
322                dir = base_dirs.home_dir().to_path_buf();
323            }
324        }
325        if let Some(path) = dir.join(file_name).to_str() {
326            return path.to_string();
327        }
328    };
329    // As last resort, fallback to current directory
330    file_name.to_string()
331}
332
333/// Decrypts the given data using AEAD.
334///
335/// In debug mode, it uses a weaker key derivation function for faster speed.
336///
337/// # Arguments
338///
339/// * `data` - The data to be decrypted, as a vector of bytes.
340///
341/// # Returns
342///
343/// * `Ok(Vec<u8>)` - The decrypted data, as a vector of bytes.
344/// * `Err(CacheError::EncryptionError)` - If there was an error during decryption.
345#[cfg(feature = "encryption")]
346fn decrypt(data: Vec<u8>) -> Result<Vec<u8>> {
347    trace!("Decrypting data");
348    let password = mid::get(env!("CARGO_PKG_NAME")).unwrap_or("lookup".to_string());
349    let cocoon = if cfg!(debug_assertions) {
350        Cocoon::new(password.as_bytes()).with_weak_kdf()
351    } else {
352        Cocoon::new(password.as_bytes())
353    };
354    match cocoon.unwrap(&data) {
355        Ok(data) => Ok(data),
356        Err(e) => Err(CacheError::EncryptionError(format!(
357            "Error decrypting: {:?}",
358            e
359        ))),
360    }
361}
362
363/// Encrypts the given data using AEAD.
364///
365/// In debug mode, it uses a weaker key derivation function for faster speed.
366///
367/// # Arguments
368///
369/// * `data` - The data to be encrypted, as a vector of bytes.
370///
371/// # Returns
372///
373/// * `Ok(Vec<u8>)` - The encrypted data, as a vector of bytes.
374/// * `Err(CacheError::EncryptionError)` - If there was an error during encryption.
375#[cfg(feature = "encryption")]
376fn encrypt(data: Vec<u8>) -> Result<Vec<u8>> {
377    trace!("Encrypting data");
378    let password = mid::get(env!("CARGO_PKG_NAME")).unwrap_or("lookup".to_string());
379    let mut cocoon = if cfg!(debug_assertions) {
380        Cocoon::new(password.as_bytes()).with_weak_kdf()
381    } else {
382        Cocoon::new(password.as_bytes())
383    };
384    match cocoon.wrap(&data) {
385        Ok(data) => Ok(data),
386        Err(e) => Err(CacheError::EncryptionError(format!(
387            "Error encrypting: {:?}",
388            e
389        ))),
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396    use crate::lookup::LookupProvider;
397    use serial_test::serial;
398
399    #[test]
400    #[serial]
401    fn test_cache_file() {
402        let response = LookupResponse::new(
403            "1.1.1.1".parse().unwrap(),
404            LookupProvider::Mock("1.1.1.1".to_string(), "localhost".to_string()),
405        );
406        println!("{}", get_cache_path(&None));
407        let mut cache = ResponseCache::new(None);
408        cache.update_current(&response, None);
409        cache.save().unwrap();
410        let cached = ResponseCache::load(None).unwrap();
411        assert_eq!(
412            cached.current_ip().unwrap(),
413            "1.1.1.1".parse::<std::net::IpAddr>().unwrap(),
414            "IP address not matching"
415        );
416        cache.delete().unwrap();
417    }
418
419    #[test]
420    fn test_expired() {
421        let response = LookupResponse::new(
422            "1.1.1.1".parse().unwrap(),
423            LookupProvider::Mock("1.1.1.1".to_string(), "localhost".to_string()),
424        );
425        let mut cache = ResponseCache::default();
426        assert!(cache.current_is_expired(), "Empty cache should be expired");
427        cache.update_current(&response, None);
428        assert_eq!(
429            cache.current_ip().unwrap(),
430            "1.1.1.1".parse::<std::net::IpAddr>().unwrap(),
431            "IP address not matching"
432        );
433        assert!(
434            !cache.current_is_expired(),
435            "Cache with no TTL should not be expired"
436        );
437        cache.update_current(&response, Some(1));
438        assert!(
439            !cache.current_is_expired(),
440            "Fresh cache should not be expired {cache:#?}"
441        );
442        // Wait for cache to expire
443        std::thread::sleep(Duration::from_secs(1));
444        assert!(
445            cache.current_is_expired(),
446            "Expired cache should be expired"
447        );
448    }
449
450    #[test]
451    fn test_cache_tree() {
452        let addresses = [
453            "1.1.1.1".parse().unwrap(),
454            "2.1.1.1".parse().unwrap(),
455            "3.1.1.1".parse().unwrap(),
456        ];
457        let mut cache = ResponseCache::default();
458        for address in &addresses {
459            let response = LookupResponse::new(*address, LookupProvider::Ipify);
460            cache.update_target(*address, &response, None);
461        }
462
463        for address in &addresses {
464            assert_eq!(
465                cache.target_response(address).unwrap().ip,
466                *address,
467                "IP address not matching: {cache:#?}"
468            );
469        }
470    }
471
472    #[test]
473    fn test_cache_clear() {
474        let response = LookupResponse::new(
475            "1.1.1.1".parse().unwrap(),
476            LookupProvider::Mock("1.1.1.1".to_string(), "localhost".to_string()),
477        );
478        let mut cache = ResponseCache::new(None);
479        cache.update_current(&response, None);
480        let response = LookupResponse::new("2.2.2.2".parse().unwrap(), LookupProvider::Ipify);
481        cache.update_target(response.ip, &response, None);
482        cache.clear();
483        let cache = ResponseCache::default();
484        assert_eq!(
485            cache,
486            ResponseCache::default(),
487            "Cache not cleared properly: {cache:#?}"
488        );
489    }
490
491    #[test]
492    #[cfg(feature = "encryption")]
493    fn test_encrypt_decrypt() {
494        let data = b"hello world".to_vec();
495        let encrypted = encrypt(data.clone()).unwrap();
496        let decrypted = decrypt(encrypted).unwrap();
497        assert_eq!(data, decrypted);
498    }
499}