1use crate::{
2 CzError,
3 common::{
4 DbMeta, decode_aes_key, decode_region_from_bytes, parse_meta_from_bytes, read_hyper_header,
5 compare_bytes,
6 },
7};
8use memmap2::{Mmap, MmapOptions};
9use std::{
10 fs::File,
11 net::IpAddr,
12};
13
14#[derive(Debug)]
15struct MmapBytes {
16 mmap: Mmap,
17 offset: usize,
18}
19
20impl MmapBytes {
21 fn as_slice(&self) -> &[u8] {
22 &self.mmap[self.offset..]
23 }
24}
25
26#[derive(Debug)]
30pub struct CzdbMmap {
31 bindata: MmapBytes,
32 meta: DbMeta,
33}
34
35impl CzdbMmap {
36 pub fn open(db_path: &str, key: &str) -> Result<Self, CzError> {
40 let key_bytes = decode_aes_key(key)?;
41 let mut file = File::open(db_path)?;
42 let header = read_hyper_header(&mut file, &key_bytes)?;
43 let data_offset = (12 + header.padding_size + header.encrypted_block_size) as usize;
44 let mmap = unsafe { MmapOptions::new().map(&file)? };
45 if data_offset > mmap.len() {
46 return Err(CzError::DatabaseFileCorrupted);
47 }
48 let file_size_total = file.metadata()?.len();
49 let bindata = MmapBytes { mmap, offset: data_offset };
50 let meta = parse_meta_from_bytes(
51 bindata.as_slice(),
52 file_size_total,
53 header.padding_size,
54 header.encrypted_block_size,
55 &key_bytes,
56 )?;
57
58 Ok(Self { bindata, meta })
59 }
60
61 pub fn search(&self, ip: IpAddr) -> Option<String> {
65 if !self.meta.db_type.compare(&ip) {
66 return None;
67 }
68 let mut ip_bytes = [0u8; 16];
69 match ip {
70 IpAddr::V4(ip) => ip_bytes[..4].copy_from_slice(&ip.octets()),
71 IpAddr::V6(ip) => ip_bytes.copy_from_slice(&ip.octets()),
72 }
73
74 let (sptr, eptr) = self.meta.search_in_header(&ip_bytes)?;
75 let sptr = sptr as usize;
76 let eptr = eptr as usize;
77 if eptr < sptr {
78 return None;
79 }
80
81 let bindata = self.bindata.as_slice();
82 let ip_len = self.meta.db_type.bytes_len();
83 let blen = self.meta.db_type.index_block_len();
84 let block_len = eptr - sptr;
85 let max_len = sptr.saturating_add(block_len).saturating_add(blen);
86 if max_len > bindata.len() {
87 return None;
88 }
89
90 let mut l = 0usize;
91 let mut h = block_len / blen;
92 while l <= h {
93 let m = (l + h) >> 1;
94 let p = sptr + m * blen;
95 let start_ip = &bindata[p..p + ip_len];
96 let end_ip = &bindata[p + ip_len..p + ip_len * 2];
97 let cmp_start = compare_bytes(&ip_bytes, start_ip, ip_len);
98 let cmp_end = compare_bytes(&ip_bytes, end_ip, ip_len);
99
100 if cmp_start != std::cmp::Ordering::Less && cmp_end != std::cmp::Ordering::Greater {
101 let data_ptr = u32::from_le_bytes([
102 bindata[p + ip_len * 2],
103 bindata[p + ip_len * 2 + 1],
104 bindata[p + ip_len * 2 + 2],
105 bindata[p + ip_len * 2 + 3],
106 ]) as usize;
107 let data_len = bindata[p + ip_len * 2 + 4] as usize;
108 if data_ptr + data_len > bindata.len() {
109 return None;
110 }
111 return decode_region_from_bytes(
112 &bindata[data_ptr..data_ptr + data_len],
113 &self.meta,
114 );
115 } else if cmp_start == std::cmp::Ordering::Less {
116 if m == 0 {
117 break;
118 }
119 h = m - 1;
120 } else {
121 l = m + 1;
122 }
123 }
124
125 None
126 }
127
128 pub fn search_many(&self, ips: &[IpAddr]) -> Vec<Option<String>> {
132 ips.iter().map(|ip| self.search(*ip)).collect()
133 }
134}