1use bincode::{deserialize_from, serialize_into};
50use error_context::*;
51pub use ipnet::Ipv4Net;
52use ipnet::Ipv4Subnets;
53use serde_derive::{Deserialize, Serialize};
54use std::cmp::Ordering;
55use std::error::Error;
56use std::fmt;
57use std::io;
58use std::io::{Read, Write};
59pub use std::net::Ipv4Addr;
60
61const DATABASE_DATA_TAG: &[u8; 4] = b"ASDB";
62const DATABASE_DATA_VERSION: &[u8; 4] = b"bin1";
63
64#[derive(Serialize, Deserialize, Debug, Clone)]
66pub struct Record {
67 pub ip: u32,
69 pub prefix_len: u8,
71 pub as_number: u32,
73 pub country: String,
75 pub owner: String,
77}
78
79impl PartialEq for Record {
80 fn eq(&self, other: &Record) -> bool {
81 self.ip == other.ip && self.prefix_len == other.prefix_len
82 }
83}
84
85impl Eq for Record {}
86
87impl Ord for Record {
88 fn cmp(&self, other: &Record) -> Ordering {
89 self.ip.cmp(&other.ip)
90 }
91}
92
93impl PartialOrd for Record {
94 fn partial_cmp(&self, other: &Record) -> Option<Ordering> {
95 Some(self.cmp(other))
96 }
97}
98
99impl Record {
100 pub fn network(&self) -> Ipv4Net {
102 Ipv4Net::new(self.ip.into(), self.prefix_len).expect("bad network")
103 }
104}
105
106#[derive(Debug)]
107pub enum TsvParseError {
108 TsvError(csv::Error),
109 AddrFieldParseError(std::net::AddrParseError, &'static str),
110 IntFieldParseError(std::num::ParseIntError, &'static str),
111}
112
113impl fmt::Display for TsvParseError {
114 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
115 match self {
116 TsvParseError::TsvError(_) => write!(f, "TSV format error"),
117 TsvParseError::AddrFieldParseError(_, context) => {
118 write!(f, "error parsing IP address while {}", context)
119 }
120 TsvParseError::IntFieldParseError(_, context) => {
121 write!(f, "error parsing integer while {}", context)
122 }
123 }
124 }
125}
126
127impl Error for TsvParseError {
128 fn source(&self) -> Option<&(dyn Error + 'static)> {
129 match self {
130 TsvParseError::TsvError(err) => Some(err),
131 TsvParseError::AddrFieldParseError(err, _) => Some(err),
132 TsvParseError::IntFieldParseError(err, _) => Some(err),
133 }
134 }
135}
136
137impl From<csv::Error> for TsvParseError {
138 fn from(error: csv::Error) -> TsvParseError {
139 TsvParseError::TsvError(error)
140 }
141}
142
143impl From<ErrorContext<std::net::AddrParseError, &'static str>> for TsvParseError {
144 fn from(ec: ErrorContext<std::net::AddrParseError, &'static str>) -> TsvParseError {
145 TsvParseError::AddrFieldParseError(ec.error, ec.context)
146 }
147}
148
149impl From<ErrorContext<std::num::ParseIntError, &'static str>> for TsvParseError {
150 fn from(ec: ErrorContext<std::num::ParseIntError, &'static str>) -> TsvParseError {
151 TsvParseError::IntFieldParseError(ec.error, ec.context)
152 }
153}
154
155pub fn read_asn_tsv<'d, R: io::Read>(
157 data: &'d mut csv::Reader<R>,
158) -> impl Iterator<Item = Result<Record, TsvParseError>> + 'd {
159 data.records()
160 .filter(|record| {
161 if let Ok(record) = record {
162 let owner = &record[4];
163 !(owner == "Not routed" || owner == "None")
164 } else {
165 true
166 }
167 })
168 .map(|record| record.map_err(Into::<TsvParseError>::into))
169 .map(|record| {
170 record.and_then(|record| {
171 let range_start: Ipv4Addr = record[0]
172 .parse()
173 .wrap_error_while("parsing range_start IP")?;
174 let range_end: Ipv4Addr =
175 record[1].parse().wrap_error_while("parsing range_end IP")?;
176 let as_number: u32 = record[2].parse().wrap_error_while("parsing as_number")?;
177 let country = record[3].to_owned();
178 let owner = record[4].to_owned();
179 Ok((range_start, range_end, as_number, country, owner))
180 })
181 })
182 .map(|record| {
183 record.map(|(range_start, range_end, as_number, country, owner)| {
184 Ipv4Subnets::new(range_start, range_end, 8).map(move |subnet| Record {
186 ip: subnet.network().into(),
187 prefix_len: subnet.prefix_len(),
188 country: country.clone(),
189 as_number,
190 owner: owner.clone(),
191 })
192 })
193 })
194 .flat_map(|subnet_records| {
195 let mut records = None;
197 let mut error = None;
198
199 match subnet_records {
200 Ok(subnet_records) => records = Some(subnet_records),
201 Err(err) => error = Some(TsvParseError::from(err)),
202 }
203
204 records
205 .into_iter()
206 .flatten()
207 .map(Ok)
208 .chain(error.into_iter().map(Err))
209 })
210}
211
212#[derive(Debug)]
213pub enum DbError {
214 TsvError(TsvParseError),
215 DbDataError(&'static str),
216 FileError(io::Error, &'static str),
217 BincodeError(bincode::Error, &'static str),
218}
219
220impl fmt::Display for DbError {
221 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
222 match self {
223 DbError::TsvError(_) => write!(f, "error opening ASN DB from TSV file"),
224 DbError::FileError(_, context) => {
225 write!(f, "error opening ASN DB from file while {}", context)
226 }
227 DbError::BincodeError(_, context) => write!(
228 f,
229 "error (de)serializing ASN DB to bincode format while {}",
230 context
231 ),
232 DbError::DbDataError(message) => write!(f, "error while reading database: {}", message),
233 }
234 }
235}
236
237impl Error for DbError {
238 fn source(&self) -> Option<&(dyn Error + 'static)> {
239 match self {
240 DbError::TsvError(err) => Some(err),
241 DbError::FileError(err, _) => Some(err),
242 DbError::BincodeError(err, _) => Some(err),
243 DbError::DbDataError(_) => None,
244 }
245 }
246}
247
248impl From<TsvParseError> for DbError {
249 fn from(err: TsvParseError) -> DbError {
250 DbError::TsvError(err)
251 }
252}
253
254impl From<ErrorContext<io::Error, &'static str>> for DbError {
255 fn from(err: ErrorContext<io::Error, &'static str>) -> DbError {
256 DbError::FileError(err.error, err.context)
257 }
258}
259
260impl From<ErrorContext<bincode::Error, &'static str>> for DbError {
261 fn from(err: ErrorContext<bincode::Error, &'static str>) -> DbError {
262 DbError::BincodeError(err.error, err.context)
263 }
264}
265
266pub struct Db(Vec<Record>);
271
272impl fmt::Debug for Db {
273 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
274 write!(f, "asn_db::Db[total records: {}]", self.0.len())
275 }
276}
277
278impl Db {
279 pub fn form_tsv(data: impl Read) -> Result<Db, DbError> {
281 let mut rdr = csv::ReaderBuilder::new()
282 .delimiter(b'\t')
283 .has_headers(false)
284 .from_reader(data);
285 let mut records = read_asn_tsv(&mut rdr).collect::<Result<Vec<_>, _>>()?;
286 records.sort();
287 Ok(Db(records))
288 }
289
290 pub fn load(mut db_data: impl Read) -> Result<Db, DbError> {
292 let mut tag = [0; 4];
293 db_data
294 .read_exact(&mut tag)
295 .wrap_error_while("reading database tag")?;
296 if &tag != DATABASE_DATA_TAG {
297 return Err(DbError::DbDataError("bad database data tag"));
298 }
299
300 let mut version = [0; 4];
301 db_data
302 .read_exact(&mut version)
303 .wrap_error_while("reading database version")?;
304 if &version != DATABASE_DATA_VERSION {
305 return Err(DbError::DbDataError("unsuported database version"));
306 }
307
308 let records: Vec<Record> =
309 deserialize_from(db_data).wrap_error_while("reading bincode DB file")?;
310
311 Ok(Db(records))
312 }
313
314 pub fn store(&self, mut db_data: impl Write) -> Result<(), DbError> {
316 db_data
317 .write(DATABASE_DATA_TAG)
318 .wrap_error_while("error writing tag")?;
319 db_data
320 .write(DATABASE_DATA_VERSION)
321 .wrap_error_while("error writing version")?;
322 serialize_into(db_data, &self.0).wrap_error_while("stroing DB")?;
323 Ok(())
324 }
325
326 pub fn lookup(&self, ip: Ipv4Addr) -> Option<&Record> {
328 match self.0.binary_search_by_key(&ip.into(), |record| record.ip) {
329 Ok(index) => return Some(&self.0[index]), Err(index) => {
331 if index != 0 {
333 let record = &self.0[index - 1];
334 if record.network().contains(&ip) {
335 return Some(record);
336 }
337 }
338 }
339 }
340 None
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347 use std::fs::File;
348 use std::io::{BufReader, BufWriter};
349 use tempfile::tempdir;
350
351 #[test]
352 fn test_db() {
353 let db = Db::form_tsv(BufReader::new(File::open("ip2asn-v4.tsv").unwrap())).unwrap();
354
355 assert!(db
356 .lookup("1.1.1.1".parse().unwrap())
357 .unwrap()
358 .owner
359 .contains("CLOUDFLARENET"));
360 assert!(db
361 .lookup("8.8.8.8".parse().unwrap())
362 .unwrap()
363 .owner
364 .contains("GOOGLE"));
365 assert!(db
366 .lookup("8.8.4.4".parse().unwrap())
367 .unwrap()
368 .owner
369 .contains("GOOGLE"));
370
371 let temp_dir = tempdir().unwrap();
372 let db_file = temp_dir.path().join("asn-db.dat");
373
374 db.store(BufWriter::new(File::create(&db_file).unwrap()))
375 .unwrap();
376
377 let db = Db::load(BufReader::new(File::open(&db_file).unwrap())).unwrap();
378
379 drop(db_file);
380 drop(temp_dir);
381
382 assert!(db
383 .lookup("1.1.1.1".parse().unwrap())
384 .unwrap()
385 .owner
386 .contains("CLOUDFLARENET"));
387 assert!(db
388 .lookup("8.8.8.8".parse().unwrap())
389 .unwrap()
390 .owner
391 .contains("GOOGLE"));
392 assert!(db
393 .lookup("8.8.4.4".parse().unwrap())
394 .unwrap()
395 .owner
396 .contains("GOOGLE"));
397 }
398}