ipinfo/
ipinfo.rs

1//   Copyright 2019-2024 IPinfo library developers
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15use std::{collections::HashMap, num::NonZeroUsize, time::Duration};
16
17use crate::{
18    cache_key, is_bogon, Continent, CountryCurrency, CountryFlag, IpDetails,
19    IpError, BATCH_MAX_SIZE, BATCH_REQ_TIMEOUT_DEFAULT, CONTINENTS, COUNTRIES,
20    CURRENCIES, EU, FLAGS, VERSION,
21};
22
23use lru::LruCache;
24use serde_json::json;
25
26use reqwest::header::{
27    HeaderMap, HeaderValue, ACCEPT, CONTENT_TYPE, USER_AGENT,
28};
29
30use tokio::time::timeout;
31
32const COUNTRY_FLAG_URL: &str =
33    "https://cdn.ipinfo.io/static/images/countries-flags/";
34
35const BASE_URL: &str = "https://ipinfo.io";
36const BASE_URL_V6: &str = "https://v6.ipinfo.io";
37
38/// IpInfo structure configuration.
39pub struct IpInfoConfig {
40    /// IPinfo access token.
41    pub token: Option<String>,
42
43    /// The timeout of HTTP requests. (default: 3 seconds)
44    pub timeout: Duration,
45
46    /// The size of the LRU cache. (default: 100 IPs)
47    pub cache_size: usize,
48
49    // Default mapping of country codes to country names
50    pub defaut_countries: Option<HashMap<String, String>>,
51
52    // Default list of EU countries
53    pub default_eu: Option<Vec<String>>,
54
55    // Default mapping of country codes to their respective flag emoji and unicode
56    pub default_flags: Option<HashMap<String, CountryFlag>>,
57
58    // Default mapping of currencies to their respective currency code and symbol
59    pub default_currencies: Option<HashMap<String, CountryCurrency>>,
60
61    // Default mapping of country codes to their respective continent code and name
62    pub default_continents: Option<HashMap<String, Continent>>,
63}
64
65impl Default for IpInfoConfig {
66    fn default() -> Self {
67        Self {
68            token: None,
69            timeout: Duration::from_secs(3),
70            cache_size: 100,
71            defaut_countries: None,
72            default_eu: None,
73            default_flags: None,
74            default_currencies: None,
75            default_continents: None,
76        }
77    }
78}
79
80/// IPinfo requests context structure.
81pub struct IpInfo {
82    token: Option<String>,
83    client: reqwest::Client,
84    cache: LruCache<String, IpDetails>,
85    countries: HashMap<String, String>,
86    eu: Vec<String>,
87    country_flags: HashMap<String, CountryFlag>,
88    country_currencies: HashMap<String, CountryCurrency>,
89    continents: HashMap<String, Continent>,
90}
91
92pub struct BatchReqOpts {
93    batch_size: u64,
94    timeout_per_batch: Duration,
95    timeout_total: Option<Duration>,
96}
97
98impl Default for BatchReqOpts {
99    fn default() -> Self {
100        Self {
101            batch_size: BATCH_MAX_SIZE,
102            timeout_per_batch: BATCH_REQ_TIMEOUT_DEFAULT,
103            timeout_total: None,
104        }
105    }
106}
107
108impl IpInfo {
109    /// Construct a new IpInfo structure.
110    ///
111    /// # Examples
112    ///
113    /// ```
114    /// use ipinfo::IpInfo;
115    ///
116    /// let ipinfo = IpInfo::new(Default::default()).expect("should construct");
117    /// ```
118    pub fn new(config: IpInfoConfig) -> Result<Self, IpError> {
119        let client =
120            reqwest::Client::builder().timeout(config.timeout).build()?;
121
122        let mut ipinfo_obj = Self {
123            client,
124            token: config.token,
125            cache: LruCache::new(
126                NonZeroUsize::new(config.cache_size).unwrap(),
127            ),
128            countries: HashMap::new(),
129            eu: Vec::new(),
130            country_flags: HashMap::new(),
131            country_currencies: HashMap::new(),
132            continents: HashMap::new(),
133        };
134
135        if config.defaut_countries.is_none() {
136            ipinfo_obj.countries = COUNTRIES.clone();
137        } else {
138            ipinfo_obj.countries = config.defaut_countries.unwrap();
139        }
140
141        if config.default_eu.is_none() {
142            ipinfo_obj.eu = EU.clone();
143        } else {
144            ipinfo_obj.eu = config.default_eu.unwrap();
145        }
146
147        if config.default_flags.is_none() {
148            ipinfo_obj.country_flags = FLAGS.clone();
149        } else {
150            ipinfo_obj.country_flags = config.default_flags.unwrap();
151        }
152
153        if config.default_currencies.is_none() {
154            ipinfo_obj.country_currencies = CURRENCIES.clone();
155        } else {
156            ipinfo_obj.country_currencies = config.default_currencies.unwrap();
157        }
158
159        if config.default_continents.is_none() {
160            ipinfo_obj.continents = CONTINENTS.clone();
161        } else {
162            ipinfo_obj.continents = config.default_continents.unwrap();
163        }
164
165        Ok(ipinfo_obj)
166    }
167
168    /// Lookup IPDetails for a list of one or more IP addresses.
169    ///
170    /// # Examples
171    ///
172    /// ```no_run
173    /// use ipinfo::{IpInfo, BatchReqOpts};
174    /// #[tokio::main]
175    /// async fn main() {
176    ///     let mut ipinfo = IpInfo::new(Default::default()).expect("should construct");
177    ///     let res = ipinfo.lookup_batch(&["8.8.8.8"], BatchReqOpts::default()).await.expect("should run");
178    /// }
179    /// ```
180    pub async fn lookup_batch(
181        &mut self,
182        ips: &[&str],
183        batch_config: BatchReqOpts,
184    ) -> Result<HashMap<String, IpDetails>, IpError> {
185        // Handle the total timeout condition
186        if let Some(total_timeout) = batch_config.timeout_total {
187            match timeout(total_timeout, self._lookup_batch(ips, batch_config))
188                .await
189            {
190                Ok(result) => result,
191                Err(_) => Err(err!(TimeOutError)),
192            }
193        } else {
194            self._lookup_batch(ips, batch_config).await
195        }
196    }
197
198    // Internal lookup_batch function. This ignores the total timeout condition
199    async fn _lookup_batch(
200        &mut self,
201        ips: &[&str],
202        batch_config: BatchReqOpts,
203    ) -> Result<HashMap<String, IpDetails>, IpError> {
204        let mut results: HashMap<String, IpDetails> = HashMap::new();
205
206        // Collect a list of ips we need to lookup.
207        // Filters out bogons and cache hits
208        let mut work = vec![];
209        for ip in ips.iter() {
210            if is_bogon(ip) {
211                results.insert(
212                    ip.to_string(),
213                    IpDetails {
214                        ip: ip.to_string(),
215                        bogon: Some(true),
216                        ..Default::default()
217                    },
218                );
219            } else if let Some(detail) = self.cache.get(&cache_key(ip)) {
220                results.insert(ip.to_string(), detail.clone());
221            } else {
222                work.push(*ip);
223            }
224        }
225
226        let client = reqwest::Client::builder()
227            .timeout(batch_config.timeout_per_batch)
228            .build()?;
229
230        // Remove duplicates
231        work.sort();
232        work.dedup();
233
234        // Make batched requests
235        for batch in work.chunks(batch_config.batch_size as usize) {
236            let response = self.batch_request(client.clone(), batch).await?;
237            results.extend(response);
238        }
239
240        // Add country_name and EU status to response
241        for detail in results.values_mut() {
242            self.populate_static_details(detail);
243        }
244
245        // Update cache
246        results
247            .iter()
248            .filter(|(ip, _)| !is_bogon(ip))
249            .for_each(|x| {
250                self.cache.put(cache_key(x.0.as_str()), x.1.clone());
251            });
252
253        Ok(results)
254    }
255
256    async fn batch_request(
257        &self,
258        client: reqwest::Client,
259        ips: &[&str],
260    ) -> Result<HashMap<String, IpDetails>, IpError> {
261        // Lookup cache misses which are not bogon
262        let response = client
263            .post(format!("{BASE_URL}/batch"))
264            .headers(Self::construct_headers())
265            .bearer_auth(self.token.as_deref().unwrap_or_default())
266            .json(&json!(ips))
267            .send()
268            .await?;
269
270        // Check if we exhausted our request quota
271        if let reqwest::StatusCode::TOO_MANY_REQUESTS = response.status() {
272            return Err(err!(RateLimitExceededError));
273        }
274
275        // Acquire response
276        let raw_resp = response.error_for_status()?.text().await?;
277
278        // Parse the response
279        let resp: serde_json::Value = serde_json::from_str(&raw_resp)?;
280
281        // Return if an error occurred
282        if let Some(e) = resp["error"].as_str() {
283            return Err(err!(IpRequestError, e));
284        }
285
286        // Parse the results
287        let result: HashMap<String, IpDetails> =
288            serde_json::from_str(&raw_resp)?;
289        Ok(result)
290    }
291
292    /// looks up IPDetails for a single IP Address
293    ///
294    /// # Example
295    ///
296    /// ```no_run
297    /// use ipinfo::IpInfo;
298    ///
299    ///  #[tokio::main]
300    /// async fn main() {
301    ///     let mut ipinfo = IpInfo::new(Default::default()).expect("should construct");
302    ///     let res = ipinfo.lookup("8.8.8.8").await.expect("should run");
303    /// }
304    /// ```
305    pub async fn lookup(&mut self, ip: &str) -> Result<IpDetails, IpError> {
306        self._lookup(ip, BASE_URL).await
307    }
308
309    /// looks up IPDetails of your own v4 IP
310    ///
311    /// # Example
312    ///
313    /// ```no_run
314    /// use ipinfo::IpInfo;
315    ///
316    ///  #[tokio::main]
317    /// async fn main() {
318    ///     let mut ipinfo = IpInfo::new(Default::default()).expect("should construct");
319    ///     let res = ipinfo.lookup_self_v4().await.expect("should run");
320    /// }
321    /// ```
322    pub async fn lookup_self_v4(&mut self) -> Result<IpDetails, IpError> {
323        self._lookup("", BASE_URL).await
324    }
325
326    /// looks up IPDetails of your own v6 IP
327    ///
328    /// # Example
329    ///
330    /// ```no_run
331    /// use ipinfo::IpInfo;
332    ///
333    ///  #[tokio::main]
334    /// async fn main() {
335    ///     let mut ipinfo = IpInfo::new(Default::default()).expect("should construct");
336    ///     let res = ipinfo.lookup_self_v6().await.expect("should run");
337    /// }
338    /// ```
339    pub async fn lookup_self_v6(&mut self) -> Result<IpDetails, IpError> {
340        self._lookup("", BASE_URL_V6).await
341    }
342
343    async fn _lookup(
344        &mut self,
345        ip: &str,
346        base_url: &str,
347    ) -> Result<IpDetails, IpError> {
348        if is_bogon(ip) {
349            return Ok(IpDetails {
350                ip: ip.to_string(),
351                bogon: Some(true),
352                ..Default::default() // fill remaining with default values
353            });
354        }
355
356        // Check for cache hit
357        let cached_detail = self.cache.get(&cache_key(ip));
358
359        if let Some(cached_detail) = cached_detail {
360            return Ok(cached_detail.clone());
361        }
362
363        // lookup in case of a cache miss
364        let response = self
365            .client
366            .get(format!("{base_url}/{ip}"))
367            .headers(Self::construct_headers())
368            .bearer_auth(self.token.as_deref().unwrap_or_default())
369            .send()
370            .await?;
371
372        // Check if we exhausted our request quota
373        if let reqwest::StatusCode::TOO_MANY_REQUESTS = response.status() {
374            return Err(err!(RateLimitExceededError));
375        }
376
377        // Acquire response
378        let raw_resp = response.error_for_status()?.text().await?;
379
380        // Parse the response
381        let resp: serde_json::Value = serde_json::from_str(&raw_resp)?;
382
383        // Return if an error occurred
384        if let Some(e) = resp["error"].as_str() {
385            return Err(err!(IpRequestError, e));
386        }
387
388        // Parse the results and add additional country details
389        let mut details: IpDetails = serde_json::from_str(&raw_resp)?;
390        self.populate_static_details(&mut details);
391
392        // update cache
393        self.cache.put(cache_key(ip), details.clone());
394        Ok(details)
395    }
396
397    /// Get a mapping of a list of IPs on a world map
398    ///
399    /// # Example
400    ///
401    /// ```no_run
402    /// use ipinfo::IpInfo;
403    ///
404    ///  #[tokio::main]
405    /// async fn main() {
406    ///     let ipinfo = IpInfo::new(Default::default()).expect("should construct");
407    ///     let map_url = ipinfo.get_map(&["8.8.8.8", "4.2.2.4"]).await.expect("should run");
408    /// }
409    /// ```
410    pub async fn get_map(&self, ips: &[&str]) -> Result<String, IpError> {
411        if ips.len() > 500_000 {
412            return Err(err!(MapLimitError));
413        }
414
415        let map_url = &format!("{BASE_URL}/tools/map?cli=1");
416        let client = self.client.clone();
417        let json_ips = serde_json::json!(ips);
418
419        let response = client.post(map_url).json(&json_ips).send().await?;
420        if !response.status().is_success() {
421            return Err(err!(HTTPClientError));
422        }
423
424        let response_json: serde_json::Value = response.json().await?;
425        let report_url = response_json["reportUrl"]
426            .as_str()
427            .ok_or("Report URL not found");
428        Ok(report_url.unwrap().to_string())
429    }
430
431    // Add country details and EU status to response
432    fn populate_static_details(&self, details: &mut IpDetails) {
433        if !&details.country.is_empty() {
434            let country_name = self.countries.get(&details.country).unwrap();
435            details.country_name = Some(country_name.to_string());
436            details.is_eu = Some(self.eu.contains(&details.country));
437            let country_flag =
438                self.country_flags.get(&details.country).unwrap();
439            details.country_flag = Some(country_flag.to_owned());
440            let file_ext = ".svg";
441            details.country_flag_url = Some(
442                COUNTRY_FLAG_URL.to_string() + &details.country + file_ext,
443            );
444            let country_currency =
445                self.country_currencies.get(&details.country).unwrap();
446            details.country_currency = Some(country_currency.to_owned());
447            let continent = self.continents.get(&details.country).unwrap();
448            details.continent = Some(continent.to_owned());
449        }
450    }
451
452    /// Construct API request headers.
453    fn construct_headers() -> HeaderMap {
454        let mut headers = HeaderMap::new();
455        headers.insert(
456            USER_AGENT,
457            HeaderValue::from_str(&format!("IPinfoClient/Rust/{VERSION}"))
458                .unwrap(),
459        );
460        headers.insert(
461            CONTENT_TYPE,
462            HeaderValue::from_static("application/json"),
463        );
464        headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
465        headers
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472    use crate::IpErrorKind;
473    use std::env;
474
475    fn get_ipinfo_client() -> IpInfo {
476        IpInfo::new(IpInfoConfig {
477            token: Some(env::var("IPINFO_TOKEN").unwrap().to_string()),
478            timeout: Duration::from_secs(3),
479            cache_size: 100,
480            ..Default::default()
481        })
482        .expect("should construct")
483    }
484
485    #[test]
486    fn ipinfo_config_defaults_reasonable() {
487        let ipinfo_config = IpInfoConfig::default();
488
489        assert_eq!(ipinfo_config.timeout, Duration::from_secs(3));
490        assert_eq!(ipinfo_config.cache_size, 100);
491    }
492
493    #[test]
494    fn request_headers_are_canonical() {
495        let headers = IpInfo::construct_headers();
496
497        assert_eq!(
498            headers[USER_AGENT],
499            format!("IPinfoClient/Rust/{}", VERSION)
500        );
501        assert_eq!(headers[CONTENT_TYPE], "application/json");
502        assert_eq!(headers[ACCEPT], "application/json");
503    }
504
505    #[tokio::test]
506    async fn request_single_ip() {
507        let mut ipinfo = get_ipinfo_client();
508
509        let details =
510            ipinfo.lookup("66.87.125.72").await.expect("should lookup");
511
512        assert_eq!(details.ip, "66.87.125.72");
513    }
514
515    #[tokio::test]
516    async fn request_no_token() {
517        let mut ipinfo =
518            IpInfo::new(Default::default()).expect("should construct");
519
520        assert_eq!(
521            ipinfo
522                .lookup_batch(&["8.8.8.8"], BatchReqOpts::default())
523                .await
524                .err()
525                .unwrap()
526                .kind(),
527            IpErrorKind::IpRequestError
528        );
529    }
530
531    #[tokio::test]
532    async fn request_multiple_ip() {
533        let mut ipinfo = get_ipinfo_client();
534
535        let details = ipinfo
536            .lookup_batch(&["8.8.8.8", "4.2.2.4"], BatchReqOpts::default())
537            .await
538            .expect("should lookup");
539
540        // Assert successful lookup
541        assert!(details.contains_key("8.8.8.8"));
542        assert!(details.contains_key("4.2.2.4"));
543
544        // Assert 8.8.8.8
545        let ip8 = &details["8.8.8.8"];
546        assert_eq!(ip8.ip, "8.8.8.8");
547        assert_eq!(ip8.hostname, Some("dns.google".to_owned()));
548        assert_eq!(ip8.city, "Mountain View");
549        assert_eq!(ip8.region, "California");
550        assert_eq!(ip8.country, "US");
551        assert_eq!(
552            ip8.country_flag_url,
553            Some(
554                "https://cdn.ipinfo.io/static/images/countries-flags/US.svg"
555                    .to_owned()
556            )
557        );
558        assert_eq!(
559            ip8.country_flag,
560            Some(CountryFlag {
561                emoji: "🇺🇸".to_owned(),
562                unicode: "U+1F1FA U+1F1F8".to_owned()
563            })
564        );
565        assert_eq!(
566            ip8.country_currency,
567            Some(CountryCurrency {
568                code: "USD".to_owned(),
569                symbol: "$".to_owned()
570            })
571        );
572        assert_eq!(
573            ip8.continent,
574            Some(Continent {
575                code: "NA".to_owned(),
576                name: "North America".to_owned()
577            })
578        );
579        assert_eq!(ip8.loc, "38.0088,-122.1175");
580        assert_eq!(ip8.postal, Some("94043".to_owned()));
581        assert_eq!(ip8.timezone, Some("America/Los_Angeles".to_owned()));
582
583        // Assert 4.2.2.4
584        let ip4 = &details["4.2.2.4"];
585        assert_eq!(ip4.ip, "4.2.2.4");
586        assert_eq!(ip4.hostname, Some("d.resolvers.level3.net".to_owned()));
587        assert_eq!(ip4.city, "Monroe");
588        assert_eq!(ip4.region, "Louisiana");
589        assert_eq!(ip4.country, "US");
590        assert_eq!(ip4.loc, "32.5530,-92.0422");
591        assert_eq!(ip4.postal, Some("71203".to_owned()));
592        assert_eq!(ip4.timezone, Some("America/Chicago".to_owned()));
593    }
594
595    #[tokio::test]
596    async fn request_cache_miss_and_hit() {
597        let mut ipinfo = get_ipinfo_client();
598
599        // Populate the cache with 8.8.8.8
600        let details = ipinfo
601            .lookup_batch(&["8.8.8.8"], BatchReqOpts::default())
602            .await
603            .expect("should lookup");
604
605        // Assert 1 result
606        assert!(details.contains_key("8.8.8.8"));
607        assert_eq!(details.len(), 1);
608
609        // Should have a cache hit for 8.8.8.8 and query for 4.2.2.4
610        let details = ipinfo
611            .lookup_batch(&["4.2.2.4", "8.8.8.8"], BatchReqOpts::default())
612            .await
613            .expect("should lookup");
614
615        // Assert 2 results
616        assert!(details.contains_key("8.8.8.8"));
617        assert!(details.contains_key("4.2.2.4"));
618        assert_eq!(details.len(), 2);
619    }
620}