#![deny(missing_docs)]
mod interner;
pub mod parser;
pub mod range;
pub mod types;
use crate::interner::StringInterner;
use crate::parser::{parse_line, ParsedLine};
use crate::range::range_to_cidrs;
use crate::types::AsnRecord;
use flate2::read::GzDecoder;
use ip_network::IpNetwork;
use ip_network_table::IpNetworkTable;
use std::error::Error as StdError;
use std::fmt;
use std::fs::File;
use std::hash::{Hash, Hasher};
use std::io::{BufRead, BufReader};
use std::net::IpAddr;
use std::path::Path;
#[derive(Debug)]
#[non_exhaustive]
pub enum Error {
Io(std::io::Error),
#[cfg(feature = "fetch")]
Http(reqwest::Error),
Parse {
line_number: usize,
line_content: String,
kind: ParseErrorKind,
},
}
impl StdError for Error {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
match self {
Error::Io(e) => Some(e),
#[cfg(feature = "fetch")]
Error::Http(e) => Some(e),
Error::Parse { .. } => None,
}
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::Io(e) => write!(f, "I/O error: {e}"),
#[cfg(feature = "fetch")]
Error::Http(e) => write!(f, "HTTP error: {e}"),
Error::Parse {
line_number,
line_content,
kind,
} => write!(
f,
"Parse error on line {line_number}: {kind} in line: \"{line_content}\""
),
}
}
}
impl From<std::io::Error> for Error {
fn from(err: std::io::Error) -> Self {
Error::Io(err)
}
}
#[cfg(feature = "fetch")]
impl From<reqwest::Error> for Error {
fn from(err: reqwest::Error) -> Self {
Error::Http(err)
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum Warning {
Parse {
line_number: usize,
line_content: String,
message: String,
},
IpFamilyMismatch {
line_number: usize,
line_content: String,
},
}
impl fmt::Display for Warning {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Warning::Parse {
line_number,
line_content,
message,
} => write!(
f,
"Parse warning on line {line_number}: {message} in line: \"{line_content}\""
),
Warning::IpFamilyMismatch {
line_number,
line_content,
} => write!(
f,
"IP family mismatch on line {line_number}: \"{line_content}\""
),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum ParseErrorKind {
IncorrectColumnCount {
expected: usize,
found: usize,
},
InvalidIpAddress {
field: String,
value: String,
},
InvalidAsnNumber {
value: String,
},
InvalidRange {
start_ip: IpAddr,
end_ip: IpAddr,
},
IpFamilyMismatch,
InvalidCountryCode {
value: String,
},
}
impl fmt::Display for ParseErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ParseErrorKind::IncorrectColumnCount { expected, found } => {
write!(f, "expected {expected} columns, but found {found}")
}
ParseErrorKind::InvalidIpAddress { field, value } => {
write!(f, "invalid IP address for field `{field}`: {value}")
}
ParseErrorKind::InvalidAsnNumber { value } => {
write!(f, "invalid ASN: {value}")
}
ParseErrorKind::InvalidRange { start_ip, end_ip } => {
write!(f, "start IP {start_ip} is greater than end IP {end_ip}")
}
ParseErrorKind::IpFamilyMismatch => {
write!(f, "start and end IPs are of different families")
}
ParseErrorKind::InvalidCountryCode { value } => {
write!(f, "invalid country code: {value}")
}
}
}
}
pub struct IpAsnMap {
table: IpNetworkTable<AsnRecord>,
organizations: Vec<String>,
}
impl fmt::Debug for IpAsnMap {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("IpAsnMap")
.field("organizations", &self.organizations.len())
.finish_non_exhaustive()
}
}
impl Default for IpAsnMap {
fn default() -> Self {
Self {
table: IpNetworkTable::new(),
organizations: Vec::new(),
}
}
}
impl IpAsnMap {
pub fn new() -> Self {
Self::default()
}
pub fn builder() -> Builder<'static> {
Builder::new()
}
pub fn lookup(&self, ip: IpAddr) -> Option<AsnInfoView> {
self.table.longest_match(ip).map(|(network, record)| {
let organization = &self.organizations[record.organization_idx as usize];
AsnInfoView {
network,
asn: record.asn,
country_code: std::str::from_utf8(&record.country_code).unwrap_or_default(),
organization,
}
})
}
pub fn lookup_owned(&self, ip: IpAddr) -> Option<AsnInfo> {
self.lookup(ip).map(AsnInfo::from)
}
}
#[derive(Debug, Clone, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct AsnInfo {
pub network: IpNetwork,
pub asn: u32,
pub country_code: String,
pub organization: String,
}
impl PartialEq for AsnInfo {
fn eq(&self, other: &Self) -> bool {
self.network == other.network
&& self.asn == other.asn
&& self.country_code == other.country_code
&& self.organization == other.organization
}
}
impl PartialOrd for AsnInfo {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for AsnInfo {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.asn
.cmp(&other.asn)
.then_with(|| self.network.cmp(&other.network))
.then_with(|| self.country_code.cmp(&other.country_code))
.then_with(|| self.organization.cmp(&other.organization))
}
}
impl Hash for AsnInfo {
fn hash<H: Hasher>(&self, state: &mut H) {
self.network.hash(state);
self.asn.hash(state);
self.country_code.hash(state);
self.organization.hash(state);
}
}
impl fmt::Display for AsnInfo {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"AS{} {} ({}) in {}",
self.asn, self.organization, self.country_code, self.network
)
}
}
impl<'a> From<AsnInfoView<'a>> for AsnInfo {
fn from(view: AsnInfoView<'a>) -> Self {
Self {
network: view.network,
asn: view.asn,
country_code: view.country_code.to_string(),
organization: view.organization.to_string(),
}
}
}
#[derive(Default)]
pub struct Builder<'a> {
source: Option<Box<dyn BufRead + Send + 'a>>,
strict: bool,
on_warning: Option<Box<dyn Fn(Warning) + Send + 'a>>,
}
impl<'a> fmt::Debug for Builder<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Builder")
.field("has_source", &self.source.is_some())
.field("strict", &self.strict)
.field("has_on_warning", &self.on_warning.is_some())
.finish()
}
}
impl<'a> Builder<'a> {
pub fn new() -> Self {
Self::default()
}
pub fn from_path<P: AsRef<Path>>(mut self, path: P) -> Result<Self, Error> {
let file = File::open(path.as_ref())?;
let reader = BufReader::new(file);
self.source = Some(self.create_source_from_reader(reader)?);
Ok(self)
}
pub fn with_source(mut self, source: impl BufRead + Send + 'a) -> Result<Self, Error> {
self.source = Some(self.create_source_from_reader(source)?);
Ok(self)
}
#[cfg(feature = "fetch")]
pub fn from_url(mut self, url: &str) -> Result<Self, Error> {
let response = reqwest::blocking::get(url)?;
let response = response.error_for_status()?;
let reader = BufReader::new(response);
self.source = Some(self.create_source_from_reader(reader)?);
Ok(self)
}
pub fn strict(mut self) -> Self {
self.strict = true;
self
}
pub fn on_warning<F>(mut self, callback: F) -> Self
where
F: Fn(Warning) + Send + 'a,
{
self.on_warning = Some(Box::new(callback));
self
}
fn create_source_from_reader(
&self,
mut reader: impl BufRead + Send + 'a,
) -> Result<Box<dyn BufRead + Send + 'a>, Error> {
let is_gzipped = {
let header = reader.fill_buf()?;
header.starts_with(&[0x1f, 0x8b])
};
let source: Box<dyn BufRead + Send + 'a> = if is_gzipped {
Box::new(BufReader::new(GzDecoder::new(reader)))
} else {
Box::new(reader)
};
Ok(source)
}
pub fn build(self) -> Result<IpAsnMap, Error> {
let source = self.source.ok_or_else(|| {
Error::Io(std::io::Error::new(
std::io::ErrorKind::NotFound,
"No data source provided",
))
})?;
let mut interner = StringInterner::new();
let mut table = IpNetworkTable::new();
for (i, line_result) in source.lines().enumerate() {
let line_number = i + 1;
let line = line_result?;
if line.is_empty() || line.starts_with('#') {
continue;
}
let parsed: ParsedLine = match parse_line(&line) {
Ok(p) => p,
Err(kind) => {
if self.strict {
return Err(Error::Parse {
line_number,
line_content: line,
kind,
});
} else if let Some(callback) = &self.on_warning {
let warning = if kind == ParseErrorKind::IpFamilyMismatch {
Warning::IpFamilyMismatch {
line_number,
line_content: line,
}
} else {
Warning::Parse {
line_number,
line_content: line,
message: format!("{kind:?}"),
}
};
callback(warning);
}
continue;
}
};
let org_idx = interner.get_or_intern(parsed.organization);
let record = AsnRecord {
asn: parsed.asn,
country_code: parsed.country_code,
organization_idx: org_idx,
};
for cidr in range_to_cidrs(parsed.start_ip, parsed.end_ip) {
table.insert(cidr, record);
}
}
let organizations = interner.into_vec();
Ok(IpAsnMap {
table,
organizations,
})
}
}
#[derive(Debug, PartialEq, Eq)]
#[non_exhaustive]
pub struct AsnInfoView<'a> {
pub network: IpNetwork,
pub asn: u32,
pub country_code: &'a str,
pub organization: &'a str,
}
#[cfg(test)]
mod tests {
use super::*;
use std::io;
#[test]
fn test_error_source() {
let io_error = Error::Io(io::Error::new(io::ErrorKind::NotFound, "file not found"));
assert!(io_error.source().is_some());
#[cfg(feature = "fetch")]
{
let client = reqwest::blocking::Client::new();
let req = client.get("http://0.0.0.0:1").build().unwrap();
let http_error = Error::Http(client.execute(req).unwrap_err());
assert!(http_error.source().is_some());
}
let parse_error = Error::Parse {
line_number: 1,
line_content: "bad line".to_string(),
kind: ParseErrorKind::IncorrectColumnCount {
expected: 5,
found: 1,
},
};
assert!(parse_error.source().is_none());
}
#[test]
fn test_error_display() {
let io_error = Error::Io(io::Error::new(io::ErrorKind::NotFound, "file not found"));
assert_eq!(io_error.to_string(), "I/O error: file not found");
#[cfg(feature = "fetch")]
{
let client = reqwest::blocking::Client::new();
let req = client.get("http://0.0.0.0:1").build().unwrap();
let http_error = Error::Http(client.execute(req).unwrap_err());
assert!(http_error.to_string().starts_with("HTTP error:"));
}
let parse_error = Error::Parse {
line_number: 42,
line_content: "bad line".to_string(),
kind: ParseErrorKind::InvalidAsnNumber {
value: "not-a-number".to_string(),
},
};
assert_eq!(
parse_error.to_string(),
"Parse error on line 42: invalid ASN: not-a-number in line: \"bad line\""
);
}
#[test]
fn test_warning_display() {
let parse_warning = Warning::Parse {
line_number: 10,
line_content: "another bad line".to_string(),
message: "some issue".to_string(),
};
assert_eq!(
parse_warning.to_string(),
"Parse warning on line 10: some issue in line: \"another bad line\""
);
let mismatch_warning = Warning::IpFamilyMismatch {
line_number: 20,
line_content: "v4-and-v6".to_string(),
};
assert_eq!(
mismatch_warning.to_string(),
"IP family mismatch on line 20: \"v4-and-v6\""
);
}
#[test]
fn test_parse_error_kind_display() {
let err = ParseErrorKind::IncorrectColumnCount {
expected: 5,
found: 4,
};
assert_eq!(err.to_string(), "expected 5 columns, but found 4");
let err = ParseErrorKind::InvalidIpAddress {
field: "start_ip".to_string(),
value: "not-an-ip".to_string(),
};
assert_eq!(
err.to_string(),
"invalid IP address for field `start_ip`: not-an-ip"
);
let err = ParseErrorKind::InvalidAsnNumber {
value: "not-a-number".to_string(),
};
assert_eq!(err.to_string(), "invalid ASN: not-a-number");
let err = ParseErrorKind::InvalidRange {
start_ip: "1.1.1.1".parse().unwrap(),
end_ip: "1.1.1.0".parse().unwrap(),
};
assert_eq!(
err.to_string(),
"start IP 1.1.1.1 is greater than end IP 1.1.1.0"
);
let err = ParseErrorKind::IpFamilyMismatch;
assert_eq!(
err.to_string(),
"start and end IPs are of different families"
);
let err = ParseErrorKind::InvalidCountryCode {
value: "USA".to_string(),
};
assert_eq!(err.to_string(), "invalid country code: USA");
}
#[test]
fn test_ip_asn_map_builder() {
let builder = IpAsnMap::builder();
assert!(!builder.strict);
}
#[test]
fn test_asn_info_ord_and_hash() {
use std::collections::HashSet;
let info1 = AsnInfo {
network: "1.0.0.0/24".parse().unwrap(),
asn: 13335,
country_code: "AU".to_string(),
organization: "CLOUDFLARENET".to_string(),
};
let info2 = AsnInfo {
network: "1.0.0.0/24".parse().unwrap(),
asn: 13335,
country_code: "AU".to_string(),
organization: "CLOUDFLARENET".to_string(),
};
let info3 = AsnInfo {
network: "8.8.8.0/24".parse().unwrap(),
asn: 15169,
country_code: "US".to_string(),
organization: "GOOGLE".to_string(),
};
let info4 = AsnInfo {
network: "1.0.0.0/24".parse().unwrap(),
asn: 13336, country_code: "AU".to_string(),
organization: "CLOUDFLARENET".to_string(),
};
assert_eq!(info1.cmp(&info2), std::cmp::Ordering::Equal);
assert_eq!(info1.cmp(&info3), std::cmp::Ordering::Less);
assert_eq!(info3.cmp(&info1), std::cmp::Ordering::Greater);
assert_eq!(info1.cmp(&info4), std::cmp::Ordering::Less);
let mut set = HashSet::new();
assert!(set.insert(info1.clone()));
assert!(!set.insert(info2.clone())); assert!(set.insert(info3.clone()));
assert_eq!(set.len(), 2);
}
#[test]
fn test_asn_info_display() {
let info = AsnInfo {
network: "192.0.2.0/24".parse().unwrap(),
asn: 64496,
country_code: "ZZ".to_string(),
organization: "TEST-NET".to_string(),
};
assert_eq!(info.to_string(), "AS64496 TEST-NET (ZZ) in 192.0.2.0/24");
}
#[test]
fn test_builder_debug_impl() {
let builder = Builder::new();
let debug_str = format!("{builder:?}");
assert!(debug_str.contains("Builder"));
assert!(debug_str.contains("has_source: false"));
assert!(debug_str.contains("strict: false"));
assert!(debug_str.contains("has_on_warning: false"));
let builder_with_source = builder.with_source("".as_bytes()).unwrap();
let debug_str_with_source = format!("{builder_with_source:?}");
assert!(debug_str_with_source.contains("has_source: true"));
}
#[test]
fn test_builder_build_no_source() {
let result = Builder::new().build();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, Error::Io(_)));
if let Error::Io(io_err) = err {
assert_eq!(io_err.kind(), io::ErrorKind::NotFound);
assert_eq!(io_err.to_string(), "No data source provided");
}
}
#[test]
fn test_builder_warning_handling() {
use std::sync::{Arc, Mutex};
let data = "1.0.0.0\t::1\t123\tUS\tTEST";
let warnings = Arc::new(Mutex::new(Vec::new()));
let warnings_clone = warnings.clone();
let builder = Builder::new()
.with_source(data.as_bytes())
.unwrap()
.on_warning(move |w| {
warnings_clone.lock().unwrap().push(format!("{w:?}"));
});
let map = builder.build().unwrap();
assert!(map.lookup("1.1.1.1".parse::<IpAddr>().unwrap()).is_none()); let warnings_guard = warnings.lock().unwrap();
assert_eq!(warnings_guard.len(), 1);
assert!(warnings_guard[0].contains("IpFamilyMismatch"));
let data = "invalid line";
let builder = Builder::new().with_source(data.as_bytes()).unwrap();
let map = builder.build().unwrap();
assert!(map.lookup("1.1.1.1".parse::<IpAddr>().unwrap()).is_none()); }
}