use std::net::IpAddr;
use ipnetwork::IpNetwork;
use serde::Deserialize;
use crate::decoder::{TYPE_ARRAY, TYPE_MAP};
use crate::error::MaxMindDbError;
use crate::reader::Reader;
#[derive(Debug, Clone, Copy)]
pub struct LookupResult<'a, S: AsRef<[u8]>> {
reader: &'a Reader<S>,
data_offset: Option<usize>,
prefix_len: u8,
ip: IpAddr,
source: LookupSource,
network_kind: NetworkKind,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum LookupSource {
Lookup,
Iter,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum NetworkKind {
V4,
V6,
V4InV6Subtree,
}
impl<'a, S: AsRef<[u8]>> LookupResult<'a, S> {
#[inline]
fn decoder(&self, offset: usize) -> super::decoder::Decoder<'a> {
let buf = &self.reader.buf.as_ref()[self.reader.pointer_base..];
super::decoder::Decoder::new(buf, offset)
}
pub(crate) fn new_found(
reader: &'a Reader<S>,
data_offset: usize,
prefix_len: u8,
ip: IpAddr,
source: LookupSource,
network_kind: NetworkKind,
) -> Self {
LookupResult {
reader,
data_offset: Some(data_offset),
prefix_len,
ip,
source,
network_kind,
}
}
pub(crate) fn new_not_found(
reader: &'a Reader<S>,
prefix_len: u8,
ip: IpAddr,
source: LookupSource,
network_kind: NetworkKind,
) -> Self {
LookupResult {
reader,
data_offset: None,
prefix_len,
ip,
source,
network_kind,
}
}
#[inline]
pub fn has_data(&self) -> bool {
self.data_offset.is_some()
}
pub fn network(&self) -> Result<IpNetwork, MaxMindDbError> {
let (ip, prefix) = match (self.source, self.network_kind, self.ip) {
(_, NetworkKind::V4, IpAddr::V4(v4)) => (IpAddr::V4(v4), self.prefix_len),
(_, NetworkKind::V4InV6Subtree, IpAddr::V4(v4)) => (
IpAddr::V4(v4),
self.prefix_len - self.reader.ipv4_start_bit_depth as u8,
),
(LookupSource::Lookup, NetworkKind::V6, IpAddr::V4(_)) => {
use std::net::Ipv6Addr;
(IpAddr::V6(Ipv6Addr::UNSPECIFIED), self.prefix_len)
}
(_, NetworkKind::V6, IpAddr::V6(v6)) => (IpAddr::V6(v6), self.prefix_len),
(_, _, ip) => unreachable!("unexpected lookup result state for network: {ip:?}"),
};
let network_ip = mask_ip(ip, prefix);
IpNetwork::new(network_ip, prefix).map_err(MaxMindDbError::InvalidNetwork)
}
#[inline]
pub fn offset(&self) -> Option<usize> {
self.data_offset
}
pub fn decode<T>(&self) -> Result<Option<T>, MaxMindDbError>
where
T: Deserialize<'a>,
{
let Some(offset) = self.data_offset else {
return Ok(None);
};
let mut decoder = self.decoder(offset);
T::deserialize(&mut decoder).map(Some)
}
pub fn decode_path<T>(&self, path: &[PathElement<'_>]) -> Result<Option<T>, MaxMindDbError>
where
T: Deserialize<'a>,
{
let Some(offset) = self.data_offset else {
return Ok(None);
};
let mut decoder = self.decoder(offset);
for (i, element) in path.iter().enumerate() {
let with_path = |e| add_path_context(e, &path[..=i]);
match *element {
PathElement::Key(key) => {
let (_, type_num) = decoder.peek_type().map_err(with_path)?;
if type_num != TYPE_MAP {
return Err(MaxMindDbError::decoding_at_path(
format!("expected map for Key(\"{key}\"), got type {type_num}"),
decoder.offset(),
render_path(&path[..=i]),
));
}
let size = decoder.consume_map_header().map_err(with_path)?;
let mut found = false;
let key_bytes = key.as_bytes();
for _ in 0..size {
let k = decoder.read_str_as_bytes().map_err(with_path)?;
if k == key_bytes {
found = true;
break;
} else {
decoder.skip_value().map_err(with_path)?;
}
}
if !found {
return Ok(None);
}
}
PathElement::Index(idx) | PathElement::IndexFromEnd(idx) => {
let (_, type_num) = decoder.peek_type().map_err(with_path)?;
if type_num != TYPE_ARRAY {
let elem = match *element {
PathElement::Index(i) => format!("Index({i})"),
PathElement::IndexFromEnd(i) => format!("IndexFromEnd({i})"),
PathElement::Key(_) => unreachable!(),
};
return Err(MaxMindDbError::decoding_at_path(
format!("expected array for {elem}, got type {type_num}"),
decoder.offset(),
render_path(&path[..=i]),
));
}
let size = decoder.consume_array_header().map_err(with_path)?;
if idx >= size {
return Ok(None); }
let actual_idx = match *element {
PathElement::Index(i) => i,
PathElement::IndexFromEnd(i) => size - 1 - i,
PathElement::Key(_) => unreachable!(),
};
for _ in 0..actual_idx {
decoder.skip_value().map_err(with_path)?;
}
}
}
}
T::deserialize(&mut decoder)
.map(Some)
.map_err(|e| add_path_context(e, path))
}
}
fn add_path_context(err: MaxMindDbError, path: &[PathElement<'_>]) -> MaxMindDbError {
match err {
MaxMindDbError::Decoding {
message,
offset,
path: None,
} => MaxMindDbError::Decoding {
message,
offset,
path: Some(render_path(path)),
},
_ => err,
}
}
fn render_path(path: &[PathElement<'_>]) -> String {
use std::fmt::Write;
let mut s = String::new();
for elem in path {
s.push('/');
match elem {
PathElement::Key(k) => s.push_str(k),
PathElement::Index(i) => write!(s, "{i}").unwrap(),
PathElement::IndexFromEnd(i) => write!(s, "{}", -((*i as isize) + 1)).unwrap(),
}
}
s
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PathElement<'a> {
Key(&'a str),
Index(usize),
IndexFromEnd(usize),
}
impl<'a> From<&'a str> for PathElement<'a> {
fn from(s: &'a str) -> Self {
PathElement::Key(s)
}
}
impl From<i32> for PathElement<'_> {
fn from(n: i32) -> Self {
signed_index_to_path_element(n as isize)
}
}
impl From<usize> for PathElement<'_> {
fn from(n: usize) -> Self {
PathElement::Index(n)
}
}
impl From<isize> for PathElement<'_> {
fn from(n: isize) -> Self {
signed_index_to_path_element(n)
}
}
fn signed_index_to_path_element<'a>(n: isize) -> PathElement<'a> {
if n >= 0 {
PathElement::Index(n as usize)
} else {
let index = n
.checked_neg()
.and_then(|n| n.checked_sub(1))
.map(|n| n as usize)
.unwrap_or(usize::MAX);
PathElement::IndexFromEnd(index)
}
}
#[macro_export]
macro_rules! path {
($($elem:expr),* $(,)?) => {
[$($crate::PathElement::from($elem)),*]
};
}
fn mask_ip(ip: IpAddr, prefix: u8) -> IpAddr {
match ip {
IpAddr::V4(v4) => {
if prefix >= 32 {
IpAddr::V4(v4)
} else {
let int: u32 = v4.into();
let mask = if prefix == 0 {
0
} else {
!0u32 << (32 - prefix)
};
IpAddr::V4((int & mask).into())
}
}
IpAddr::V6(v6) => {
if prefix >= 128 {
IpAddr::V6(v6)
} else {
let int: u128 = v6.into();
let mask = if prefix == 0 {
0
} else {
!0u128 << (128 - prefix)
};
IpAddr::V6((int & mask).into())
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mask_ipv4() {
let ip: IpAddr = "192.168.1.100".parse().unwrap();
assert_eq!(mask_ip(ip, 24), "192.168.1.0".parse::<IpAddr>().unwrap());
assert_eq!(mask_ip(ip, 16), "192.168.0.0".parse::<IpAddr>().unwrap());
assert_eq!(mask_ip(ip, 32), "192.168.1.100".parse::<IpAddr>().unwrap());
assert_eq!(mask_ip(ip, 0), "0.0.0.0".parse::<IpAddr>().unwrap());
}
#[test]
fn test_mask_ipv6() {
let ip: IpAddr = "2001:db8:85a3::8a2e:370:7334".parse().unwrap();
assert_eq!(
mask_ip(ip, 64),
"2001:db8:85a3::".parse::<IpAddr>().unwrap()
);
assert_eq!(mask_ip(ip, 32), "2001:db8::".parse::<IpAddr>().unwrap());
}
#[test]
fn test_path_element_debug() {
assert_eq!(format!("{:?}", PathElement::Key("test")), "Key(\"test\")");
assert_eq!(format!("{:?}", PathElement::Index(5)), "Index(5)");
assert_eq!(
format!("{:?}", PathElement::IndexFromEnd(0)),
"IndexFromEnd(0)"
);
}
#[test]
fn test_path_element_from_str() {
let elem: PathElement = "key".into();
assert_eq!(elem, PathElement::Key("key"));
}
#[test]
fn test_path_element_from_i32() {
let elem: PathElement = PathElement::from(0i32);
assert_eq!(elem, PathElement::Index(0));
let elem: PathElement = PathElement::from(5i32);
assert_eq!(elem, PathElement::Index(5));
let elem: PathElement = PathElement::from(-1i32);
assert_eq!(elem, PathElement::IndexFromEnd(0));
let elem: PathElement = PathElement::from(-2i32);
assert_eq!(elem, PathElement::IndexFromEnd(1));
let elem: PathElement = PathElement::from(-3i32);
assert_eq!(elem, PathElement::IndexFromEnd(2));
}
#[test]
fn test_path_element_from_usize() {
let elem: PathElement = PathElement::from(0usize);
assert_eq!(elem, PathElement::Index(0));
let elem: PathElement = PathElement::from(42usize);
assert_eq!(elem, PathElement::Index(42));
}
#[test]
fn test_path_element_from_isize() {
let elem: PathElement = PathElement::from(0isize);
assert_eq!(elem, PathElement::Index(0));
let elem: PathElement = PathElement::from(-1isize);
assert_eq!(elem, PathElement::IndexFromEnd(0));
let elem: PathElement = PathElement::from(isize::MIN);
assert_eq!(elem, PathElement::IndexFromEnd(usize::MAX));
}
#[test]
fn test_path_macro_keys_only() {
let p = path!["country", "iso_code"];
assert_eq!(p.len(), 2);
assert_eq!(p[0], PathElement::Key("country"));
assert_eq!(p[1], PathElement::Key("iso_code"));
}
#[test]
fn test_path_macro_mixed() {
let p = path!["subdivisions", 0, "names", "en"];
assert_eq!(p.len(), 4);
assert_eq!(p[0], PathElement::Key("subdivisions"));
assert_eq!(p[1], PathElement::Index(0));
assert_eq!(p[2], PathElement::Key("names"));
assert_eq!(p[3], PathElement::Key("en"));
}
#[test]
fn test_path_macro_negative_indexes() {
let p = path!["array", -1];
assert_eq!(p.len(), 2);
assert_eq!(p[0], PathElement::Key("array"));
assert_eq!(p[1], PathElement::IndexFromEnd(0));
let p = path!["data", -2, "value"];
assert_eq!(p[1], PathElement::IndexFromEnd(1)); }
#[test]
fn test_path_macro_trailing_comma() {
let p = path!["a", "b",];
assert_eq!(p.len(), 2);
}
#[test]
fn test_path_macro_empty() {
let p: [PathElement; 0] = path![];
assert_eq!(p.len(), 0);
}
#[test]
fn test_render_path() {
assert_eq!(render_path(&[]), "");
assert_eq!(render_path(&[PathElement::Key("city")]), "/city");
assert_eq!(
render_path(&[PathElement::Key("city"), PathElement::Key("names")]),
"/city/names"
);
assert_eq!(
render_path(&[PathElement::Key("arr"), PathElement::Index(0)]),
"/arr/0"
);
assert_eq!(
render_path(&[PathElement::Key("arr"), PathElement::Index(42)]),
"/arr/42"
);
assert_eq!(
render_path(&[PathElement::Key("arr"), PathElement::IndexFromEnd(0)]),
"/arr/-1"
);
assert_eq!(
render_path(&[PathElement::Key("arr"), PathElement::IndexFromEnd(1)]),
"/arr/-2"
);
}
#[test]
fn test_decode_path_error_includes_path() {
use crate::Reader;
let reader = Reader::open_readfile("test-data/test-data/GeoIP2-City-Test.mmdb").unwrap();
let ip: IpAddr = "89.160.20.128".parse().unwrap();
let result = reader.lookup(ip).unwrap();
let err = result
.decode_path::<String>(&[PathElement::Index(0)])
.unwrap_err();
let err_str = err.to_string();
assert!(
err_str.contains("path: /0"),
"error should include path context: {err_str}"
);
assert!(
err_str.contains("expected array"),
"error should mention expected type: {err_str}"
);
let err = result
.decode_path::<String>(&[PathElement::Key("city"), PathElement::Index(0)])
.unwrap_err();
let err_str = err.to_string();
assert!(
err_str.contains("path: /city/0"),
"error should include full path to failure: {err_str}"
);
}
}