czdb/
disk.rs

1use crate::{
2    CzError,
3    common::{
4        DbMeta, DbType, decode_aes_key, compare_bytes, decode_region_from_bytes, parse_meta_from_file,
5        read_hyper_header,
6    },
7};
8use std::{
9    fs::File,
10    io::{Read, Seek, SeekFrom},
11    net::IpAddr,
12};
13
14/// Disk-backed CZDB searcher.
15///
16/// 基于磁盘读取的 CZDB 查询器。
17#[derive(Debug)]
18pub struct CzdbDisk {
19    file: File,
20    data_offset: u64,
21    meta: DbMeta,
22}
23
24impl CzdbDisk {
25    /// Open a database file for disk-backed queries.
26    ///
27    /// 打开数据库文件用于磁盘查询。
28    pub fn open(db_path: &str, key: &str) -> Result<Self, CzError> {
29        let key_bytes = decode_aes_key(key)?;
30        let mut file = File::open(db_path)?;
31        let header = read_hyper_header(&mut file, &key_bytes)?;
32        let data_offset = (12 + header.padding_size + header.encrypted_block_size) as u64;
33        let file_size_total = file.metadata()?.len();
34        let meta = parse_meta_from_file(
35            &mut file,
36            data_offset,
37            file_size_total,
38            header.padding_size,
39            header.encrypted_block_size,
40            &key_bytes,
41        )?;
42
43        Ok(Self {
44            file,
45            data_offset,
46            meta,
47        })
48    }
49
50    /// Search a single IP address.
51    ///
52    /// 查询指定 IP 地址。
53    pub fn search(&mut self, ip: IpAddr) -> Option<String> {
54        if !self.meta.db_type.compare(&ip) {
55            return None;
56        }
57        let mut ip_bytes = [0u8; 16];
58        match ip {
59            IpAddr::V4(ip) => ip_bytes[..4].copy_from_slice(&ip.octets()),
60            IpAddr::V6(ip) => ip_bytes.copy_from_slice(&ip.octets()),
61        }
62
63        let (sptr, eptr) = self.meta.search_in_header(&ip_bytes)?;
64        let sptr = sptr as usize;
65        let eptr = eptr as usize;
66        if eptr < sptr {
67            return None;
68        }
69
70        let ip_len = self.meta.db_type.bytes_len();
71        let blen = self.meta.db_type.index_block_len();
72        let block_len = eptr - sptr;
73        let read_len = block_len + blen;
74        let mut index_buffer = vec![0u8; read_len];
75        if self
76            .file
77            .seek(SeekFrom::Start(self.data_offset + sptr as u64))
78            .is_err()
79        {
80            return None;
81        }
82        if self.file.read_exact(&mut index_buffer).is_err() {
83            return None;
84        }
85
86        let mut l = 0usize;
87        let mut h = block_len / blen;
88        while l <= h {
89            let m = (l + h) >> 1;
90            let p = m * blen;
91            let start_ip = &index_buffer[p..p + ip_len];
92            let end_ip = &index_buffer[p + ip_len..p + ip_len * 2];
93            let cmp_start = compare_bytes(&ip_bytes, start_ip, ip_len);
94            let cmp_end = compare_bytes(&ip_bytes, end_ip, ip_len);
95
96            if cmp_start != std::cmp::Ordering::Less && cmp_end != std::cmp::Ordering::Greater {
97                let data_ptr = u32::from_le_bytes([
98                    index_buffer[p + ip_len * 2],
99                    index_buffer[p + ip_len * 2 + 1],
100                    index_buffer[p + ip_len * 2 + 2],
101                    index_buffer[p + ip_len * 2 + 3],
102                ]) as usize;
103                let data_len = index_buffer[p + ip_len * 2 + 4] as usize;
104                if data_ptr == 0 || data_len == 0 {
105                    return None;
106                }
107                let mut region_bytes = vec![0u8; data_len];
108                if self
109                    .file
110                    .seek(SeekFrom::Start(self.data_offset + data_ptr as u64))
111                    .is_err()
112                {
113                    return None;
114                }
115                if self.file.read_exact(&mut region_bytes).is_err() {
116                    return None;
117                }
118                return decode_region_from_bytes(&region_bytes, &self.meta);
119            } else if cmp_start == std::cmp::Ordering::Less {
120                if m == 0 {
121                    break;
122                }
123                h = m - 1;
124            } else {
125                l = m + 1;
126            }
127        }
128
129        None
130    }
131
132    /// Search a small batch of IP addresses.
133    ///
134    /// 批量查询 IP(小批量)。
135    pub fn search_many(&mut self, ips: &[IpAddr]) -> Vec<Option<String>> {
136        ips.iter().map(|ip| self.search(*ip)).collect()
137    }
138
139    /// Returns the database IP version.
140    ///
141    /// 返回数据库类型(IPv4 或 IPv6)。
142    pub fn db_type(&self) -> DbType {
143        self.meta.db_type
144    }
145}