1use super::format::MmdbHeader;
11use super::types::{MmdbError, RecordSize};
12use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
13
14#[derive(Debug, Clone, PartialEq)]
16pub struct LookupResult {
17 pub data_offset: u32,
19 pub prefix_len: u8,
21}
22
23pub struct SearchTree<'a> {
25 data: &'a [u8],
27 header: &'a MmdbHeader,
29}
30
31impl<'a> SearchTree<'a> {
32 #[must_use]
34 pub fn new(data: &'a [u8], header: &'a MmdbHeader) -> Self {
35 Self { data, header }
36 }
37
38 pub fn lookup(&self, ip: IpAddr) -> Result<Option<LookupResult>, MmdbError> {
40 match ip {
41 IpAddr::V4(addr) => self.lookup_v4(addr),
42 IpAddr::V6(addr) => self.lookup_v6(addr),
43 }
44 }
45
46 pub fn lookup_v4(&self, addr: Ipv4Addr) -> Result<Option<LookupResult>, MmdbError> {
48 use super::types::IpVersion;
49
50 let (mut node, mut depth) = if self.header.ip_version == IpVersion::V6 {
52 self.find_ipv4_start_node()?
56 } else {
57 (0u32, 0u8)
59 };
60
61 let bits = ipv4_to_bits(addr);
63
64 for bit_index in 0..32 {
65 let bit = ((bits >> (31 - bit_index)) & 1) as u8;
66 let record = self.read_record(node as usize, bit)?;
67
68 if record == self.header.node_count {
69 return Ok(None);
70 } else if record < self.header.node_count {
71 node = record;
72 depth += 1;
73 } else {
74 let data_offset = self.calculate_data_offset(record)?;
75 let ipv4_prefix = if depth >= 96 {
78 depth - 96 + 1
79 } else {
80 depth + 1
81 };
82 return Ok(Some(LookupResult {
83 data_offset,
84 prefix_len: ipv4_prefix,
85 }));
86 }
87 }
88
89 Ok(None)
90 }
91
92 pub fn lookup_v6(&self, addr: Ipv6Addr) -> Result<Option<LookupResult>, MmdbError> {
94 let bits = ipv6_to_bits(addr);
96
97 let mut node = 0u32;
98 let mut depth = 0u8;
99 let max_depth = 128;
100
101 for bit_index in 0..max_depth {
102 let bit = if bit_index < 64 {
104 (bits.0 >> (63 - bit_index)) & 1
105 } else {
106 (bits.1 >> (127 - bit_index)) & 1
107 };
108
109 let record = self.read_record(node as usize, u8::try_from(bit).unwrap())?;
110
111 if record == self.header.node_count {
112 return Ok(None);
113 } else if record < self.header.node_count {
114 node = record;
115 depth = bit_index + 1;
116 } else {
117 let data_offset = self.calculate_data_offset(record)?;
118 return Ok(Some(LookupResult {
119 data_offset,
120 prefix_len: depth + 1,
121 }));
122 }
123 }
124
125 Ok(None)
126 }
127
128 fn read_record(&self, node: usize, side: u8) -> Result<u32, MmdbError> {
134 if node >= usize::try_from(self.header.node_count).unwrap_or(usize::MAX) {
135 return Err(MmdbError::InvalidFormat(format!(
136 "Node index {} exceeds node count {}",
137 node, self.header.node_count
138 )));
139 }
140
141 match self.header.record_size {
142 RecordSize::Bits24 => self.read_24bit_record(node, side),
143 RecordSize::Bits28 => self.read_28bit_record(node, side),
144 RecordSize::Bits32 => self.read_32bit_record(node, side),
145 }
146 }
147
148 fn read_24bit_record(&self, node: usize, side: u8) -> Result<u32, MmdbError> {
150 let node_offset = node * 6; let record_offset = node_offset + (side as usize * 3);
152
153 if record_offset + 3 > self.header.tree_size {
154 return Err(MmdbError::InvalidFormat(format!(
155 "Record offset {} exceeds tree size {}",
156 record_offset, self.header.tree_size
157 )));
158 }
159
160 let b0 = u32::from(self.data[record_offset]);
162 let b1 = u32::from(self.data[record_offset + 1]);
163 let b2 = u32::from(self.data[record_offset + 2]);
164
165 Ok((b0 << 16) | (b1 << 8) | b2)
166 }
167
168 fn read_28bit_record(&self, node: usize, side: u8) -> Result<u32, MmdbError> {
173 let node_offset = node * 7; if node_offset + 7 > self.header.tree_size {
176 return Err(MmdbError::InvalidFormat(format!(
177 "Node offset {} exceeds tree size {}",
178 node_offset, self.header.tree_size
179 )));
180 }
181
182 let bytes = &self.data[node_offset..node_offset + 7];
183
184 if side == 0 {
185 let high_bits = u32::from((bytes[3] >> 4) & 0x0F);
187 let low_bits =
188 (u32::from(bytes[0]) << 16) | (u32::from(bytes[1]) << 8) | u32::from(bytes[2]);
189 Ok((high_bits << 24) | low_bits)
190 } else {
191 let high_bits = u32::from(bytes[3] & 0x0F);
193 let low_bits =
194 (u32::from(bytes[4]) << 16) | (u32::from(bytes[5]) << 8) | u32::from(bytes[6]);
195 Ok((high_bits << 24) | low_bits)
196 }
197 }
198
199 fn read_32bit_record(&self, node: usize, side: u8) -> Result<u32, MmdbError> {
201 let node_offset = node * 8; let record_offset = node_offset + (side as usize * 4);
203
204 if record_offset + 4 > self.header.tree_size {
205 return Err(MmdbError::InvalidFormat(format!(
206 "Record offset {} exceeds tree size {}",
207 record_offset, self.header.tree_size
208 )));
209 }
210
211 let b0 = u32::from(self.data[record_offset]);
213 let b1 = u32::from(self.data[record_offset + 1]);
214 let b2 = u32::from(self.data[record_offset + 2]);
215 let b3 = u32::from(self.data[record_offset + 3]);
216
217 Ok((b0 << 24) | (b1 << 16) | (b2 << 8) | b3)
218 }
219
220 fn calculate_data_offset(&self, record: u32) -> Result<u32, MmdbError> {
227 if record <= self.header.node_count {
228 return Err(MmdbError::InvalidFormat(format!(
229 "Record {} is not a data pointer (node_count = {})",
230 record, self.header.node_count
231 )));
232 }
233
234 let offset_before_separator =
236 record.checked_sub(self.header.node_count).ok_or_else(|| {
237 MmdbError::InvalidFormat(format!(
238 "Record {} - node_count {} underflow",
239 record, self.header.node_count
240 ))
241 })?;
242
243 let offset = offset_before_separator.checked_sub(16).ok_or_else(|| {
244 MmdbError::InvalidFormat(format!(
245 "Data pointer {} - 16 underflow (record={}, node_count={})",
246 offset_before_separator, record, self.header.node_count
247 ))
248 })?;
249
250 Ok(offset)
251 }
252
253 fn find_ipv4_start_node(&self) -> Result<(u32, u8), MmdbError> {
262 let mut node = 0u32;
263
264 for _ in 0..96 {
266 let record = self.read_record(node as usize, 0)?;
267
268 if record == self.header.node_count {
269 return Ok((node, 96));
271 } else if record < self.header.node_count {
272 node = record;
273 } else {
274 return Ok((node, 96));
276 }
277 }
278
279 Ok((node, 96))
280 }
281}
282
283fn ipv4_to_bits(addr: Ipv4Addr) -> u32 {
285 let octets = addr.octets();
286 (u32::from(octets[0]) << 24)
287 | (u32::from(octets[1]) << 16)
288 | (u32::from(octets[2]) << 8)
289 | u32::from(octets[3])
290}
291
292fn ipv6_to_bits(addr: Ipv6Addr) -> (u64, u64) {
294 let segments = addr.segments();
295 let high = (u64::from(segments[0]) << 48)
296 | (u64::from(segments[1]) << 32)
297 | (u64::from(segments[2]) << 16)
298 | u64::from(segments[3]);
299 let low = (u64::from(segments[4]) << 48)
300 | (u64::from(segments[5]) << 32)
301 | (u64::from(segments[6]) << 16)
302 | u64::from(segments[7]);
303 (high, low)
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[test]
311 fn test_ipv4_to_bits() {
312 let addr = Ipv4Addr::new(192, 168, 1, 1);
313 let bits = ipv4_to_bits(addr);
314 assert_eq!(bits, 0xC0A80101);
315 }
316
317 #[test]
318 fn test_ipv6_to_bits() {
319 let addr = Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 1);
320 let (high, low) = ipv6_to_bits(addr);
321 assert_eq!(high, 0x20010db800000000);
322 assert_eq!(low, 0x0000000000000001);
323 }
324
325 #[test]
326 fn test_read_24bit_record() {
327 use super::super::types::IpVersion;
328
329 let mut data = vec![0u8; 1000];
332 data[0] = 0x00; data[1] = 0x00;
334 data[2] = 0x01; data[3] = 0x00; data[4] = 0x00;
337 data[5] = 0x02; let header = MmdbHeader {
340 node_count: 10,
341 record_size: RecordSize::Bits24,
342 ip_version: IpVersion::V6,
343 tree_size: 60, };
345
346 let tree = SearchTree::new(&data, &header);
347
348 assert_eq!(tree.read_24bit_record(0, 0).unwrap(), 1);
349 assert_eq!(tree.read_24bit_record(0, 1).unwrap(), 2);
350 }
351
352 #[test]
353 fn test_read_28bit_record() {
354 use super::super::types::IpVersion;
355
356 let mut data = vec![0u8; 1000];
358 data[0] = 0x00; data[1] = 0x00;
362 data[2] = 0x01;
363 data[3] = 0x12; data[4] = 0x00; data[5] = 0x00;
366 data[6] = 0x02;
367
368 let header = MmdbHeader {
369 node_count: 10,
370 record_size: RecordSize::Bits28,
371 ip_version: IpVersion::V6,
372 tree_size: 70, };
374
375 let tree = SearchTree::new(&data, &header);
376
377 assert_eq!(tree.read_28bit_record(0, 0).unwrap(), 0x1000001);
378 assert_eq!(tree.read_28bit_record(0, 1).unwrap(), 0x2000002);
379 }
380
381 #[test]
382 fn test_calculate_data_offset() {
383 use super::super::types::IpVersion;
384
385 let header = MmdbHeader {
386 node_count: 100,
387 record_size: RecordSize::Bits24,
388 ip_version: IpVersion::V6,
389 tree_size: 600,
390 };
391
392 let tree = SearchTree::new(&[], &header);
393
394 assert_eq!(tree.calculate_data_offset(116).unwrap(), 0);
397
398 assert_eq!(tree.calculate_data_offset(200).unwrap(), 84);
401 }
402
403 #[test]
404 fn test_lookup_with_real_database() {
405 let data = include_bytes!("../../tests/data/GeoLite2-Country.mmdb");
407
408 let header = MmdbHeader::from_file(data).unwrap();
410 let tree = SearchTree::new(data, &header);
411
412 let ip = Ipv4Addr::new(1, 1, 1, 1);
414 let result = tree.lookup_v4(ip).unwrap();
415
416 assert!(result.is_some(), "Should find data for 1.1.1.1");
418
419 if let Some(lookup_result) = result {
420 assert!(
421 lookup_result.data_offset > 0,
422 "Data offset should be non-zero"
423 );
424 assert!(
425 lookup_result.prefix_len > 0,
426 "Prefix length should be positive"
427 );
428 assert!(
429 lookup_result.prefix_len <= 32,
430 "IPv4 prefix should be <= 32"
431 );
432 }
433
434 let ip2 = Ipv4Addr::new(8, 8, 8, 8);
436 let result2 = tree.lookup_v4(ip2).unwrap();
437 assert!(result2.is_some(), "Should find data for 8.8.8.8");
438 }
439}