ipdb/
lib.rs

1use std::cmp::Ordering;
2use std::fmt::{self, Display, Formatter};
3use std::io;
4use std::marker::PhantomData;
5use std::str;
6
7use std::collections::HashMap;
8use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
9use std::path::Path;
10
11use ipnetwork::IpNetwork;
12use serde::{de, Deserialize};
13use serde_json;
14
15use city::CityInfo;
16
17#[cfg(feature = "mmap")]
18pub use memmap2::Mmap;
19#[cfg(feature = "mmap")]
20use memmap2::MmapOptions;
21#[cfg(feature = "mmap")]
22use std::fs::File;
23
24const IPV4: u16 = 0x01;
25const IPV6: u16 = 0x02;
26
27#[derive(Debug, PartialEq, Eq)]
28pub enum IPDBError {
29    FileSizeError(String),
30    MetaDataError(String),
31    IOError(String),
32
33    DatabaseError(String),
34    OutOfBoundError(usize, usize),
35    IPFormatError(String),
36
37    NotSupportedError(String),
38
39    DataNotFoundError(String),
40    InvalidNetworkError(String),
41}
42
43impl From<io::Error> for IPDBError {
44    fn from(err: io::Error) -> IPDBError {
45        // clean up and clean up MaxMindDBError generally
46        IPDBError::IOError(err.to_string())
47    }
48}
49
50impl Display for IPDBError {
51    fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), fmt::Error> {
52        match self {
53            IPDBError::FileSizeError(msg) => write!(fmt, "FileSizeError: {}", msg)?,
54            IPDBError::MetaDataError(msg) => write!(fmt, "MetaDataError: {}", msg)?,
55            IPDBError::IOError(msg) => write!(fmt, "IOError: {}", msg)?,
56
57            IPDBError::DatabaseError(msg) => write!(fmt, "DatabaseError: {}", msg)?,
58            IPDBError::OutOfBoundError(a, b) => write!(fmt, "OutOfBoundError: {} > {}", a, b)?,
59
60            IPDBError::IPFormatError(msg) => write!(fmt, "IPFormatError: {}", msg)?,
61            IPDBError::NotSupportedError(msg) => write!(fmt, "NotSupportedError: {}", msg)?,
62            IPDBError::DataNotFoundError(msg) => write!(fmt, "DataNotFoundError: {}", msg)?,
63
64            IPDBError::InvalidNetworkError(msg) => write!(fmt, "InvalidNetworkError: {}", msg)?,
65        }
66        Ok(())
67    }
68}
69
70// Use default implementation for `std::error::Error`
71impl std::error::Error for IPDBError {}
72
73impl de::Error for IPDBError {
74    fn custom<T: Display>(msg: T) -> Self {
75        IPDBError::DatabaseError(format!("{}", msg))
76    }
77}
78
79#[derive(Deserialize, Debug)]
80pub struct Metadata {
81    pub build: i64,
82    pub ip_version: u16,
83    pub node_count: usize,
84    pub total_size: usize,
85
86    pub fields: Vec<String>,
87    pub languages: HashMap<String, usize>,
88}
89
90#[derive(Debug)]
91pub struct Reader<S: AsRef<[u8]>> {
92    buf: S,
93
94    pub meta: Metadata,
95
96    pointer_base: usize,
97    ipv4_offset: usize,
98}
99
100#[cfg(feature = "mmap")]
101impl<'de> Reader<Mmap> {
102    /// Open an IPDB database file by memory mapping it.
103    ///
104    /// # Example
105    ///
106    /// ```
107    /// let reader = ipdb::Reader::open_mmap("ipdb.ipdb").unwrap();
108    /// ```
109    pub fn open_mmap<P: AsRef<Path>>(database: P) -> Result<Reader<Mmap>, IPDBError> {
110        let file_read = File::open(database)?;
111        let mmap = unsafe { MmapOptions::new().map(&file_read) }?;
112        Reader::from_source(mmap)
113    }
114}
115
116impl Reader<Vec<u8>> {
117    /// Open an IPDB database file by loading it into memory.
118    ///
119    /// # Example
120    ///
121    /// ```
122    /// let reader = ipdb::Reader::open_readfile("ipdb.ipdb").unwrap();
123    /// ```
124    pub fn open_readfile<P: AsRef<Path>>(database: P) -> Result<Reader<Vec<u8>>, IPDBError> {
125        use std::fs;
126
127        let buf: Vec<u8> = fs::read(&database)?;
128        Reader::from_source(buf)
129    }
130}
131
132impl<'de, S: AsRef<[u8]>> Reader<S> {
133    pub fn from_source(buf: S) -> Result<Reader<S>, IPDBError> {
134        let file_size = buf.as_ref().len();
135        let meta_bytes: [u8; 4] = buf.as_ref()[0..4].try_into().map_err(|_| {
136            IPDBError::MetaDataError(format!(
137                "The file size is too small to be a valid database: {}",
138                file_size
139            ))
140        })?;
141        let meta_length = u32::from_be_bytes(meta_bytes) as usize;
142
143        // validate file size
144        if file_size < 4 + meta_length {
145            return Err(IPDBError::FileSizeError(format!(
146                "File size is too small. Expected at least {} bytes, got {}",
147                4 + meta_length,
148                file_size
149            )));
150        }
151
152        let meta: Metadata = serde_json::from_slice(&buf.as_ref()[4..4 + meta_length]).unwrap();
153
154        // validate metadata
155        if meta.languages.len() == 0 {
156            return Err(IPDBError::MetaDataError(
157                "No languages specified in metadata.".to_owned(),
158            ));
159        } else if meta.fields.len() == 0 {
160            return Err(IPDBError::MetaDataError(
161                "No fields specified in metadata.".to_owned(),
162            ));
163        }
164
165        // validate if filesize matches metadata
166        if file_size != (4 + meta_length + meta.total_size) {
167            return Err(IPDBError::FileSizeError(format!(
168                "File size does not match metadata. Expected {} bytes, got {}",
169                meta.total_size, file_size
170            )));
171        }
172
173        let mut r = Reader {
174            buf,
175            meta,
176            pointer_base: 4 + meta_length,
177            ipv4_offset: 0,
178        };
179        r.ipv4_offset = r.find_ipv4_start()?;
180
181        Ok(r)
182    }
183
184    fn find_ipv4_start(&self) -> Result<usize, IPDBError> {
185        let mut node: usize = 0_usize;
186        for i in 0_u8..96 {
187            if node >= self.meta.node_count {
188                break;
189            }
190
191            if i >= 80 {
192                node = self.read_node(node, 1)?;
193            } else {
194                node = self.read_node(node, 0)?;
195            }
196        }
197
198        Ok(node)
199    }
200
201    #[inline]
202    fn is_ipv4_supported(&self) -> bool {
203        (self.meta.ip_version & IPV4) == IPV4
204    }
205
206    #[inline]
207    fn is_ipv6_supported(&self) -> bool {
208        (self.meta.ip_version & IPV6) == IPV6
209    }
210
211    fn read_node(&self, node: usize, index: usize) -> Result<usize, IPDBError> {
212        let offset = self.pointer_base + node * 8 + index * 4;
213        let bytes = &self.buf.as_ref()[offset..offset + 4];
214        match u32::from_be_bytes(bytes.try_into().unwrap()) as usize {
215            0 => Err(IPDBError::DataNotFoundError("Data not found".to_owned())),
216            x => Ok(x),
217        }
218    }
219
220    fn search_node(&self, ip: Vec<u8>) -> Result<(usize, usize), IPDBError> {
221        let bit_count = ip.len() * 8;
222        let mut node: usize = 0_usize;
223
224        let mut prefix_len = bit_count;
225
226        if bit_count == 32 {
227            node = self.ipv4_offset;
228        }
229
230        for i in 0_usize..bit_count {
231            if node > self.meta.node_count {
232                prefix_len = i;
233                break;
234            }
235
236            let index = (0xFF & (ip[i >> 3])) >> (7 - (i % 8)) & 1;
237            match self.read_node(node, index as usize) {
238                Ok(x) => node = x,
239                Err(e) => return Err(e),
240            }
241        }
242
243        if node <= self.meta.node_count {
244            return Err(IPDBError::DataNotFoundError("Data not found".to_owned()));
245        }
246
247        Ok((node, prefix_len))
248    }
249
250    fn resolve_data_pointer(&self, node: usize) -> Result<(usize, usize), IPDBError> {
251        let start = self.pointer_base + node - self.meta.node_count + self.meta.node_count * 8;
252        if start >= self.meta.total_size {
253            return Err(IPDBError::OutOfBoundError(start, self.meta.total_size));
254        }
255
256        let size = u32::from_be_bytes([
257            0u8,
258            0u8,
259            self.buf.as_ref()[start],
260            self.buf.as_ref()[start + 1],
261        ]) as usize;
262        let offset = start + 2 + size;
263
264        if offset > self.meta.total_size {
265            return Err(IPDBError::OutOfBoundError(offset, self.meta.total_size));
266        }
267
268        Ok((start + 2, offset))
269    }
270
271    fn parse_data(
272        &self,
273        start: usize,
274        offset: usize,
275        skip_columns: usize,
276    ) -> Result<CityInfo, IPDBError> {
277        use std::str::from_utf8_unchecked;
278        let bytes = &self.buf.as_ref()[start..offset];
279        let data = unsafe { from_utf8_unchecked(bytes) };
280
281        let sp: Vec<&str> = data.split('\t').skip(skip_columns).collect();
282
283        Ok(sp.into())
284    }
285
286    pub fn lookup(&self, address: IpAddr, language: String) -> Result<city::CityInfo, IPDBError> {
287        let (info, _prefixlen) = self.lookup_prefix(address, language)?;
288        Ok(info)
289    }
290    /// Lookup the socket address in the opened IPDB database
291    ///
292    /// Example:
293    ///
294    /// ```
295    /// use ipdb;
296    /// use std::net::IpAddr;
297    /// use std::str::FromStr;
298    ///
299    /// let reader = ipdb::Reader::open_readfile("ipdb.ipdb").unwrap();
300    ///
301    /// let ip: IpAddr = "1.1.1.1".parse().unwrap();
302    /// let data = reader.lookup_prefix(ip, "EN".to_owned()).unwrap();
303    /// println!("{:#?}", data);
304    /// ```
305    pub fn lookup_prefix(
306        &self,
307        address: IpAddr,
308        language: String,
309    ) -> Result<(CityInfo, usize), IPDBError> {
310        // check if language is supported
311        let skip = match self.meta.languages.get(&language) {
312            Some(x) => x,
313            None => {
314                return Err(IPDBError::NotSupportedError(
315                    "Language not supported".to_owned(),
316                ))
317            }
318        };
319
320        // check if ip version is supported
321        let ip_bytes = ip_to_bytes(address);
322        match address {
323            IpAddr::V4(_) => {
324                if !self.is_ipv4_supported() {
325                    return Err(IPDBError::NotSupportedError(
326                        "IPv4 is not supported by this database.".to_owned(),
327                    ));
328                }
329            }
330            IpAddr::V6(_) => {
331                if !self.is_ipv6_supported() {
332                    return Err(IPDBError::NotSupportedError(
333                        "IPv6 is not supported by this database.".to_owned(),
334                    ));
335                }
336            }
337        }
338
339        let (pointer, prefix_len) = self.search_node(ip_bytes)?;
340
341        let (start, offset) = self.resolve_data_pointer(pointer)?;
342        let data = self.parse_data(start, offset, *skip)?;
343
344        Ok((data, prefix_len))
345    }
346
347    pub fn within(&'de self, cidr: IpNetwork) -> Result<Within<S>, IPDBError>
348    {
349        let ip_address = cidr.network();
350        let prefix_len = cidr.prefix() as usize;
351        let ip_bytes = ip_to_bytes(ip_address);
352        let bit_count = ip_bytes.len() * 8;
353
354        // IPv6 isn't implemented yet
355        let mut node = self.ipv4_offset;
356        let node_count = self.meta.node_count as usize;
357
358        let mut stack: Vec<WithinNode> = Vec::with_capacity(bit_count - prefix_len);
359
360        // Traverse down the tree to the level that matches the cidr mark
361        let mut i = 0_usize;
362        while i < prefix_len {
363            let bit = 1 & (ip_bytes[i >> 3] >> (7 - (i % 8))) as usize;
364            node = self.read_node(node, bit)?;
365            if node >= node_count {
366                // We've hit a dead end before we exhausted our prefix
367                break;
368            }
369
370            i += 1;
371        }
372
373        if node < node_count {
374            // Ok, now anything that's below node in the tree is "within", start with the node we
375            // traversed to as our to be processed stack.
376            stack.push(WithinNode {
377                node,
378                ip_bytes,
379                prefix_len,
380            });
381        }
382        // else the stack will be empty and we'll be returning an iterator that visits nothing,
383        // which makes sense.
384
385        let within: Within<S> = Within {
386            reader: self,
387            node_count,
388            stack,
389            phantom: PhantomData,
390        };
391
392        Ok(within)
393    }
394
395}
396
397#[derive(Debug)]
398struct WithinNode {
399    node: usize,
400    ip_bytes: Vec<u8>,
401    prefix_len: usize,
402}
403
404#[derive(Debug)]
405pub struct Within<'de, S: AsRef<[u8]>> {
406    reader: &'de Reader<S>,
407    node_count: usize,
408    stack: Vec<WithinNode>,
409    phantom: PhantomData<CityInfo<'de>>,
410}
411
412#[derive(Debug)]
413pub struct WithinItem<'a> {
414    pub ip_net: IpNetwork,
415    pub info:CityInfo<'a>,
416}
417
418impl<'de, S: AsRef<[u8]>> Iterator for Within<'de, S> {
419    type Item = Result<WithinItem<'de>, IPDBError>;
420
421    fn next(&mut self) -> Option<Self::Item> {
422        while !self.stack.is_empty() {
423            let current = self.stack.pop().unwrap();
424            let bit_count = current.ip_bytes.len() * 8;
425
426            // Skip networks that are aliases for the IPv4 network
427            if self.reader.ipv4_offset != 0
428                && current.node == self.reader.ipv4_offset
429                && bit_count == 128
430                && current.ip_bytes[..12].iter().any(|&b| b != 0)
431            {
432                continue;
433            }
434
435            match current.node.cmp(&self.node_count) {
436                Ordering::Greater => {
437                    // This is a data node, emit it and we're done (until the following next call)
438                    let ip_net = match bytes_and_prefix_to_net(
439                        &current.ip_bytes,
440                        current.prefix_len as u8,
441                    ) {
442                        Ok(ip_net) => ip_net,
443                        Err(e) => return Some(Err(e)),
444                    };
445
446                    let rec = match self.reader.resolve_data_pointer(current.node) {
447                        Ok(rec) => rec,
448                        Err(e) => return Some(Err(e)),
449                    };
450                    
451                    let data = self.reader.parse_data(rec.0, rec.1, 0);
452                    return match data {
453                        Ok(info) => Some(Ok(WithinItem { ip_net, info })),
454                        Err(e) => Some(Err(e)),
455                    };
456                }
457                Ordering::Equal => {
458                    // Dead end, nothing to do
459                }
460                Ordering::Less => {
461                    // In order traversal of our children
462                    // right/1-bit
463                    let mut right_ip_bytes = current.ip_bytes.clone();
464                    right_ip_bytes[current.prefix_len >> 3] |=
465                        1 << ((bit_count - current.prefix_len - 1) % 8);
466                    let node = match self.reader.read_node(current.node, 1) {
467                        Ok(node) => node,
468                        Err(e) => return Some(Err(e)),
469                    };
470                    self.stack.push(WithinNode {
471                        node,
472                        ip_bytes: right_ip_bytes,
473                        prefix_len: current.prefix_len + 1,
474                    });
475                    // left/0-bit
476                    let node = match self.reader.read_node(current.node, 0) {
477                        Ok(node) => node,
478                        Err(e) => return Some(Err(e)),
479                    };
480                    self.stack.push(WithinNode {
481                        node,
482                        ip_bytes: current.ip_bytes.clone(),
483                        prefix_len: current.prefix_len + 1,
484                    });
485                }
486            }
487        }
488        None
489    }
490}
491
492pub mod city;
493
494#[inline]
495fn ip_to_bytes(address: IpAddr) -> Vec<u8> {
496    match address {
497        IpAddr::V4(a) => a.octets().to_vec(),
498        IpAddr::V6(a) => a.octets().to_vec(),
499    }
500}
501
502#[allow(clippy::many_single_char_names)]
503fn bytes_and_prefix_to_net(bytes: &[u8], prefix: u8) -> Result<IpNetwork, IPDBError> {
504    let (ip, pre) = match bytes.len() {
505        4 => (
506            IpAddr::V4(Ipv4Addr::new(bytes[0], bytes[1], bytes[2], bytes[3])),
507            prefix,
508        ),
509        16 => {
510            if bytes[0] == 0
511                && bytes[1] == 0
512                && bytes[2] == 0
513                && bytes[3] == 0
514                && bytes[4] == 0
515                && bytes[5] == 0
516                && bytes[6] == 0
517                && bytes[7] == 0
518                && bytes[8] == 0
519                && bytes[9] == 0
520                && bytes[10] == 0
521                && bytes[11] == 0
522            {
523                // It's actually v4, but in v6 form, convert would be nice if ipnetwork had this
524                // logic.
525                (
526                    IpAddr::V4(Ipv4Addr::new(bytes[12], bytes[13], bytes[14], bytes[15])),
527                    prefix - 96,
528                )
529            } else {
530                let a = (bytes[0] as u16) << 8 | bytes[1] as u16;
531                let b = (bytes[2] as u16) << 8 | bytes[3] as u16;
532                let c = (bytes[4] as u16) << 8 | bytes[5] as u16;
533                let d = (bytes[6] as u16) << 8 | bytes[7] as u16;
534                let e = (bytes[8] as u16) << 8 | bytes[9] as u16;
535                let f = (bytes[10] as u16) << 8 | bytes[11] as u16;
536                let g = (bytes[12] as u16) << 8 | bytes[13] as u16;
537                let h = (bytes[14] as u16) << 8 | bytes[15] as u16;
538                (IpAddr::V6(Ipv6Addr::new(a, b, c, d, e, f, g, h)), prefix)
539            }
540        }
541        // This should never happen
542        _ => return Err(IPDBError::InvalidNetworkError("invalid address".to_owned())),
543    };
544    IpNetwork::new(ip, pre).map_err(|e| IPDBError::InvalidNetworkError(e.to_string()))
545}