czdb/
memory.rs

1use crate::{
2    CzError,
3    common::{
4        DbMeta, DbType, decode_aes_key, decode_region_from_bytes, parse_meta_from_bytes,
5        read_hyper_header, compare_bytes,
6    },
7};
8use std::{
9    collections::HashMap,
10    fs::File,
11    io::{Cursor, Read},
12    net::IpAddr,
13};
14
15#[derive(Debug)]
16struct MemoryIndex {
17    entries_v4: Vec<IndexEntryV4>,
18    entries_v6: Vec<IndexEntryV6>,
19    regions: RegionPool,
20}
21
22#[derive(Debug)]
23struct IndexEntryV4 {
24    start_ip: u32,
25    end_ip: u32,
26    region_id: usize,
27}
28
29#[derive(Debug)]
30struct IndexEntryV6 {
31    start_ip: [u8; 16],
32    end_ip: [u8; 16],
33    region_id: usize,
34}
35
36#[derive(Debug)]
37struct RegionSpan {
38    start: usize,
39    len: usize,
40}
41
42#[derive(Debug)]
43struct RegionPool {
44    data: Box<str>,
45    spans: Vec<RegionSpan>,
46}
47
48impl RegionPool {
49    fn get(&self, region_id: usize) -> &str {
50        let span = &self.spans[region_id];
51        &self.data[span.start..span.start + span.len]
52    }
53}
54
55/// In-memory CZDB searcher with a prebuilt index and string pool.
56///
57/// 预构建索引与字符串池的内存 CZDB 查询器。
58#[derive(Debug)]
59pub struct CzdbMemory {
60    meta: DbMeta,
61    memory_index: MemoryIndex,
62}
63
64impl CzdbMemory {
65    /// Open a database file and build in-memory indices.
66    ///
67    /// 打开数据库文件并构建内存索引。
68    pub fn open(db_path: &str, key: &str) -> Result<Self, CzError> {
69        let mut file = File::open(db_path)?;
70        let mut data = Vec::new();
71        file.read_to_end(&mut data)?;
72        Self::from_bytes(data, key)
73    }
74
75    /// Build from raw bytes and construct in-memory indices.
76    ///
77    /// 从原始字节构建并生成内存索引。
78    pub fn from_bytes(data: Vec<u8>, key: &str) -> Result<Self, CzError> {
79        let key_bytes = decode_aes_key(key)?;
80        let mut cursor = Cursor::new(&data);
81        let header = read_hyper_header(&mut cursor, &key_bytes)?;
82        let data_offset = (12 + header.padding_size + header.encrypted_block_size) as usize;
83        if data_offset > data.len() {
84            return Err(CzError::DatabaseFileCorrupted);
85        }
86        let file_size_total = data.len() as u64;
87        let meta = parse_meta_from_bytes(
88            &data[data_offset..],
89            file_size_total,
90            header.padding_size,
91            header.encrypted_block_size,
92            &key_bytes,
93        )?;
94        let memory_index = build_memory_index(&data[data_offset..], &meta)?;
95
96        Ok(Self {
97            meta,
98            memory_index,
99        })
100    }
101
102    /// Search a single IP address.
103    ///
104    /// 查询指定 IP 地址。
105    pub fn search(&self, ip: IpAddr) -> Option<String> {
106        self.search_ref(ip).map(str::to_string)
107    }
108
109    /// Search a single IP address and return a borrowed string.
110    ///
111    /// 查询指定 IP 并返回借用字符串。
112    pub fn search_ref(&self, ip: IpAddr) -> Option<&str> {
113        if !self.meta.db_type.compare(&ip) {
114            return None;
115        }
116        match ip {
117            IpAddr::V4(ip) => {
118                if self.memory_index.entries_v4.is_empty() {
119                    return None;
120                }
121                let ip_num = u32::from_be_bytes(ip.octets());
122                let mut l = 0usize;
123                let mut h = self.memory_index.entries_v4.len() - 1;
124                while l <= h {
125                    let m = (l + h) >> 1;
126                    let entry = &self.memory_index.entries_v4[m];
127                    if ip_num >= entry.start_ip && ip_num <= entry.end_ip {
128                        return Some(self.memory_index.regions.get(entry.region_id));
129                    } else if ip_num < entry.start_ip {
130                        if m == 0 {
131                            break;
132                        }
133                        h = m - 1;
134                    } else {
135                        l = m + 1;
136                    }
137                }
138                None
139            }
140            IpAddr::V6(ip) => {
141                if self.memory_index.entries_v6.is_empty() {
142                    return None;
143                }
144                let mut ip_bytes = [0u8; 16];
145                ip_bytes.copy_from_slice(&ip.octets());
146                let mut l = 0usize;
147                let mut h = self.memory_index.entries_v6.len() - 1;
148                while l <= h {
149                    let m = (l + h) >> 1;
150                    let entry = &self.memory_index.entries_v6[m];
151                    let cmp_start = compare_bytes(&ip_bytes, &entry.start_ip, 16);
152                    let cmp_end = compare_bytes(&ip_bytes, &entry.end_ip, 16);
153                    if cmp_start != std::cmp::Ordering::Less
154                        && cmp_end != std::cmp::Ordering::Greater
155                    {
156                        return Some(self.memory_index.regions.get(entry.region_id));
157                    } else if cmp_start == std::cmp::Ordering::Less {
158                        if m == 0 {
159                            break;
160                        }
161                        h = m - 1;
162                    } else {
163                        l = m + 1;
164                    }
165                }
166                None
167            }
168        }
169    }
170
171    /// Search a small batch of IP addresses.
172    ///
173    /// 批量查询 IP(小批量)。
174    pub fn search_many(&self, ips: &[IpAddr]) -> Vec<Option<String>> {
175        ips.iter().map(|ip| self.search(*ip)).collect()
176    }
177
178    /// Search a batch of IP addresses and return borrowed strings.
179    ///
180    /// 批量查询 IP 并返回借用字符串。
181    pub fn search_many_ref<'a>(&'a self, ips: &[IpAddr]) -> Vec<Option<&'a str>> {
182        ips.iter().map(|ip| self.search_ref(*ip)).collect()
183    }
184
185    /// Search a large batch by sorting and scanning.
186    ///
187    /// 对大批量 IP 进行排序后扫描查询。
188    pub fn search_many_scan<'a>(&'a self, ips: &[IpAddr]) -> Vec<Option<&'a str>> {
189        let mut results = vec![None; ips.len()];
190        let mut v4 = Vec::new();
191        let mut v6 = Vec::new();
192        for (idx, ip) in ips.iter().copied().enumerate() {
193            match ip {
194                IpAddr::V4(ipv4) => v4.push((u32::from_be_bytes(ipv4.octets()), idx)),
195                IpAddr::V6(ipv6) => v6.push((ipv6.octets(), idx)),
196            }
197        }
198
199        if !v4.is_empty() && !self.memory_index.entries_v4.is_empty() {
200            v4.sort_unstable_by_key(|(ip, _)| *ip);
201            let mut entry_idx = 0usize;
202            for (ip_num, original_idx) in v4 {
203                while entry_idx < self.memory_index.entries_v4.len()
204                    && self.memory_index.entries_v4[entry_idx].end_ip < ip_num
205                {
206                    entry_idx += 1;
207                }
208                if entry_idx >= self.memory_index.entries_v4.len() {
209                    break;
210                }
211                let entry = &self.memory_index.entries_v4[entry_idx];
212                if ip_num >= entry.start_ip && ip_num <= entry.end_ip {
213                    results[original_idx] = Some(self.memory_index.regions.get(entry.region_id));
214                }
215            }
216        }
217
218        if !v6.is_empty() && !self.memory_index.entries_v6.is_empty() {
219            v6.sort_unstable_by(|(a, _), (b, _)| compare_bytes(a, b, 16));
220            let mut entry_idx = 0usize;
221            for (ip_bytes, original_idx) in v6 {
222                while entry_idx < self.memory_index.entries_v6.len()
223                    && compare_bytes(&self.memory_index.entries_v6[entry_idx].end_ip, &ip_bytes, 16)
224                        == std::cmp::Ordering::Less
225                {
226                    entry_idx += 1;
227                }
228                if entry_idx >= self.memory_index.entries_v6.len() {
229                    break;
230                }
231                let entry = &self.memory_index.entries_v6[entry_idx];
232                let cmp_start = compare_bytes(&ip_bytes, &entry.start_ip, 16);
233                let cmp_end = compare_bytes(&ip_bytes, &entry.end_ip, 16);
234                if cmp_start != std::cmp::Ordering::Less
235                    && cmp_end != std::cmp::Ordering::Greater
236                {
237                    results[original_idx] = Some(self.memory_index.regions.get(entry.region_id));
238                }
239            }
240        }
241
242        results
243    }
244
245    /// Returns the database IP version.
246    ///
247    /// 返回数据库类型(IPv4 或 IPv6)。
248    pub fn db_type(&self) -> DbType {
249        self.meta.db_type
250    }
251}
252
253fn build_memory_index(bindata: &[u8], meta: &DbMeta) -> Result<MemoryIndex, CzError> {
254    let ip_len = meta.db_type.bytes_len();
255    let blen = meta.db_type.index_block_len();
256    let start = meta.start_index as usize;
257    let end = meta.end_index as usize;
258
259    if end < start {
260        return Err(CzError::DatabaseFileCorrupted);
261    }
262    if end + blen > bindata.len() {
263        return Err(CzError::DatabaseFileCorrupted);
264    }
265
266    let total_blocks = (end - start) / blen + 1;
267    let mut entries_v4 = Vec::with_capacity(total_blocks);
268    let mut entries_v6 = Vec::with_capacity(total_blocks);
269    let mut regions = Vec::<RegionSpan>::new();
270    let mut region_text = String::new();
271    let mut region_cache = HashMap::<(usize, usize), usize>::new();
272
273    let mut p = start;
274    while p <= end {
275        if p + blen > bindata.len() {
276            return Err(CzError::DatabaseFileCorrupted);
277        }
278        let mut start_ip_bytes = [0u8; 16];
279        let mut end_ip_bytes = [0u8; 16];
280        start_ip_bytes[..ip_len].copy_from_slice(&bindata[p..p + ip_len]);
281        end_ip_bytes[..ip_len].copy_from_slice(&bindata[p + ip_len..p + ip_len * 2]);
282        let data_ptr = u32::from_le_bytes([
283            bindata[p + ip_len * 2],
284            bindata[p + ip_len * 2 + 1],
285            bindata[p + ip_len * 2 + 2],
286            bindata[p + ip_len * 2 + 3],
287        ]) as usize;
288        let data_len = bindata[p + ip_len * 2 + 4] as usize;
289
290        let region_id = match region_cache.get(&(data_ptr, data_len)) {
291            Some(id) => *id,
292            None => {
293                if data_ptr + data_len > bindata.len() {
294                    return Err(CzError::DatabaseFileCorrupted);
295                }
296                let region = decode_region_from_bytes(
297                    &bindata[data_ptr..data_ptr + data_len],
298                    meta,
299                )
300                .ok_or(CzError::DatabaseFileCorrupted)?;
301                let start_offset = region_text.len();
302                region_text.push_str(&region);
303                let len = region.len();
304                let id = regions.len();
305                regions.push(RegionSpan {
306                    start: start_offset,
307                    len,
308                });
309                region_cache.insert((data_ptr, data_len), id);
310                id
311            }
312        };
313
314        if meta.db_type == DbType::Ipv4 {
315            let start_ip = u32::from_be_bytes(start_ip_bytes[..4].try_into().unwrap());
316            let end_ip = u32::from_be_bytes(end_ip_bytes[..4].try_into().unwrap());
317            entries_v4.push(IndexEntryV4 {
318                start_ip,
319                end_ip,
320                region_id,
321            });
322        } else {
323            entries_v6.push(IndexEntryV6 {
324                start_ip: start_ip_bytes,
325                end_ip: end_ip_bytes,
326                region_id,
327            });
328        }
329
330        p += blen;
331    }
332
333    Ok(MemoryIndex {
334        entries_v4,
335        entries_v6,
336        regions: RegionPool {
337            data: region_text.into_boxed_str(),
338            spans: regions,
339        },
340    })
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use rmpv::{Value, encode::write_value};
347    use std::net::Ipv4Addr;
348
349    fn build_test_db() -> CzdbMemory {
350        let block_len = DbType::Ipv4.index_block_len();
351        let padding = 4usize;
352        let mut bindata = vec![0u8; padding + block_len * 2];
353
354        let mut region1 = Vec::new();
355        write_value(&mut region1, &Value::Integer(0.into())).unwrap();
356        write_value(&mut region1, &Value::String("region1".into())).unwrap();
357
358        let mut region2 = Vec::new();
359        write_value(&mut region2, &Value::Integer(0.into())).unwrap();
360        write_value(&mut region2, &Value::String("region2".into())).unwrap();
361
362        let region1_ptr = (padding + block_len * 2) as u32;
363        let region2_ptr = region1_ptr + region1.len() as u32;
364
365        let first_offset = padding;
366        bindata[first_offset..first_offset + 4].copy_from_slice(&[1, 1, 1, 0]);
367        bindata[first_offset + 4..first_offset + 8].copy_from_slice(&[1, 1, 1, 255]);
368        bindata[first_offset + 8..first_offset + 12].copy_from_slice(&region1_ptr.to_le_bytes());
369        bindata[first_offset + 12] = region1.len() as u8;
370
371        let offset = padding + block_len;
372        bindata[offset..offset + 4].copy_from_slice(&[2, 2, 2, 0]);
373        bindata[offset + 4..offset + 8].copy_from_slice(&[2, 2, 2, 255]);
374        bindata[offset + 8..offset + 12].copy_from_slice(&region2_ptr.to_le_bytes());
375        bindata[offset + 12] = region2.len() as u8;
376
377        bindata.extend_from_slice(&region1);
378        bindata.extend_from_slice(&region2);
379
380        let mut header_sip = Vec::new();
381        let mut header_ptr = Vec::new();
382        let mut ip1 = [0u8; 16];
383        let mut ip2 = [0u8; 16];
384        ip1[..4].copy_from_slice(&[1, 1, 1, 0]);
385        ip2[..4].copy_from_slice(&[2, 2, 2, 0]);
386        header_sip.push(ip1);
387        header_sip.push(ip2);
388        header_ptr.push(first_offset as u32);
389        header_ptr.push(offset as u32);
390
391        let meta = DbMeta {
392            db_type: DbType::Ipv4,
393            header_sip,
394            header_ptr,
395            column_selection: 0,
396            geo_map_data: None,
397            start_index: first_offset as u32,
398            end_index: offset as u32,
399        };
400
401        let memory_index = build_memory_index(&bindata, &meta).unwrap();
402
403        let _ = bindata;
404        CzdbMemory { meta, memory_index }
405    }
406
407    #[test]
408    fn search_handles_start_boundary_correctly() {
409        let db = build_test_db();
410        assert_eq!(
411            db.search(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 0))),
412            Some("region1".to_string())
413        );
414    }
415
416    #[test]
417    fn search_returns_expected_results() {
418        let db = build_test_db();
419        assert_eq!(
420            db.search(IpAddr::V4(Ipv4Addr::new(2, 2, 2, 2))),
421            Some("region2".to_string())
422        );
423        assert!(db.search(IpAddr::V4(Ipv4Addr::new(3, 3, 3, 3))).is_none());
424    }
425}