1use std::{collections::HashMap, num::NonZeroUsize, time::Duration};
16
17use crate::{
18 cache_key, is_bogon, Continent, CountryCurrency, CountryFlag, IpDetails,
19 IpError, BATCH_MAX_SIZE, BATCH_REQ_TIMEOUT_DEFAULT, CONTINENTS, COUNTRIES,
20 CURRENCIES, EU, FLAGS, VERSION,
21};
22
23use lru::LruCache;
24use serde_json::json;
25
26use reqwest::header::{
27 HeaderMap, HeaderValue, ACCEPT, CONTENT_TYPE, USER_AGENT,
28};
29
30use tokio::time::timeout;
31
32const COUNTRY_FLAG_URL: &str =
33 "https://cdn.ipinfo.io/static/images/countries-flags/";
34
35const BASE_URL: &str = "https://ipinfo.io";
36const BASE_URL_V6: &str = "https://v6.ipinfo.io";
37
38pub struct IpInfoConfig {
40 pub token: Option<String>,
42
43 pub timeout: Duration,
45
46 pub cache_size: usize,
48
49 pub defaut_countries: Option<HashMap<String, String>>,
51
52 pub default_eu: Option<Vec<String>>,
54
55 pub default_flags: Option<HashMap<String, CountryFlag>>,
57
58 pub default_currencies: Option<HashMap<String, CountryCurrency>>,
60
61 pub default_continents: Option<HashMap<String, Continent>>,
63}
64
65impl Default for IpInfoConfig {
66 fn default() -> Self {
67 Self {
68 token: None,
69 timeout: Duration::from_secs(3),
70 cache_size: 100,
71 defaut_countries: None,
72 default_eu: None,
73 default_flags: None,
74 default_currencies: None,
75 default_continents: None,
76 }
77 }
78}
79
80pub struct IpInfo {
82 token: Option<String>,
83 client: reqwest::Client,
84 cache: LruCache<String, IpDetails>,
85 countries: HashMap<String, String>,
86 eu: Vec<String>,
87 country_flags: HashMap<String, CountryFlag>,
88 country_currencies: HashMap<String, CountryCurrency>,
89 continents: HashMap<String, Continent>,
90}
91
92pub struct BatchReqOpts {
93 batch_size: u64,
94 timeout_per_batch: Duration,
95 timeout_total: Option<Duration>,
96}
97
98impl Default for BatchReqOpts {
99 fn default() -> Self {
100 Self {
101 batch_size: BATCH_MAX_SIZE,
102 timeout_per_batch: BATCH_REQ_TIMEOUT_DEFAULT,
103 timeout_total: None,
104 }
105 }
106}
107
108impl IpInfo {
109 pub fn new(config: IpInfoConfig) -> Result<Self, IpError> {
119 let client =
120 reqwest::Client::builder().timeout(config.timeout).build()?;
121
122 let mut ipinfo_obj = Self {
123 client,
124 token: config.token,
125 cache: LruCache::new(
126 NonZeroUsize::new(config.cache_size).unwrap(),
127 ),
128 countries: HashMap::new(),
129 eu: Vec::new(),
130 country_flags: HashMap::new(),
131 country_currencies: HashMap::new(),
132 continents: HashMap::new(),
133 };
134
135 if config.defaut_countries.is_none() {
136 ipinfo_obj.countries = COUNTRIES.clone();
137 } else {
138 ipinfo_obj.countries = config.defaut_countries.unwrap();
139 }
140
141 if config.default_eu.is_none() {
142 ipinfo_obj.eu = EU.clone();
143 } else {
144 ipinfo_obj.eu = config.default_eu.unwrap();
145 }
146
147 if config.default_flags.is_none() {
148 ipinfo_obj.country_flags = FLAGS.clone();
149 } else {
150 ipinfo_obj.country_flags = config.default_flags.unwrap();
151 }
152
153 if config.default_currencies.is_none() {
154 ipinfo_obj.country_currencies = CURRENCIES.clone();
155 } else {
156 ipinfo_obj.country_currencies = config.default_currencies.unwrap();
157 }
158
159 if config.default_continents.is_none() {
160 ipinfo_obj.continents = CONTINENTS.clone();
161 } else {
162 ipinfo_obj.continents = config.default_continents.unwrap();
163 }
164
165 Ok(ipinfo_obj)
166 }
167
168 pub async fn lookup_batch(
181 &mut self,
182 ips: &[&str],
183 batch_config: BatchReqOpts,
184 ) -> Result<HashMap<String, IpDetails>, IpError> {
185 if let Some(total_timeout) = batch_config.timeout_total {
187 match timeout(total_timeout, self._lookup_batch(ips, batch_config))
188 .await
189 {
190 Ok(result) => result,
191 Err(_) => Err(err!(TimeOutError)),
192 }
193 } else {
194 self._lookup_batch(ips, batch_config).await
195 }
196 }
197
198 async fn _lookup_batch(
200 &mut self,
201 ips: &[&str],
202 batch_config: BatchReqOpts,
203 ) -> Result<HashMap<String, IpDetails>, IpError> {
204 let mut results: HashMap<String, IpDetails> = HashMap::new();
205
206 let mut work = vec![];
209 for ip in ips.iter() {
210 if is_bogon(ip) {
211 results.insert(
212 ip.to_string(),
213 IpDetails {
214 ip: ip.to_string(),
215 bogon: Some(true),
216 ..Default::default()
217 },
218 );
219 } else if let Some(detail) = self.cache.get(&cache_key(ip)) {
220 results.insert(ip.to_string(), detail.clone());
221 } else {
222 work.push(*ip);
223 }
224 }
225
226 let client = reqwest::Client::builder()
227 .timeout(batch_config.timeout_per_batch)
228 .build()?;
229
230 work.sort();
232 work.dedup();
233
234 for batch in work.chunks(batch_config.batch_size as usize) {
236 let response = self.batch_request(client.clone(), batch).await?;
237 results.extend(response);
238 }
239
240 for detail in results.values_mut() {
242 self.populate_static_details(detail);
243 }
244
245 results
247 .iter()
248 .filter(|(ip, _)| !is_bogon(ip))
249 .for_each(|x| {
250 self.cache.put(cache_key(x.0.as_str()), x.1.clone());
251 });
252
253 Ok(results)
254 }
255
256 async fn batch_request(
257 &self,
258 client: reqwest::Client,
259 ips: &[&str],
260 ) -> Result<HashMap<String, IpDetails>, IpError> {
261 let response = client
263 .post(format!("{BASE_URL}/batch"))
264 .headers(Self::construct_headers())
265 .bearer_auth(self.token.as_deref().unwrap_or_default())
266 .json(&json!(ips))
267 .send()
268 .await?;
269
270 if let reqwest::StatusCode::TOO_MANY_REQUESTS = response.status() {
272 return Err(err!(RateLimitExceededError));
273 }
274
275 let raw_resp = response.error_for_status()?.text().await?;
277
278 let resp: serde_json::Value = serde_json::from_str(&raw_resp)?;
280
281 if let Some(e) = resp["error"].as_str() {
283 return Err(err!(IpRequestError, e));
284 }
285
286 let result: HashMap<String, IpDetails> =
288 serde_json::from_str(&raw_resp)?;
289 Ok(result)
290 }
291
292 pub async fn lookup(&mut self, ip: &str) -> Result<IpDetails, IpError> {
306 self._lookup(ip, BASE_URL).await
307 }
308
309 pub async fn lookup_self_v4(&mut self) -> Result<IpDetails, IpError> {
323 self._lookup("", BASE_URL).await
324 }
325
326 pub async fn lookup_self_v6(&mut self) -> Result<IpDetails, IpError> {
340 self._lookup("", BASE_URL_V6).await
341 }
342
343 async fn _lookup(
344 &mut self,
345 ip: &str,
346 base_url: &str,
347 ) -> Result<IpDetails, IpError> {
348 if is_bogon(ip) {
349 return Ok(IpDetails {
350 ip: ip.to_string(),
351 bogon: Some(true),
352 ..Default::default() });
354 }
355
356 let cached_detail = self.cache.get(&cache_key(ip));
358
359 if let Some(cached_detail) = cached_detail {
360 return Ok(cached_detail.clone());
361 }
362
363 let response = self
365 .client
366 .get(format!("{base_url}/{ip}"))
367 .headers(Self::construct_headers())
368 .bearer_auth(self.token.as_deref().unwrap_or_default())
369 .send()
370 .await?;
371
372 if let reqwest::StatusCode::TOO_MANY_REQUESTS = response.status() {
374 return Err(err!(RateLimitExceededError));
375 }
376
377 let raw_resp = response.error_for_status()?.text().await?;
379
380 let resp: serde_json::Value = serde_json::from_str(&raw_resp)?;
382
383 if let Some(e) = resp["error"].as_str() {
385 return Err(err!(IpRequestError, e));
386 }
387
388 let mut details: IpDetails = serde_json::from_str(&raw_resp)?;
390 self.populate_static_details(&mut details);
391
392 self.cache.put(cache_key(ip), details.clone());
394 Ok(details)
395 }
396
397 pub async fn get_map(&self, ips: &[&str]) -> Result<String, IpError> {
411 if ips.len() > 500_000 {
412 return Err(err!(MapLimitError));
413 }
414
415 let map_url = &format!("{BASE_URL}/tools/map?cli=1");
416 let client = self.client.clone();
417 let json_ips = serde_json::json!(ips);
418
419 let response = client.post(map_url).json(&json_ips).send().await?;
420 if !response.status().is_success() {
421 return Err(err!(HTTPClientError));
422 }
423
424 let response_json: serde_json::Value = response.json().await?;
425 let report_url = response_json["reportUrl"]
426 .as_str()
427 .ok_or("Report URL not found");
428 Ok(report_url.unwrap().to_string())
429 }
430
431 fn populate_static_details(&self, details: &mut IpDetails) {
433 if !&details.country.is_empty() {
434 let country_name = self.countries.get(&details.country).unwrap();
435 details.country_name = Some(country_name.to_string());
436 details.is_eu = Some(self.eu.contains(&details.country));
437 let country_flag =
438 self.country_flags.get(&details.country).unwrap();
439 details.country_flag = Some(country_flag.to_owned());
440 let file_ext = ".svg";
441 details.country_flag_url = Some(
442 COUNTRY_FLAG_URL.to_string() + &details.country + file_ext,
443 );
444 let country_currency =
445 self.country_currencies.get(&details.country).unwrap();
446 details.country_currency = Some(country_currency.to_owned());
447 let continent = self.continents.get(&details.country).unwrap();
448 details.continent = Some(continent.to_owned());
449 }
450 }
451
452 fn construct_headers() -> HeaderMap {
454 let mut headers = HeaderMap::new();
455 headers.insert(
456 USER_AGENT,
457 HeaderValue::from_str(&format!("IPinfoClient/Rust/{VERSION}"))
458 .unwrap(),
459 );
460 headers.insert(
461 CONTENT_TYPE,
462 HeaderValue::from_static("application/json"),
463 );
464 headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
465 headers
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472 use crate::IpErrorKind;
473 use std::env;
474
475 fn get_ipinfo_client() -> IpInfo {
476 IpInfo::new(IpInfoConfig {
477 token: Some(env::var("IPINFO_TOKEN").unwrap().to_string()),
478 timeout: Duration::from_secs(3),
479 cache_size: 100,
480 ..Default::default()
481 })
482 .expect("should construct")
483 }
484
485 #[test]
486 fn ipinfo_config_defaults_reasonable() {
487 let ipinfo_config = IpInfoConfig::default();
488
489 assert_eq!(ipinfo_config.timeout, Duration::from_secs(3));
490 assert_eq!(ipinfo_config.cache_size, 100);
491 }
492
493 #[test]
494 fn request_headers_are_canonical() {
495 let headers = IpInfo::construct_headers();
496
497 assert_eq!(
498 headers[USER_AGENT],
499 format!("IPinfoClient/Rust/{}", VERSION)
500 );
501 assert_eq!(headers[CONTENT_TYPE], "application/json");
502 assert_eq!(headers[ACCEPT], "application/json");
503 }
504
505 #[tokio::test]
506 async fn request_single_ip() {
507 let mut ipinfo = get_ipinfo_client();
508
509 let details =
510 ipinfo.lookup("66.87.125.72").await.expect("should lookup");
511
512 assert_eq!(details.ip, "66.87.125.72");
513 }
514
515 #[tokio::test]
516 async fn request_no_token() {
517 let mut ipinfo =
518 IpInfo::new(Default::default()).expect("should construct");
519
520 assert_eq!(
521 ipinfo
522 .lookup_batch(&["8.8.8.8"], BatchReqOpts::default())
523 .await
524 .err()
525 .unwrap()
526 .kind(),
527 IpErrorKind::IpRequestError
528 );
529 }
530
531 #[tokio::test]
532 async fn request_multiple_ip() {
533 let mut ipinfo = get_ipinfo_client();
534
535 let details = ipinfo
536 .lookup_batch(&["8.8.8.8", "4.2.2.4"], BatchReqOpts::default())
537 .await
538 .expect("should lookup");
539
540 assert!(details.contains_key("8.8.8.8"));
542 assert!(details.contains_key("4.2.2.4"));
543
544 let ip8 = &details["8.8.8.8"];
546 assert_eq!(ip8.ip, "8.8.8.8");
547 assert_eq!(ip8.hostname, Some("dns.google".to_owned()));
548 assert_eq!(ip8.city, "Mountain View");
549 assert_eq!(ip8.region, "California");
550 assert_eq!(ip8.country, "US");
551 assert_eq!(
552 ip8.country_flag_url,
553 Some(
554 "https://cdn.ipinfo.io/static/images/countries-flags/US.svg"
555 .to_owned()
556 )
557 );
558 assert_eq!(
559 ip8.country_flag,
560 Some(CountryFlag {
561 emoji: "🇺🇸".to_owned(),
562 unicode: "U+1F1FA U+1F1F8".to_owned()
563 })
564 );
565 assert_eq!(
566 ip8.country_currency,
567 Some(CountryCurrency {
568 code: "USD".to_owned(),
569 symbol: "$".to_owned()
570 })
571 );
572 assert_eq!(
573 ip8.continent,
574 Some(Continent {
575 code: "NA".to_owned(),
576 name: "North America".to_owned()
577 })
578 );
579 assert_eq!(ip8.loc, "38.0088,-122.1175");
580 assert_eq!(ip8.postal, Some("94043".to_owned()));
581 assert_eq!(ip8.timezone, Some("America/Los_Angeles".to_owned()));
582
583 let ip4 = &details["4.2.2.4"];
585 assert_eq!(ip4.ip, "4.2.2.4");
586 assert_eq!(ip4.hostname, Some("d.resolvers.level3.net".to_owned()));
587 assert_eq!(ip4.city, "Monroe");
588 assert_eq!(ip4.region, "Louisiana");
589 assert_eq!(ip4.country, "US");
590 assert_eq!(ip4.loc, "32.5530,-92.0422");
591 assert_eq!(ip4.postal, Some("71203".to_owned()));
592 assert_eq!(ip4.timezone, Some("America/Chicago".to_owned()));
593 }
594
595 #[tokio::test]
596 async fn request_cache_miss_and_hit() {
597 let mut ipinfo = get_ipinfo_client();
598
599 let details = ipinfo
601 .lookup_batch(&["8.8.8.8"], BatchReqOpts::default())
602 .await
603 .expect("should lookup");
604
605 assert!(details.contains_key("8.8.8.8"));
607 assert_eq!(details.len(), 1);
608
609 let details = ipinfo
611 .lookup_batch(&["4.2.2.4", "8.8.8.8"], BatchReqOpts::default())
612 .await
613 .expect("should lookup");
614
615 assert!(details.contains_key("8.8.8.8"));
617 assert!(details.contains_key("4.2.2.4"));
618 assert_eq!(details.len(), 2);
619 }
620}