ip2location-bin-format 0.4.0

IP2Location BIN Format
Documentation
use std::{
    io::SeekFrom,
    net::{IpAddr, Ipv4Addr, Ipv6Addr},
};

use futures_util::{AsyncRead, AsyncReadExt as _, AsyncSeek, AsyncSeekExt as _};

use super::error::Error;
use crate::{
    record_field::{RecordFieldContent, RecordFieldContents, RecordFields},
    records::PositionRange,
};

//
#[derive(Debug)]
pub(super) struct Inner<S> {
    stream: S,
    count: u32,
    seek_from_start_base: u64,
    record_fields: RecordFields,
    record_field_contents: RecordFieldContents,
    buf: Vec<u8>,
}

impl<S> Inner<S> {
    pub(super) fn new(
        stream: S,
        count: u32,
        seek_from_start_base: u64,
        record_fields: RecordFields,
        record_field_contents: RecordFieldContents,
        buf: Vec<u8>,
    ) -> Self {
        Self {
            stream,
            count,
            seek_from_start_base,
            record_fields,
            record_field_contents,
            buf,
        }
    }
}

//
//
//
impl<S> Inner<S>
where
    S: AsyncSeek + AsyncRead + Unpin,
{
    pub(super) async fn query(
        &mut self,
        ip: IpAddr,
        PositionRange {
            start: mut low,
            end: mut high,
        }: PositionRange,
    ) -> Result<Option<(IpAddr, IpAddr, RecordFieldContents)>, Error> {
        if high > self.count {
            high = self.count;
        }
        if low > high {
            low = high;
        }

        let mut n_depth = 0;
        while low <= high {
            let mid = (low + high) >> 1;

            let seek_from_start = self.seek_from_start_base
                + match ip {
                    IpAddr::V4(_) => self.record_fields.records_bytes_len_for_ipv4(mid),
                    IpAddr::V6(_) => self.record_fields.records_bytes_len_for_ipv6(mid),
                } as u64;

            self.stream
                .seek(SeekFrom::Start(seek_from_start))
                .await
                .map_err(Error::SeekFailed)?;

            self.stream
                .read_exact(&mut self.buf)
                .await
                .map_err(Error::ReadFailed)?;

            let ip_from: IpAddr = match ip {
                IpAddr::V4(_) => {
                    Ipv4Addr::from(u32::from_ne_bytes(self.buf[0..4].try_into().unwrap())).into()
                }
                IpAddr::V6(_) => {
                    Ipv6Addr::from(u128::from_ne_bytes(self.buf[0..16].try_into().unwrap())).into()
                }
            };
            let ip_to: IpAddr = if high < self.count {
                match ip {
                    IpAddr::V4(_) => Ipv4Addr::from(u32::from_ne_bytes(
                        self.buf[self.buf.len() - 4..self.buf.len()]
                            .try_into()
                            .unwrap(),
                    ))
                    .into(),
                    IpAddr::V6(_) => Ipv6Addr::from(u128::from_ne_bytes(
                        self.buf[self.buf.len() - 16..self.buf.len()]
                            .try_into()
                            .unwrap(),
                    ))
                    .into(),
                }
            } else {
                match ip_from {
                    IpAddr::V4(ip_from) => {
                        Ipv4Addr::from(u32::from(ip_from).saturating_add(1)).into()
                    }
                    IpAddr::V6(ip_from) => {
                        Ipv6Addr::from(u128::from(ip_from).saturating_add(1)).into()
                    }
                }
            };

            if (ip >= ip_from) && (ip < ip_to) {
                let mut record_field_contents = self.record_field_contents.to_owned();
                for (n, record_field_content) in record_field_contents.iter_mut().enumerate() {
                    let index = match ip {
                        IpAddr::V4(_) => 4 + n * 4,
                        IpAddr::V6(_) => 16 + n * 4,
                    };

                    let content_index =
                        u32::from_ne_bytes(self.buf[index..index + 4].try_into().unwrap());

                    match record_field_content {
                        RecordFieldContent::COUNTRY(i, _, _) => *i = content_index,
                        RecordFieldContent::REGION(i, _) => *i = content_index,
                        RecordFieldContent::CITY(i, _) => *i = content_index,
                        RecordFieldContent::ISP(i, _) => *i = content_index,
                        RecordFieldContent::DOMAIN(i, _) => *i = content_index,
                        //
                        RecordFieldContent::LATITUDE(v) => {
                            *v = {
                                f32::from_ne_bytes(self.buf[index..index + 4].try_into().unwrap())
                            }
                        }
                        RecordFieldContent::LONGITUDE(v) => {
                            *v = {
                                f32::from_ne_bytes(self.buf[index..index + 4].try_into().unwrap())
                            }
                        }
                        RecordFieldContent::ZIPCODE(i, _) => *i = content_index,
                        RecordFieldContent::TIMEZONE(i, _) => *i = content_index,
                        RecordFieldContent::NETSPEED(i, _) => *i = content_index,
                        //
                        RecordFieldContent::PROXYTYPE(i, _) => *i = content_index,
                        RecordFieldContent::USAGETYPE(i, _) => *i = content_index,
                        RecordFieldContent::ASN(i, _) => *i = content_index,
                        RecordFieldContent::AS(i, _) => *i = content_index,
                        RecordFieldContent::LASTSEEN(i, _) => *i = content_index,
                        RecordFieldContent::THREAT(i, _) => *i = content_index,
                        RecordFieldContent::RESIDENTIAL(i, _) => *i = content_index,
                        RecordFieldContent::PROVIDER(i, _) => *i = content_index,
                    }
                }

                return Ok(Some((ip_from, ip_to, record_field_contents)));
            } else if ip < ip_from {
                high = mid.saturating_sub(1);
            } else {
                low = mid.saturating_add(1);
            }

            //
            //
            //
            if high == 0 {
                return Ok(None);
            }
            #[allow(clippy::collapsible_else_if)]
            if self.count == u32::MAX {
                if low == self.count {
                    return Ok(None);
                }
            } else {
                if low > self.count {
                    return Ok(None);
                }
            }

            if n_depth > 30 {
                return Err(Error::MaxDepthReached);
            }

            n_depth += 1;
        }

        Ok(None)
    }
}