1use crate::{
2 CzError,
3 common::{
4 DbMeta, DbType, decode_aes_key, decode_region_from_bytes, parse_meta_from_bytes,
5 read_hyper_header, compare_bytes,
6 },
7};
8use std::{
9 collections::HashMap,
10 fs::File,
11 io::{Cursor, Read},
12 net::IpAddr,
13};
14
15#[derive(Debug)]
16struct MemoryIndex {
17 entries_v4: Vec<IndexEntryV4>,
18 entries_v6: Vec<IndexEntryV6>,
19 regions: RegionPool,
20}
21
22#[derive(Debug)]
23struct IndexEntryV4 {
24 start_ip: u32,
25 end_ip: u32,
26 region_id: usize,
27}
28
29#[derive(Debug)]
30struct IndexEntryV6 {
31 start_ip: [u8; 16],
32 end_ip: [u8; 16],
33 region_id: usize,
34}
35
36#[derive(Debug)]
37struct RegionSpan {
38 start: usize,
39 len: usize,
40}
41
42#[derive(Debug)]
43struct RegionPool {
44 data: Box<str>,
45 spans: Vec<RegionSpan>,
46}
47
48impl RegionPool {
49 fn get(&self, region_id: usize) -> &str {
50 let span = &self.spans[region_id];
51 &self.data[span.start..span.start + span.len]
52 }
53}
54
55#[derive(Debug)]
59pub struct CzdbMemory {
60 meta: DbMeta,
61 memory_index: MemoryIndex,
62}
63
64impl CzdbMemory {
65 pub fn open(db_path: &str, key: &str) -> Result<Self, CzError> {
69 let mut file = File::open(db_path)?;
70 let mut data = Vec::new();
71 file.read_to_end(&mut data)?;
72 Self::from_bytes(data, key)
73 }
74
75 pub fn from_bytes(data: Vec<u8>, key: &str) -> Result<Self, CzError> {
79 let key_bytes = decode_aes_key(key)?;
80 let mut cursor = Cursor::new(&data);
81 let header = read_hyper_header(&mut cursor, &key_bytes)?;
82 let data_offset = (12 + header.padding_size + header.encrypted_block_size) as usize;
83 if data_offset > data.len() {
84 return Err(CzError::DatabaseFileCorrupted);
85 }
86 let file_size_total = data.len() as u64;
87 let meta = parse_meta_from_bytes(
88 &data[data_offset..],
89 file_size_total,
90 header.padding_size,
91 header.encrypted_block_size,
92 &key_bytes,
93 )?;
94 let memory_index = build_memory_index(&data[data_offset..], &meta)?;
95
96 Ok(Self {
97 meta,
98 memory_index,
99 })
100 }
101
102 pub fn search(&self, ip: IpAddr) -> Option<String> {
106 self.search_ref(ip).map(str::to_string)
107 }
108
109 pub fn search_ref(&self, ip: IpAddr) -> Option<&str> {
113 if !self.meta.db_type.compare(&ip) {
114 return None;
115 }
116 match ip {
117 IpAddr::V4(ip) => {
118 if self.memory_index.entries_v4.is_empty() {
119 return None;
120 }
121 let ip_num = u32::from_be_bytes(ip.octets());
122 let mut l = 0usize;
123 let mut h = self.memory_index.entries_v4.len() - 1;
124 while l <= h {
125 let m = (l + h) >> 1;
126 let entry = &self.memory_index.entries_v4[m];
127 if ip_num >= entry.start_ip && ip_num <= entry.end_ip {
128 return Some(self.memory_index.regions.get(entry.region_id));
129 } else if ip_num < entry.start_ip {
130 if m == 0 {
131 break;
132 }
133 h = m - 1;
134 } else {
135 l = m + 1;
136 }
137 }
138 None
139 }
140 IpAddr::V6(ip) => {
141 if self.memory_index.entries_v6.is_empty() {
142 return None;
143 }
144 let mut ip_bytes = [0u8; 16];
145 ip_bytes.copy_from_slice(&ip.octets());
146 let mut l = 0usize;
147 let mut h = self.memory_index.entries_v6.len() - 1;
148 while l <= h {
149 let m = (l + h) >> 1;
150 let entry = &self.memory_index.entries_v6[m];
151 let cmp_start = compare_bytes(&ip_bytes, &entry.start_ip, 16);
152 let cmp_end = compare_bytes(&ip_bytes, &entry.end_ip, 16);
153 if cmp_start != std::cmp::Ordering::Less
154 && cmp_end != std::cmp::Ordering::Greater
155 {
156 return Some(self.memory_index.regions.get(entry.region_id));
157 } else if cmp_start == std::cmp::Ordering::Less {
158 if m == 0 {
159 break;
160 }
161 h = m - 1;
162 } else {
163 l = m + 1;
164 }
165 }
166 None
167 }
168 }
169 }
170
171 pub fn search_many(&self, ips: &[IpAddr]) -> Vec<Option<String>> {
175 ips.iter().map(|ip| self.search(*ip)).collect()
176 }
177
178 pub fn search_many_ref<'a>(&'a self, ips: &[IpAddr]) -> Vec<Option<&'a str>> {
182 ips.iter().map(|ip| self.search_ref(*ip)).collect()
183 }
184
185 pub fn search_many_scan<'a>(&'a self, ips: &[IpAddr]) -> Vec<Option<&'a str>> {
189 let mut results = vec![None; ips.len()];
190 let mut v4 = Vec::new();
191 let mut v6 = Vec::new();
192 for (idx, ip) in ips.iter().copied().enumerate() {
193 match ip {
194 IpAddr::V4(ipv4) => v4.push((u32::from_be_bytes(ipv4.octets()), idx)),
195 IpAddr::V6(ipv6) => v6.push((ipv6.octets(), idx)),
196 }
197 }
198
199 if !v4.is_empty() && !self.memory_index.entries_v4.is_empty() {
200 v4.sort_unstable_by_key(|(ip, _)| *ip);
201 let mut entry_idx = 0usize;
202 for (ip_num, original_idx) in v4 {
203 while entry_idx < self.memory_index.entries_v4.len()
204 && self.memory_index.entries_v4[entry_idx].end_ip < ip_num
205 {
206 entry_idx += 1;
207 }
208 if entry_idx >= self.memory_index.entries_v4.len() {
209 break;
210 }
211 let entry = &self.memory_index.entries_v4[entry_idx];
212 if ip_num >= entry.start_ip && ip_num <= entry.end_ip {
213 results[original_idx] = Some(self.memory_index.regions.get(entry.region_id));
214 }
215 }
216 }
217
218 if !v6.is_empty() && !self.memory_index.entries_v6.is_empty() {
219 v6.sort_unstable_by(|(a, _), (b, _)| compare_bytes(a, b, 16));
220 let mut entry_idx = 0usize;
221 for (ip_bytes, original_idx) in v6 {
222 while entry_idx < self.memory_index.entries_v6.len()
223 && compare_bytes(&self.memory_index.entries_v6[entry_idx].end_ip, &ip_bytes, 16)
224 == std::cmp::Ordering::Less
225 {
226 entry_idx += 1;
227 }
228 if entry_idx >= self.memory_index.entries_v6.len() {
229 break;
230 }
231 let entry = &self.memory_index.entries_v6[entry_idx];
232 let cmp_start = compare_bytes(&ip_bytes, &entry.start_ip, 16);
233 let cmp_end = compare_bytes(&ip_bytes, &entry.end_ip, 16);
234 if cmp_start != std::cmp::Ordering::Less
235 && cmp_end != std::cmp::Ordering::Greater
236 {
237 results[original_idx] = Some(self.memory_index.regions.get(entry.region_id));
238 }
239 }
240 }
241
242 results
243 }
244
245 pub fn db_type(&self) -> DbType {
249 self.meta.db_type
250 }
251}
252
253fn build_memory_index(bindata: &[u8], meta: &DbMeta) -> Result<MemoryIndex, CzError> {
254 let ip_len = meta.db_type.bytes_len();
255 let blen = meta.db_type.index_block_len();
256 let start = meta.start_index as usize;
257 let end = meta.end_index as usize;
258
259 if end < start {
260 return Err(CzError::DatabaseFileCorrupted);
261 }
262 if end + blen > bindata.len() {
263 return Err(CzError::DatabaseFileCorrupted);
264 }
265
266 let total_blocks = (end - start) / blen + 1;
267 let mut entries_v4 = Vec::with_capacity(total_blocks);
268 let mut entries_v6 = Vec::with_capacity(total_blocks);
269 let mut regions = Vec::<RegionSpan>::new();
270 let mut region_text = String::new();
271 let mut region_cache = HashMap::<(usize, usize), usize>::new();
272
273 let mut p = start;
274 while p <= end {
275 if p + blen > bindata.len() {
276 return Err(CzError::DatabaseFileCorrupted);
277 }
278 let mut start_ip_bytes = [0u8; 16];
279 let mut end_ip_bytes = [0u8; 16];
280 start_ip_bytes[..ip_len].copy_from_slice(&bindata[p..p + ip_len]);
281 end_ip_bytes[..ip_len].copy_from_slice(&bindata[p + ip_len..p + ip_len * 2]);
282 let data_ptr = u32::from_le_bytes([
283 bindata[p + ip_len * 2],
284 bindata[p + ip_len * 2 + 1],
285 bindata[p + ip_len * 2 + 2],
286 bindata[p + ip_len * 2 + 3],
287 ]) as usize;
288 let data_len = bindata[p + ip_len * 2 + 4] as usize;
289
290 let region_id = match region_cache.get(&(data_ptr, data_len)) {
291 Some(id) => *id,
292 None => {
293 if data_ptr + data_len > bindata.len() {
294 return Err(CzError::DatabaseFileCorrupted);
295 }
296 let region = decode_region_from_bytes(
297 &bindata[data_ptr..data_ptr + data_len],
298 meta,
299 )
300 .ok_or(CzError::DatabaseFileCorrupted)?;
301 let start_offset = region_text.len();
302 region_text.push_str(®ion);
303 let len = region.len();
304 let id = regions.len();
305 regions.push(RegionSpan {
306 start: start_offset,
307 len,
308 });
309 region_cache.insert((data_ptr, data_len), id);
310 id
311 }
312 };
313
314 if meta.db_type == DbType::Ipv4 {
315 let start_ip = u32::from_be_bytes(start_ip_bytes[..4].try_into().unwrap());
316 let end_ip = u32::from_be_bytes(end_ip_bytes[..4].try_into().unwrap());
317 entries_v4.push(IndexEntryV4 {
318 start_ip,
319 end_ip,
320 region_id,
321 });
322 } else {
323 entries_v6.push(IndexEntryV6 {
324 start_ip: start_ip_bytes,
325 end_ip: end_ip_bytes,
326 region_id,
327 });
328 }
329
330 p += blen;
331 }
332
333 Ok(MemoryIndex {
334 entries_v4,
335 entries_v6,
336 regions: RegionPool {
337 data: region_text.into_boxed_str(),
338 spans: regions,
339 },
340 })
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346 use rmpv::{Value, encode::write_value};
347 use std::net::Ipv4Addr;
348
349 fn build_test_db() -> CzdbMemory {
350 let block_len = DbType::Ipv4.index_block_len();
351 let padding = 4usize;
352 let mut bindata = vec![0u8; padding + block_len * 2];
353
354 let mut region1 = Vec::new();
355 write_value(&mut region1, &Value::Integer(0.into())).unwrap();
356 write_value(&mut region1, &Value::String("region1".into())).unwrap();
357
358 let mut region2 = Vec::new();
359 write_value(&mut region2, &Value::Integer(0.into())).unwrap();
360 write_value(&mut region2, &Value::String("region2".into())).unwrap();
361
362 let region1_ptr = (padding + block_len * 2) as u32;
363 let region2_ptr = region1_ptr + region1.len() as u32;
364
365 let first_offset = padding;
366 bindata[first_offset..first_offset + 4].copy_from_slice(&[1, 1, 1, 0]);
367 bindata[first_offset + 4..first_offset + 8].copy_from_slice(&[1, 1, 1, 255]);
368 bindata[first_offset + 8..first_offset + 12].copy_from_slice(®ion1_ptr.to_le_bytes());
369 bindata[first_offset + 12] = region1.len() as u8;
370
371 let offset = padding + block_len;
372 bindata[offset..offset + 4].copy_from_slice(&[2, 2, 2, 0]);
373 bindata[offset + 4..offset + 8].copy_from_slice(&[2, 2, 2, 255]);
374 bindata[offset + 8..offset + 12].copy_from_slice(®ion2_ptr.to_le_bytes());
375 bindata[offset + 12] = region2.len() as u8;
376
377 bindata.extend_from_slice(®ion1);
378 bindata.extend_from_slice(®ion2);
379
380 let mut header_sip = Vec::new();
381 let mut header_ptr = Vec::new();
382 let mut ip1 = [0u8; 16];
383 let mut ip2 = [0u8; 16];
384 ip1[..4].copy_from_slice(&[1, 1, 1, 0]);
385 ip2[..4].copy_from_slice(&[2, 2, 2, 0]);
386 header_sip.push(ip1);
387 header_sip.push(ip2);
388 header_ptr.push(first_offset as u32);
389 header_ptr.push(offset as u32);
390
391 let meta = DbMeta {
392 db_type: DbType::Ipv4,
393 header_sip,
394 header_ptr,
395 column_selection: 0,
396 geo_map_data: None,
397 start_index: first_offset as u32,
398 end_index: offset as u32,
399 };
400
401 let memory_index = build_memory_index(&bindata, &meta).unwrap();
402
403 let _ = bindata;
404 CzdbMemory { meta, memory_index }
405 }
406
407 #[test]
408 fn search_handles_start_boundary_correctly() {
409 let db = build_test_db();
410 assert_eq!(
411 db.search(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 0))),
412 Some("region1".to_string())
413 );
414 }
415
416 #[test]
417 fn search_returns_expected_results() {
418 let db = build_test_db();
419 assert_eq!(
420 db.search(IpAddr::V4(Ipv4Addr::new(2, 2, 2, 2))),
421 Some("region2".to_string())
422 );
423 assert!(db.search(IpAddr::V4(Ipv4Addr::new(3, 3, 3, 3))).is_none());
424 }
425}