ip_alloc_lookup/
database.rs

1use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
2use std::{fs, io, path::Path};
3
4#[cfg(feature = "download")]
5pub const RIPE_EXTENDED_LATEST_URL: &str =
6    "https://ftp.ripe.net/pub/stats/ripencc/delegated-ripencc-extended-latest";
7
8#[derive(Debug, Clone, Copy)]
9#[repr(C)]
10pub struct GeoInfo {
11    pub country_code: [u8; 2],
12    pub is_eu: bool,
13    pub region: u8,
14}
15
16#[repr(u8)]
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum Region {
19    EuropeanUnion = 1,
20    EuropeNonEu   = 2,
21    EasternEurope = 3,
22    Turkey        = 4,
23    MiddleEast    = 5,
24    NorthAfrica   = 6,
25    CentralAsia   = 7,
26    GulfStates    = 8,
27    Other         = 255,
28}
29
30impl Region {
31    pub fn as_str(self) -> &'static str {
32        match self {
33            Region::EuropeanUnion => "European Union",
34            Region::EuropeNonEu   => "Europe (non-EU)",
35            Region::EasternEurope => "Eastern Europe",
36            Region::Turkey        => "Turkey",
37            Region::MiddleEast    => "Middle East",
38            Region::NorthAfrica   => "North Africa",
39            Region::CentralAsia   => "Central Asia",
40            Region::GulfStates    => "Gulf States",
41            Region::Other         => "Other",
42        }
43    }
44}
45
46// Convert a 2-letter country code like "DE" into [b'D', b'E'].
47fn cc2(country: &str) -> [u8; 2] {
48    let b = country.as_bytes();
49    // RIPE data should always be 2-letter country codes; if not, fall back.
50    if b.len() >= 2 { [b[0], b[1]] } else { *b"??" }
51}
52
53// For display/testing convenience.
54impl GeoInfo {
55    pub fn country_code_str(&self) -> &str {
56        // Always valid for ASCII 2-letter codes; fallback if somehow invalid.
57        std::str::from_utf8(&self.country_code).unwrap_or("??")
58    }
59
60    pub fn region_enum(&self) -> Region {
61        match self.region {
62            1 => Region::EuropeanUnion,
63            2 => Region::EuropeNonEu,
64            3 => Region::EasternEurope,
65            4 => Region::Turkey,
66            5 => Region::MiddleEast,
67            6 => Region::NorthAfrica,
68            7 => Region::CentralAsia,
69            8 => Region::GulfStates,
70            _ => Region::Other,
71        }
72    }
73}
74
75
76pub struct GeoIpDb {
77    v4_ranges: Vec<(u32, u32, GeoInfo)>,
78    v6_ranges: Vec<(u128, u128, GeoInfo)>,
79}
80
81// EU member states (27 countries as of 2025)
82const EU_COUNTRIES: &[&str] = &[
83    "AT", "BE", "BG", "HR", "CY", "CZ", "DK", "EE", "FI", "FR",
84    "DE", "GR", "HU", "IE", "IT", "LV", "LT", "LU", "MT", "NL",
85    "PL", "PT", "RO", "SK", "SI", "ES", "SE",
86];
87
88// Include the generated data from build.rs
89include!(concat!(env!("OUT_DIR"), "/generated_data.rs"));
90
91impl GeoIpDb {
92    /// Create a new database with embedded RIPE data
93    pub fn new() -> Self {
94        let mut v4_ranges = Vec::with_capacity(IPV4_RANGES.len());
95        let mut v6_ranges = Vec::with_capacity(IPV6_RANGES.len());
96
97        // Process IPv4 ranges
98        for &(start, end, country) in IPV4_RANGES {
99            let is_eu = EU_COUNTRIES.contains(&country);
100            let region = determine_region(country);
101
102            let geo_info = GeoInfo {
103				country_code: cc2(country),
104				is_eu,
105				region: region as u8,
106			};
107
108            v4_ranges.push((start, end, geo_info));
109        }
110
111        // Process IPv6 ranges
112        for &(start, end, country) in IPV6_RANGES {
113            let is_eu = EU_COUNTRIES.contains(&country);
114            let region = determine_region(country);
115
116            let geo_info = GeoInfo {
117				country_code: cc2(country),
118				is_eu,
119				region: region as u8,
120			};
121
122            v6_ranges.push((start, end, geo_info));
123        }
124
125        // Data should already be sorted from build.rs, but let's be safe
126        //v4_ranges.sort_by_key(|r| r.0);
127        //v6_ranges.sort_by_key(|r| r.0);
128
129        GeoIpDb { v4_ranges, v6_ranges }
130    }
131	
132	/// Build a DB from RIPE delegated stats *content* (runtime).
133    pub fn from_ripe_delegated_str(content: &str) -> Self {
134        let parsed = crate::parse_ripe_delegated(content);
135
136        let mut v4_ranges: Vec<(u32, u32, GeoInfo)> = Vec::new();
137        let mut v6_ranges: Vec<(u128, u128, GeoInfo)> = Vec::new();
138
139        for r in parsed {
140            let is_eu = EU_COUNTRIES.contains(&r.country.as_str());
141            let region = determine_region(&r.country);
142
143            let geo = GeoInfo {
144                country_code: cc2(&r.country),
145                is_eu,
146                region: region as u8,
147            };
148
149            if let Some(v4) = r.start_v4 {
150                let start: u32 = v4.into();
151                let end = start.saturating_add((r.count as u32).saturating_sub(1));
152                v4_ranges.push((start, end, geo));
153            } else if let Some(v6) = r.start_v6 {
154                let start: u128 = v6.into();
155                let end = start.saturating_add(r.count.saturating_sub(1));
156                v6_ranges.push((start, end, geo));
157            }
158        }
159
160        v4_ranges.sort_by_key(|r| r.0);
161        v6_ranges.sort_by_key(|r| r.0);
162
163        GeoIpDb { v4_ranges, v6_ranges }
164    }
165
166    /// Load RIPE delegated stats data from a file at runtime.
167    pub fn from_ripe_delegated_file<P: AsRef<Path>>(path: P) -> io::Result<Self> {
168        let content = fs::read_to_string(path)?;
169        Ok(Self::from_ripe_delegated_str(&content))
170    }
171
172    /// Try to load from a cache file; if missing/unreadable, fall back to embedded data.
173    pub fn from_cache_or_embedded<P: AsRef<Path>>(cache_path: P) -> Self {
174        match Self::from_ripe_delegated_file(cache_path) {
175            Ok(db) => db,
176            Err(_) => Self::new(),
177        }
178    }
179
180    /// Look up an IPv4 address
181	#[inline]
182    pub fn lookup_v4(&self, ip: Ipv4Addr) -> Option<&GeoInfo> {
183		let ip_u32: u32 = ip.into();
184		
185		match self.v4_ranges.binary_search_by_key(&ip_u32, |&(start, _, _)| start) {
186			Ok(idx) => Some(&self.v4_ranges[idx].2),
187			Err(idx) => {
188				if idx > 0 {
189					let (start, end, geo) = &self.v4_ranges[idx - 1];
190					if ip_u32 >= *start && ip_u32 <= *end {
191						return Some(geo);
192					}
193				}
194				None
195			}
196		}
197	}
198
199    /// Look up an IPv6 address
200	#[inline]
201	pub fn lookup_v6(&self, ip: Ipv6Addr) -> Option<&GeoInfo> {
202		let ip_u128: u128 = ip.into();
203		let ranges = &self.v6_ranges;
204
205		if ranges.is_empty() {
206			return None;
207		}
208
209		// upper_bound: first index where start > ip
210		let mut lo: usize = 0;
211		let mut hi: usize = ranges.len();
212		while lo < hi {
213			let mid = lo + (hi - lo) / 2;
214			if ip_u128 < ranges[mid].0 {
215				hi = mid;
216			} else {
217				lo = mid + 1;
218			}
219		}
220
221		if lo == 0 {
222			return None;
223		}
224
225		let (start, end, geo) = &ranges[lo - 1];
226		if ip_u128 >= *start && ip_u128 <= *end {
227			Some(geo)
228		} else {
229			None
230		}
231	}
232
233    /// Look up any IP address (IPv4 or IPv6)
234    pub fn lookup(&self, ip: IpAddr) -> Option<&GeoInfo> {
235        match ip {
236            IpAddr::V4(v4) => self.lookup_v4(v4),
237            IpAddr::V6(v6) => self.lookup_v6(v6),
238        }
239    }
240
241    /// Convenience method: check if IP is in EU
242	#[inline]
243    pub fn is_eu(&self, ip: IpAddr) -> bool {
244        self.lookup(ip).map(|info| info.is_eu).unwrap_or(false)
245    }
246
247    /// Get statistics about the database
248    pub fn stats(&self) -> DbStats {
249        let total_v4_ranges = self.v4_ranges.len();
250        let total_v6_ranges = self.v6_ranges.len();
251        let eu_v4_ranges = self.v4_ranges.iter().filter(|(_, _, info)| info.is_eu).count();
252        let eu_v6_ranges = self.v6_ranges.iter().filter(|(_, _, info)| info.is_eu).count();
253
254        DbStats {
255            total_v4_ranges,
256            total_v6_ranges,
257            eu_v4_ranges,
258            eu_v6_ranges,
259            non_eu_v4_ranges: total_v4_ranges - eu_v4_ranges,
260            non_eu_v6_ranges: total_v6_ranges - eu_v6_ranges,
261        }
262    }
263}
264
265#[cfg(feature = "download")]
266impl GeoIpDb {
267    /// Download RIPE delegated data from `url` and atomically replace `cache_path`.
268    ///
269    /// Returns the number of bytes written.
270    pub fn update_cache_from_url<P: AsRef<Path>>(cache_path: P, url: &str) -> io::Result<u64> {
271        let cache_path = cache_path.as_ref();
272
273        // Ensure parent dir exists
274        if let Some(parent) = cache_path.parent() {
275            fs::create_dir_all(parent)?;
276        }
277
278        // Download
279        let resp = reqwest::blocking::get(url)
280            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
281            .error_for_status()
282            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
283
284        let bytes = resp
285            .bytes()
286            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
287
288        // Write to a temp file next to the destination (so rename is atomic on most OSes)
289        let tmp_path = cache_path.with_extension("tmp");
290        {
291            let mut f = fs::File::create(&tmp_path)?;
292            use std::io::Write;
293            f.write_all(&bytes)?;
294            f.sync_all()?;
295        }
296
297        // Replace existing cache atomically-ish
298        if cache_path.exists() {
299            // On Windows rename can fail if target exists, so remove first.
300            let _ = fs::remove_file(cache_path);
301        }
302        fs::rename(&tmp_path, cache_path)?;
303
304        Ok(bytes.len() as u64)
305    }
306
307    /// Convenience: update from RIPE "extended latest".
308    pub fn update_cache<P: AsRef<Path>>(cache_path: P) -> io::Result<u64> {
309        Self::update_cache_from_url(cache_path, RIPE_EXTENDED_LATEST_URL)
310    }
311}
312
313impl Default for GeoIpDb {
314    fn default() -> Self {
315        Self::new()
316    }
317}
318
319#[derive(Debug)]
320pub struct DbStats {
321    pub total_v4_ranges: usize,
322    pub total_v6_ranges: usize,
323    pub eu_v4_ranges: usize,
324    pub eu_v6_ranges: usize,
325    pub non_eu_v4_ranges: usize,
326    pub non_eu_v6_ranges: usize,
327}
328
329fn determine_region(country_code: &str) -> Region {
330    if EU_COUNTRIES.contains(&country_code) {
331        Region::EuropeanUnion
332    } else {
333        match country_code {
334            "GB" | "NO" | "CH" | "IS" | "LI" => Region::EuropeNonEu,
335            "RU" | "UA" | "BY" | "MD" => Region::EasternEurope,
336            "TR" => Region::Turkey,
337            "IL" | "PS" => Region::MiddleEast,
338            "EG" | "TN" | "MA" | "DZ" => Region::NorthAfrica,
339            "KZ" | "UZ" | "TM" | "KG" | "TJ" => Region::CentralAsia,
340            "AE" | "SA" | "QA" | "KW" | "BH" | "OM" => Region::GulfStates,
341            _ => Region::Other,
342        }
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349
350    #[test]
351    fn test_embedded_db() {
352        let db = GeoIpDb::new();
353
354        let stats = db.stats();
355        println!("\nšŸ“Š Embedded Database Stats:");
356        println!("  IPv4 ranges: {} (EU: {}, non-EU: {})", 
357            stats.total_v4_ranges, stats.eu_v4_ranges, stats.non_eu_v4_ranges);
358        println!("  IPv6 ranges: {} (EU: {}, non-EU: {})", 
359            stats.total_v6_ranges, stats.eu_v6_ranges, stats.non_eu_v6_ranges);
360
361        assert!(stats.total_v4_ranges > 0, "Should have IPv4 ranges");
362    }
363
364    #[test]
365    fn test_lookup_german_ipv4() {
366        let db = GeoIpDb::new();
367        let ip: Ipv4Addr = "46.4.0.1".parse().unwrap();
368
369        let info = db.lookup_v4(ip).expect("German IP should be found");
370        assert_eq!(info.country_code_str(), "DE");
371        assert!(info.is_eu);
372    }
373
374    #[test]
375    fn test_lookup_german_ipv6() {
376        let db = GeoIpDb::new();
377        // Example German IPv6 address (2a00::/12 is typically EU)
378        let ip: Ipv6Addr = "2a01:4f8::1".parse().unwrap();
379
380        if let Some(info) = db.lookup_v6(ip) {
381            println!("Found IPv6: {} in {}", ip, info.country_code_str());
382            // Just verify we can look it up, actual country depends on data
383        }
384    }
385
386    #[test]
387    fn test_lookup_any_ip() {
388        let db = GeoIpDb::new();
389        
390        // Test with IPv4
391        let ipv4: IpAddr = "46.4.0.1".parse().unwrap();
392        if let Some(info) = db.lookup(ipv4) {
393            assert_eq!(info.country_code_str(), "DE");
394        }
395
396        // Test with IPv6
397        let ipv6: IpAddr = "2a01:4f8::1".parse().unwrap();
398        let _ = db.lookup(ipv6);
399    }
400
401    #[test]
402    fn test_is_eu_method() {
403        let db = GeoIpDb::new();
404
405        // Test IPv4
406        let ipv4: IpAddr = "46.4.0.1".parse().unwrap();
407        if db.lookup(ipv4).is_some() {
408            assert!(db.is_eu(ipv4));
409        }
410    }
411	
412	#[cfg(feature = "download")]
413	fn serve_once(body: &'static str) -> String {
414		use std::io::{Read, Write};
415		use std::net::TcpListener;
416
417		let listener = TcpListener::bind("127.0.0.1:0").unwrap();
418		let addr = listener.local_addr().unwrap();
419
420		std::thread::spawn(move || {
421			let (mut stream, _) = listener.accept().unwrap();
422
423			// read request (ignore contents)
424			let mut buf = [0u8; 1024];
425			let _ = stream.read(&mut buf);
426
427			let resp = format!(
428				"HTTP/1.1 200 OK\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
429				body.as_bytes().len(),
430				body
431			);
432			let _ = stream.write_all(resp.as_bytes());
433			let _ = stream.flush();
434		});
435
436		format!("http://{}", addr)
437	}
438	
439	#[test]
440	#[cfg(feature = "download")]
441	fn test_update_cache_and_load() {
442		use std::net::IpAddr;
443
444		// Minimal delegated content:
445		// - one IPv4 block: 46.4.0.0/24 (256 addrs)
446		// - one IPv6 block: 2a01:4f8::/32
447		let delegated = "\
448	# comment
449	2|ripencc|20250101|0000|summary|whatever
450	ripencc|DE|ipv4|46.4.0.0|256|20250101|allocated
451	ripencc|DE|ipv6|2a01:4f8::|32|20250101|allocated
452	";
453
454		let url = serve_once(delegated);
455
456		let dir = tempfile::tempdir().unwrap();
457		let cache_path = dir.path().join("ripe-cache.txt");
458
459		let bytes = GeoIpDb::update_cache_from_url(&cache_path, &url).unwrap();
460		assert!(bytes > 0);
461		assert!(cache_path.exists());
462
463		let db = GeoIpDb::from_ripe_delegated_file(&cache_path).unwrap();
464
465		let ip: IpAddr = "46.4.0.1".parse().unwrap();
466		let info = db.lookup(ip).expect("should find 46.4.0.1");
467		assert_eq!(info.country_code_str(), "DE");
468		assert!(info.is_eu);
469	}
470	
471	#[test]
472	#[cfg(feature = "download")]
473	fn test_update_cache_replaces_existing_file() {
474		let old = "\
475	ripencc|FR|ipv4|46.4.0.0|256|20250101|allocated
476	";
477		let new = "\
478	ripencc|DE|ipv4|46.4.0.0|256|20250101|allocated
479	";
480
481		let url = serve_once(new);
482
483		let dir = tempfile::tempdir().unwrap();
484		let cache_path = dir.path().join("ripe-cache.txt");
485
486		std::fs::write(&cache_path, old).unwrap();
487
488		GeoIpDb::update_cache_from_url(&cache_path, &url).unwrap();
489
490		let db = GeoIpDb::from_ripe_delegated_file(&cache_path).unwrap();
491		let info = db.lookup("46.4.0.1".parse().unwrap()).unwrap();
492		assert_eq!(info.country_code_str(), "DE");
493	}
494	
495	#[test]
496	#[ignore]
497	#[cfg(feature = "download")]
498	fn smoke_test_real_ripe_download_and_lookup() {
499		let cache = std::path::PathBuf::from("/tmp/ripe-cache.txt");
500
501		// Download real RIPE data
502		let bytes = GeoIpDb::update_cache(&cache).unwrap();
503		assert!(bytes > 1_000_000, "too small, download probably failed");
504
505		// Load from cache
506		let db = GeoIpDb::from_ripe_delegated_file(&cache).unwrap();
507
508		// Known Hetzner range is commonly DE
509		let ip: std::net::IpAddr = "88.198.0.1".parse().unwrap();
510		let info = db.lookup(ip).unwrap();
511		println!("88.198.0.1 -> {}", info.country_code_str());
512	}
513}