czdb/
mmap.rs

1use crate::{
2    CzError,
3    common::{
4        DbMeta, decode_aes_key, decode_region_from_bytes, parse_meta_from_bytes, read_hyper_header,
5        compare_bytes,
6    },
7};
8use memmap2::{Mmap, MmapOptions};
9use std::{
10    fs::File,
11    net::IpAddr,
12};
13
14#[derive(Debug)]
15struct MmapBytes {
16    mmap: Mmap,
17    offset: usize,
18}
19
20impl MmapBytes {
21    fn as_slice(&self) -> &[u8] {
22        &self.mmap[self.offset..]
23    }
24}
25
26/// Mmap-backed CZDB searcher.
27///
28/// 基于 mmap 的 CZDB 查询器。
29#[derive(Debug)]
30pub struct CzdbMmap {
31    bindata: MmapBytes,
32    meta: DbMeta,
33}
34
35impl CzdbMmap {
36    /// Open a database file using memory mapping.
37    ///
38    /// 使用内存映射打开数据库文件。
39    pub fn open(db_path: &str, key: &str) -> Result<Self, CzError> {
40        let key_bytes = decode_aes_key(key)?;
41        let mut file = File::open(db_path)?;
42        let header = read_hyper_header(&mut file, &key_bytes)?;
43        let data_offset = (12 + header.padding_size + header.encrypted_block_size) as usize;
44        let mmap = unsafe { MmapOptions::new().map(&file)? };
45        if data_offset > mmap.len() {
46            return Err(CzError::DatabaseFileCorrupted);
47        }
48        let file_size_total = file.metadata()?.len();
49        let bindata = MmapBytes { mmap, offset: data_offset };
50        let meta = parse_meta_from_bytes(
51            bindata.as_slice(),
52            file_size_total,
53            header.padding_size,
54            header.encrypted_block_size,
55            &key_bytes,
56        )?;
57
58        Ok(Self { bindata, meta })
59    }
60
61    /// Search a single IP address.
62    ///
63    /// 查询指定 IP 地址。
64    pub fn search(&self, ip: IpAddr) -> Option<String> {
65        if !self.meta.db_type.compare(&ip) {
66            return None;
67        }
68        let mut ip_bytes = [0u8; 16];
69        match ip {
70            IpAddr::V4(ip) => ip_bytes[..4].copy_from_slice(&ip.octets()),
71            IpAddr::V6(ip) => ip_bytes.copy_from_slice(&ip.octets()),
72        }
73
74        let (sptr, eptr) = self.meta.search_in_header(&ip_bytes)?;
75        let sptr = sptr as usize;
76        let eptr = eptr as usize;
77        if eptr < sptr {
78            return None;
79        }
80
81        let bindata = self.bindata.as_slice();
82        let ip_len = self.meta.db_type.bytes_len();
83        let blen = self.meta.db_type.index_block_len();
84        let block_len = eptr - sptr;
85        let max_len = sptr.saturating_add(block_len).saturating_add(blen);
86        if max_len > bindata.len() {
87            return None;
88        }
89
90        let mut l = 0usize;
91        let mut h = block_len / blen;
92        while l <= h {
93            let m = (l + h) >> 1;
94            let p = sptr + m * blen;
95            let start_ip = &bindata[p..p + ip_len];
96            let end_ip = &bindata[p + ip_len..p + ip_len * 2];
97            let cmp_start = compare_bytes(&ip_bytes, start_ip, ip_len);
98            let cmp_end = compare_bytes(&ip_bytes, end_ip, ip_len);
99
100            if cmp_start != std::cmp::Ordering::Less && cmp_end != std::cmp::Ordering::Greater {
101                let data_ptr = u32::from_le_bytes([
102                    bindata[p + ip_len * 2],
103                    bindata[p + ip_len * 2 + 1],
104                    bindata[p + ip_len * 2 + 2],
105                    bindata[p + ip_len * 2 + 3],
106                ]) as usize;
107                let data_len = bindata[p + ip_len * 2 + 4] as usize;
108                if data_ptr + data_len > bindata.len() {
109                    return None;
110                }
111                return decode_region_from_bytes(
112                    &bindata[data_ptr..data_ptr + data_len],
113                    &self.meta,
114                );
115            } else if cmp_start == std::cmp::Ordering::Less {
116                if m == 0 {
117                    break;
118                }
119                h = m - 1;
120            } else {
121                l = m + 1;
122            }
123        }
124
125        None
126    }
127
128    /// Search a small batch of IP addresses.
129    ///
130    /// 批量查询 IP(小批量)。
131    pub fn search_many(&self, ips: &[IpAddr]) -> Vec<Option<String>> {
132        ips.iter().map(|ip| self.search(*ip)).collect()
133    }
134}