matchy_format/mmdb/
tree.rs

1//! MMDB Search Tree Traversal
2//!
3//! Implements binary search tree traversal for IP address lookups.
4//! The tree uses a compact binary representation where each node contains
5//! two records (left and right) that point to either:
6//! - Another node (continue traversal)
7//! - A data section offset (found)
8//! - A "not found" marker
9
10use super::format::MmdbHeader;
11use super::types::{MmdbError, RecordSize};
12use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
13
14/// Result of an IP lookup
15#[derive(Debug, Clone, PartialEq)]
16pub struct LookupResult {
17    /// Offset into the data section (relative to data section start)
18    pub data_offset: u32,
19    /// Network prefix length (netmask)
20    pub prefix_len: u8,
21}
22
23/// Search tree for IP address lookups
24pub struct SearchTree<'a> {
25    /// The raw file data containing the tree
26    data: &'a [u8],
27    /// Parsed header information
28    header: &'a MmdbHeader,
29}
30
31impl<'a> SearchTree<'a> {
32    /// Create a new search tree
33    #[must_use]
34    pub fn new(data: &'a [u8], header: &'a MmdbHeader) -> Self {
35        Self { data, header }
36    }
37
38    /// Look up an IP address
39    pub fn lookup(&self, ip: IpAddr) -> Result<Option<LookupResult>, MmdbError> {
40        match ip {
41            IpAddr::V4(addr) => self.lookup_v4(addr),
42            IpAddr::V6(addr) => self.lookup_v6(addr),
43        }
44    }
45
46    /// Look up an IPv4 address
47    pub fn lookup_v4(&self, addr: Ipv4Addr) -> Result<Option<LookupResult>, MmdbError> {
48        use super::types::IpVersion;
49
50        // Check if this is an IPv6 tree
51        let (mut node, mut depth) = if self.header.ip_version == IpVersion::V6 {
52            // IPv4 addresses in IPv6 trees require finding the IPv4 start node first.
53            // Per MMDB spec and libmaxminddb, we traverse 96 zero bits (::ffff:0:0/96)
54            // to reach the IPv4 address space within the IPv6 tree.
55            self.find_ipv4_start_node()?
56        } else {
57            // Pure IPv4 tree - start at root
58            (0u32, 0u8)
59        };
60
61        // Now traverse the IPv4 address bits
62        let bits = ipv4_to_bits(addr);
63
64        for bit_index in 0..32 {
65            let bit = ((bits >> (31 - bit_index)) & 1) as u8;
66            let record = self.read_record(node as usize, bit)?;
67
68            if record == self.header.node_count {
69                return Ok(None);
70            } else if record < self.header.node_count {
71                node = record;
72                depth += 1;
73            } else {
74                let data_offset = self.calculate_data_offset(record)?;
75                // For IPv4 lookups, report the prefix as IPv4 prefix length
76                // (depth includes the 96 bits traversed for IPv6 tree, so subtract them)
77                let ipv4_prefix = if depth >= 96 {
78                    depth - 96 + 1
79                } else {
80                    depth + 1
81                };
82                return Ok(Some(LookupResult {
83                    data_offset,
84                    prefix_len: ipv4_prefix,
85                }));
86            }
87        }
88
89        Ok(None)
90    }
91
92    /// Look up an IPv6 address
93    pub fn lookup_v6(&self, addr: Ipv6Addr) -> Result<Option<LookupResult>, MmdbError> {
94        // Convert IPv6 to bits
95        let bits = ipv6_to_bits(addr);
96
97        let mut node = 0u32;
98        let mut depth = 0u8;
99        let max_depth = 128;
100
101        for bit_index in 0..max_depth {
102            // Extract bit from 128-bit value
103            let bit = if bit_index < 64 {
104                (bits.0 >> (63 - bit_index)) & 1
105            } else {
106                (bits.1 >> (127 - bit_index)) & 1
107            };
108
109            let record = self.read_record(node as usize, u8::try_from(bit).unwrap())?;
110
111            if record == self.header.node_count {
112                return Ok(None);
113            } else if record < self.header.node_count {
114                node = record;
115                depth = bit_index + 1;
116            } else {
117                let data_offset = self.calculate_data_offset(record)?;
118                return Ok(Some(LookupResult {
119                    data_offset,
120                    prefix_len: depth + 1,
121                }));
122            }
123        }
124
125        Ok(None)
126    }
127
128    /// Read a record from a node
129    ///
130    /// Each node contains two records. `side` determines which:
131    /// - 0 = left record (for IP bit 0)
132    /// - 1 = right record (for IP bit 1)
133    fn read_record(&self, node: usize, side: u8) -> Result<u32, MmdbError> {
134        if node >= usize::try_from(self.header.node_count).unwrap_or(usize::MAX) {
135            return Err(MmdbError::InvalidFormat(format!(
136                "Node index {} exceeds node count {}",
137                node, self.header.node_count
138            )));
139        }
140
141        match self.header.record_size {
142            RecordSize::Bits24 => self.read_24bit_record(node, side),
143            RecordSize::Bits28 => self.read_28bit_record(node, side),
144            RecordSize::Bits32 => self.read_32bit_record(node, side),
145        }
146    }
147
148    /// Read a 24-bit record (3 bytes per record, 6 bytes per node)
149    fn read_24bit_record(&self, node: usize, side: u8) -> Result<u32, MmdbError> {
150        let node_offset = node * 6; // 6 bytes per node
151        let record_offset = node_offset + (side as usize * 3);
152
153        if record_offset + 3 > self.header.tree_size {
154            return Err(MmdbError::InvalidFormat(format!(
155                "Record offset {} exceeds tree size {}",
156                record_offset, self.header.tree_size
157            )));
158        }
159
160        // Read 3 bytes in big-endian order
161        let b0 = u32::from(self.data[record_offset]);
162        let b1 = u32::from(self.data[record_offset + 1]);
163        let b2 = u32::from(self.data[record_offset + 2]);
164
165        Ok((b0 << 16) | (b1 << 8) | b2)
166    }
167
168    /// Read a 28-bit record (3.5 bytes per record, 7 bytes per node)
169    ///
170    /// Layout: [Left 24 bits][Middle 8 bits][Right 24 bits]
171    /// Middle byte contains 4 high bits of left + 4 high bits of right
172    fn read_28bit_record(&self, node: usize, side: u8) -> Result<u32, MmdbError> {
173        let node_offset = node * 7; // 7 bytes per node
174
175        if node_offset + 7 > self.header.tree_size {
176            return Err(MmdbError::InvalidFormat(format!(
177                "Node offset {} exceeds tree size {}",
178                node_offset, self.header.tree_size
179            )));
180        }
181
182        let bytes = &self.data[node_offset..node_offset + 7];
183
184        if side == 0 {
185            // Left record: bytes[0..3] with 4 high bits from middle byte
186            let high_bits = u32::from((bytes[3] >> 4) & 0x0F);
187            let low_bits =
188                (u32::from(bytes[0]) << 16) | (u32::from(bytes[1]) << 8) | u32::from(bytes[2]);
189            Ok((high_bits << 24) | low_bits)
190        } else {
191            // Right record: bytes[4..7] with 4 low bits from middle byte
192            let high_bits = u32::from(bytes[3] & 0x0F);
193            let low_bits =
194                (u32::from(bytes[4]) << 16) | (u32::from(bytes[5]) << 8) | u32::from(bytes[6]);
195            Ok((high_bits << 24) | low_bits)
196        }
197    }
198
199    /// Read a 32-bit record (4 bytes per record, 8 bytes per node)
200    fn read_32bit_record(&self, node: usize, side: u8) -> Result<u32, MmdbError> {
201        let node_offset = node * 8; // 8 bytes per node
202        let record_offset = node_offset + (side as usize * 4);
203
204        if record_offset + 4 > self.header.tree_size {
205            return Err(MmdbError::InvalidFormat(format!(
206                "Record offset {} exceeds tree size {}",
207                record_offset, self.header.tree_size
208            )));
209        }
210
211        // Read 4 bytes in big-endian order
212        let b0 = u32::from(self.data[record_offset]);
213        let b1 = u32::from(self.data[record_offset + 1]);
214        let b2 = u32::from(self.data[record_offset + 2]);
215        let b3 = u32::from(self.data[record_offset + 3]);
216
217        Ok((b0 << 24) | (b1 << 16) | (b2 << 8) | b3)
218    }
219
220    /// Calculate data section offset from record value
221    ///
222    /// Per MMDB spec:
223    /// - Record value > node_count means it points to data
224    /// - Formula: data_offset = (record_value - node_count) - 16
225    /// - The 16 is the data section separator size
226    fn calculate_data_offset(&self, record: u32) -> Result<u32, MmdbError> {
227        if record <= self.header.node_count {
228            return Err(MmdbError::InvalidFormat(format!(
229                "Record {} is not a data pointer (node_count = {})",
230                record, self.header.node_count
231            )));
232        }
233
234        // Per spec: subtract node count, then subtract 16 for separator
235        let offset_before_separator =
236            record.checked_sub(self.header.node_count).ok_or_else(|| {
237                MmdbError::InvalidFormat(format!(
238                    "Record {} - node_count {} underflow",
239                    record, self.header.node_count
240                ))
241            })?;
242
243        let offset = offset_before_separator.checked_sub(16).ok_or_else(|| {
244            MmdbError::InvalidFormat(format!(
245                "Data pointer {} - 16 underflow (record={}, node_count={})",
246                offset_before_separator, record, self.header.node_count
247            ))
248        })?;
249
250        Ok(offset)
251    }
252
253    /// Find the IPv4 start node in an IPv6 tree
254    ///
255    /// Per MMDB spec, IPv4 addresses in IPv6 trees are accessed via the
256    /// ::ffff:0:0/96 prefix. We traverse 96 zero bits to find where the
257    /// IPv4 address space begins.
258    ///
259    /// Returns (node, depth) where node is the starting node for IPv4 lookups
260    /// and depth is 96 (the number of bits traversed).
261    fn find_ipv4_start_node(&self) -> Result<(u32, u8), MmdbError> {
262        let mut node = 0u32;
263
264        // Traverse 96 zero bits (left record each time)
265        for _ in 0..96 {
266            let record = self.read_record(node as usize, 0)?;
267
268            if record == self.header.node_count {
269                // IPv4 space not found in this tree
270                return Ok((node, 96));
271            } else if record < self.header.node_count {
272                node = record;
273            } else {
274                // Shouldn't hit data in the first 96 bits, but handle it
275                return Ok((node, 96));
276            }
277        }
278
279        Ok((node, 96))
280    }
281}
282
283/// Convert IPv4 address to 32-bit integer
284fn ipv4_to_bits(addr: Ipv4Addr) -> u32 {
285    let octets = addr.octets();
286    (u32::from(octets[0]) << 24)
287        | (u32::from(octets[1]) << 16)
288        | (u32::from(octets[2]) << 8)
289        | u32::from(octets[3])
290}
291
292/// Convert IPv6 address to 128-bit integer (as two u64s)
293fn ipv6_to_bits(addr: Ipv6Addr) -> (u64, u64) {
294    let segments = addr.segments();
295    let high = (u64::from(segments[0]) << 48)
296        | (u64::from(segments[1]) << 32)
297        | (u64::from(segments[2]) << 16)
298        | u64::from(segments[3]);
299    let low = (u64::from(segments[4]) << 48)
300        | (u64::from(segments[5]) << 32)
301        | (u64::from(segments[6]) << 16)
302        | u64::from(segments[7]);
303    (high, low)
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    #[test]
311    fn test_ipv4_to_bits() {
312        let addr = Ipv4Addr::new(192, 168, 1, 1);
313        let bits = ipv4_to_bits(addr);
314        assert_eq!(bits, 0xC0A80101);
315    }
316
317    #[test]
318    fn test_ipv6_to_bits() {
319        let addr = Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 1);
320        let (high, low) = ipv6_to_bits(addr);
321        assert_eq!(high, 0x20010db800000000);
322        assert_eq!(low, 0x0000000000000001);
323    }
324
325    #[test]
326    fn test_read_24bit_record() {
327        use super::super::types::IpVersion;
328
329        // Create a small test tree with 24-bit records
330        // Node 0: left=1, right=2
331        let mut data = vec![0u8; 1000];
332        data[0] = 0x00; // left record high byte
333        data[1] = 0x00;
334        data[2] = 0x01; // left = 1
335        data[3] = 0x00; // right record high byte
336        data[4] = 0x00;
337        data[5] = 0x02; // right = 2
338
339        let header = MmdbHeader {
340            node_count: 10,
341            record_size: RecordSize::Bits24,
342            ip_version: IpVersion::V6,
343            tree_size: 60, // 10 nodes * 6 bytes
344        };
345
346        let tree = SearchTree::new(&data, &header);
347
348        assert_eq!(tree.read_24bit_record(0, 0).unwrap(), 1);
349        assert_eq!(tree.read_24bit_record(0, 1).unwrap(), 2);
350    }
351
352    #[test]
353    fn test_read_28bit_record() {
354        use super::super::types::IpVersion;
355
356        // Create test data for 28-bit records
357        let mut data = vec![0u8; 1000];
358        // Node 0 with 28-bit records
359        // Left: 0x1000001, Right: 0x2000002
360        data[0] = 0x00; // left low 24 bits
361        data[1] = 0x00;
362        data[2] = 0x01;
363        data[3] = 0x12; // middle byte: 0x1 for left high, 0x2 for right high
364        data[4] = 0x00; // right low 24 bits
365        data[5] = 0x00;
366        data[6] = 0x02;
367
368        let header = MmdbHeader {
369            node_count: 10,
370            record_size: RecordSize::Bits28,
371            ip_version: IpVersion::V6,
372            tree_size: 70, // 10 nodes * 7 bytes
373        };
374
375        let tree = SearchTree::new(&data, &header);
376
377        assert_eq!(tree.read_28bit_record(0, 0).unwrap(), 0x1000001);
378        assert_eq!(tree.read_28bit_record(0, 1).unwrap(), 0x2000002);
379    }
380
381    #[test]
382    fn test_calculate_data_offset() {
383        use super::super::types::IpVersion;
384
385        let header = MmdbHeader {
386            node_count: 100,
387            record_size: RecordSize::Bits24,
388            ip_version: IpVersion::V6,
389            tree_size: 600,
390        };
391
392        let tree = SearchTree::new(&[], &header);
393
394        // Record 116 -> data offset 0
395        // (116 - 100 - 16 = 0)
396        assert_eq!(tree.calculate_data_offset(116).unwrap(), 0);
397
398        // Record 200 -> data offset 84
399        // (200 - 100 - 16 = 84)
400        assert_eq!(tree.calculate_data_offset(200).unwrap(), 84);
401    }
402
403    #[test]
404    fn test_lookup_with_real_database() {
405        // This test uses the actual GeoLite2-Country.mmdb file
406        let data = include_bytes!("../../tests/data/GeoLite2-Country.mmdb");
407
408        // Parse header
409        let header = MmdbHeader::from_file(data).unwrap();
410        let tree = SearchTree::new(data, &header);
411
412        // Test a known IP (1.1.1.1 - Cloudflare, should be in database)
413        let ip = Ipv4Addr::new(1, 1, 1, 1);
414        let result = tree.lookup_v4(ip).unwrap();
415
416        // Should find something for this well-known IP
417        assert!(result.is_some(), "Should find data for 1.1.1.1");
418
419        if let Some(lookup_result) = result {
420            assert!(
421                lookup_result.data_offset > 0,
422                "Data offset should be non-zero"
423            );
424            assert!(
425                lookup_result.prefix_len > 0,
426                "Prefix length should be positive"
427            );
428            assert!(
429                lookup_result.prefix_len <= 32,
430                "IPv4 prefix should be <= 32"
431            );
432        }
433
434        // Test another well-known IP
435        let ip2 = Ipv4Addr::new(8, 8, 8, 8);
436        let result2 = tree.lookup_v4(ip2).unwrap();
437        assert!(result2.is_some(), "Should find data for 8.8.8.8");
438    }
439}