use std::{collections::HashMap, num::NonZeroUsize, time::Duration};
use crate::{
cache_key, is_bogon, Continent, CountryCurrency, CountryFlag, IpDetails,
IpError, ResproxyDetails, BATCH_MAX_SIZE, BATCH_REQ_TIMEOUT_DEFAULT,
CONTINENTS, COUNTRIES, CURRENCIES, EU, FLAGS, VERSION,
};
use lru::LruCache;
use serde_json::json;
use reqwest::header::{
HeaderMap, HeaderValue, ACCEPT, CONTENT_TYPE, USER_AGENT,
};
use tokio::time::timeout;
const COUNTRY_FLAG_URL: &str =
"https://cdn.ipinfo.io/static/images/countries-flags/";
const BASE_URL: &str = "https://ipinfo.io";
const BASE_URL_V6: &str = "https://v6.ipinfo.io";
pub struct IpInfoConfig {
pub token: Option<String>,
pub timeout: Duration,
pub cache_size: usize,
pub defaut_countries: Option<HashMap<String, String>>,
pub default_eu: Option<Vec<String>>,
pub default_flags: Option<HashMap<String, CountryFlag>>,
pub default_currencies: Option<HashMap<String, CountryCurrency>>,
pub default_continents: Option<HashMap<String, Continent>>,
#[doc(hidden)]
pub base_url: Option<String>,
}
impl Default for IpInfoConfig {
fn default() -> Self {
Self {
token: None,
timeout: Duration::from_secs(3),
cache_size: 100,
defaut_countries: None,
default_eu: None,
default_flags: None,
default_currencies: None,
default_continents: None,
base_url: None,
}
}
}
pub struct IpInfo {
token: Option<String>,
client: reqwest::Client,
cache: LruCache<String, IpDetails>,
countries: HashMap<String, String>,
eu: Vec<String>,
country_flags: HashMap<String, CountryFlag>,
country_currencies: HashMap<String, CountryCurrency>,
continents: HashMap<String, Continent>,
base_url: String,
}
pub struct BatchReqOpts {
batch_size: u64,
timeout_per_batch: Duration,
timeout_total: Option<Duration>,
}
impl Default for BatchReqOpts {
fn default() -> Self {
Self {
batch_size: BATCH_MAX_SIZE,
timeout_per_batch: BATCH_REQ_TIMEOUT_DEFAULT,
timeout_total: None,
}
}
}
impl IpInfo {
pub fn new(config: IpInfoConfig) -> Result<Self, IpError> {
let client =
reqwest::Client::builder().timeout(config.timeout).build()?;
let mut ipinfo_obj = Self {
client,
token: config.token,
cache: LruCache::new(
NonZeroUsize::new(config.cache_size).unwrap(),
),
countries: HashMap::new(),
eu: Vec::new(),
country_flags: HashMap::new(),
country_currencies: HashMap::new(),
continents: HashMap::new(),
base_url: config.base_url.unwrap_or_else(|| BASE_URL.to_string()),
};
ipinfo_obj.countries =
config.defaut_countries.unwrap_or_else(|| COUNTRIES.clone());
ipinfo_obj.eu = config.default_eu.unwrap_or_else(|| EU.clone());
ipinfo_obj.country_flags =
config.default_flags.unwrap_or_else(|| FLAGS.clone());
ipinfo_obj.country_currencies = config
.default_currencies
.unwrap_or_else(|| CURRENCIES.clone());
ipinfo_obj.continents = config
.default_continents
.unwrap_or_else(|| CONTINENTS.clone());
Ok(ipinfo_obj)
}
pub async fn lookup_batch(
&mut self,
ips: &[&str],
batch_config: BatchReqOpts,
) -> Result<HashMap<String, IpDetails>, IpError> {
if let Some(total_timeout) = batch_config.timeout_total {
match timeout(total_timeout, self._lookup_batch(ips, batch_config))
.await
{
Ok(result) => result,
Err(_) => Err(err!(TimeOutError)),
}
} else {
self._lookup_batch(ips, batch_config).await
}
}
async fn _lookup_batch(
&mut self,
ips: &[&str],
batch_config: BatchReqOpts,
) -> Result<HashMap<String, IpDetails>, IpError> {
let mut results: HashMap<String, IpDetails> = HashMap::new();
let mut work = vec![];
for ip in ips.iter() {
if is_bogon(ip) {
results.insert(
ip.to_string(),
IpDetails {
ip: ip.to_string(),
bogon: Some(true),
..Default::default()
},
);
} else if let Some(detail) = self.cache.get(&cache_key(ip)) {
results.insert(ip.to_string(), detail.clone());
} else {
work.push(*ip);
}
}
let client = reqwest::Client::builder()
.timeout(batch_config.timeout_per_batch)
.build()?;
work.sort();
work.dedup();
for batch in work.chunks(batch_config.batch_size as usize) {
let response = self.batch_request(client.clone(), batch).await?;
results.extend(response);
}
for detail in results.values_mut() {
self.populate_static_details(detail);
}
results
.iter()
.filter(|(ip, _)| !is_bogon(ip))
.for_each(|x| {
self.cache.put(cache_key(x.0.as_str()), x.1.clone());
});
Ok(results)
}
async fn batch_request(
&self,
client: reqwest::Client,
ips: &[&str],
) -> Result<HashMap<String, IpDetails>, IpError> {
let response = client
.post(format!("{BASE_URL}/batch"))
.headers(Self::construct_headers())
.bearer_auth(self.token.as_deref().unwrap_or_default())
.json(&json!(ips))
.send()
.await?;
if let reqwest::StatusCode::TOO_MANY_REQUESTS = response.status() {
return Err(err!(RateLimitExceededError));
}
let raw_resp = response.error_for_status()?.text().await?;
let resp: serde_json::Value = serde_json::from_str(&raw_resp)?;
if let Some(e) = resp["error"].as_str() {
return Err(err!(IpRequestError, e));
}
let result: HashMap<String, IpDetails> =
serde_json::from_str(&raw_resp)?;
Ok(result)
}
pub async fn lookup(&mut self, ip: &str) -> Result<IpDetails, IpError> {
self._lookup(ip, BASE_URL).await
}
pub async fn lookup_self_v4(&mut self) -> Result<IpDetails, IpError> {
self._lookup("", BASE_URL).await
}
pub async fn lookup_self_v6(&mut self) -> Result<IpDetails, IpError> {
self._lookup("", BASE_URL_V6).await
}
async fn _lookup(
&mut self,
ip: &str,
base_url: &str,
) -> Result<IpDetails, IpError> {
if is_bogon(ip) {
return Ok(IpDetails {
ip: ip.to_string(),
bogon: Some(true),
..Default::default() });
}
let cached_detail = self.cache.get(&cache_key(ip));
if let Some(cached_detail) = cached_detail {
return Ok(cached_detail.clone());
}
let response = self
.client
.get(format!("{base_url}/{ip}"))
.headers(Self::construct_headers())
.bearer_auth(self.token.as_deref().unwrap_or_default())
.send()
.await?;
if let reqwest::StatusCode::TOO_MANY_REQUESTS = response.status() {
return Err(err!(RateLimitExceededError));
}
let raw_resp = response.error_for_status()?.text().await?;
let resp: serde_json::Value = serde_json::from_str(&raw_resp)?;
if let Some(e) = resp["error"].as_str() {
return Err(err!(IpRequestError, e));
}
let mut details: IpDetails = serde_json::from_str(&raw_resp)?;
self.populate_static_details(&mut details);
self.cache.put(cache_key(ip), details.clone());
Ok(details)
}
pub async fn get_map(&self, ips: &[&str]) -> Result<String, IpError> {
if ips.len() > 500_000 {
return Err(err!(MapLimitError));
}
let map_url = &format!("{BASE_URL}/tools/map?cli=1");
let client = self.client.clone();
let json_ips = serde_json::json!(ips);
let response = client.post(map_url).json(&json_ips).send().await?;
if !response.status().is_success() {
return Err(err!(HTTPClientError));
}
let response_json: serde_json::Value = response.json().await?;
let report_url = response_json["reportUrl"]
.as_str()
.ok_or("Report URL not found");
Ok(report_url.unwrap().to_string())
}
pub async fn lookup_resproxy(
&self,
ip: &str,
) -> Result<ResproxyDetails, IpError> {
let response = self
.client
.get(format!("{}/resproxy/{ip}", self.base_url))
.headers(Self::construct_headers())
.bearer_auth(self.token.as_deref().unwrap_or_default())
.send()
.await?;
if let reqwest::StatusCode::TOO_MANY_REQUESTS = response.status() {
return Err(err!(RateLimitExceededError));
}
let raw_resp = response.error_for_status()?.text().await?;
let resp: serde_json::Value = serde_json::from_str(&raw_resp)?;
if let Some(e) = resp["error"].as_str() {
return Err(err!(IpRequestError, e));
}
let details: ResproxyDetails = serde_json::from_str(&raw_resp)?;
Ok(details)
}
fn populate_static_details(&self, details: &mut IpDetails) {
if !&details.country.is_empty() {
let country_name = self.countries.get(&details.country).unwrap();
details.country_name = Some(country_name.to_string());
details.is_eu = Some(self.eu.contains(&details.country));
let country_flag =
self.country_flags.get(&details.country).unwrap();
details.country_flag = Some(country_flag.to_owned());
let file_ext = ".svg";
details.country_flag_url = Some(
COUNTRY_FLAG_URL.to_string() + &details.country + file_ext,
);
let country_currency =
self.country_currencies.get(&details.country).unwrap();
details.country_currency = Some(country_currency.to_owned());
let continent = self.continents.get(&details.country).unwrap();
details.continent = Some(continent.to_owned());
}
}
fn construct_headers() -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert(
USER_AGENT,
HeaderValue::from_str(&format!("IPinfoClient/Rust/{VERSION}"))
.unwrap(),
);
headers.insert(
CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
headers
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::IpErrorKind;
use std::env;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
fn get_ipinfo_client() -> IpInfo {
IpInfo::new(IpInfoConfig {
token: Some(env::var("IPINFO_TOKEN").unwrap().to_string()),
timeout: Duration::from_secs(3),
cache_size: 100,
..Default::default()
})
.expect("should construct")
}
#[test]
fn ipinfo_config_defaults_reasonable() {
let ipinfo_config = IpInfoConfig::default();
assert_eq!(ipinfo_config.timeout, Duration::from_secs(3));
assert_eq!(ipinfo_config.cache_size, 100);
}
#[test]
fn request_headers_are_canonical() {
let headers = IpInfo::construct_headers();
assert_eq!(
headers[USER_AGENT],
format!("IPinfoClient/Rust/{}", VERSION)
);
assert_eq!(headers[CONTENT_TYPE], "application/json");
assert_eq!(headers[ACCEPT], "application/json");
}
#[tokio::test]
async fn request_single_ip() {
let mut ipinfo = get_ipinfo_client();
let details =
ipinfo.lookup("66.87.125.72").await.expect("should lookup");
assert_eq!(details.ip, "66.87.125.72");
}
#[tokio::test]
async fn request_no_token() {
let mut ipinfo =
IpInfo::new(Default::default()).expect("should construct");
assert_eq!(
ipinfo
.lookup_batch(&["8.8.8.8"], BatchReqOpts::default())
.await
.err()
.unwrap()
.kind(),
IpErrorKind::IpRequestError
);
}
#[tokio::test]
async fn request_multiple_ip() {
let mut ipinfo = get_ipinfo_client();
let details = ipinfo
.lookup_batch(&["8.8.8.8", "4.2.2.4"], BatchReqOpts::default())
.await
.expect("should lookup");
assert!(details.contains_key("8.8.8.8"));
assert!(details.contains_key("4.2.2.4"));
let ip8 = &details["8.8.8.8"];
assert_eq!(ip8.ip, "8.8.8.8");
assert_eq!(ip8.hostname, Some("dns.google".to_owned()));
assert_eq!(ip8.city, "Mountain View");
assert_eq!(ip8.region, "California");
assert_eq!(ip8.country, "US");
assert_eq!(
ip8.country_flag_url,
Some(
"https://cdn.ipinfo.io/static/images/countries-flags/US.svg"
.to_owned()
)
);
assert_eq!(
ip8.country_flag,
Some(CountryFlag {
emoji: "🇺🇸".to_owned(),
unicode: "U+1F1FA U+1F1F8".to_owned()
})
);
assert_eq!(
ip8.country_currency,
Some(CountryCurrency {
code: "USD".to_owned(),
symbol: "$".to_owned()
})
);
assert_eq!(
ip8.continent,
Some(Continent {
code: "NA".to_owned(),
name: "North America".to_owned()
})
);
assert_ne!(ip8.loc, "");
assert_eq!(ip8.postal, Some("94043".to_owned()));
assert_eq!(ip8.timezone, Some("America/Los_Angeles".to_owned()));
let ip4 = &details["4.2.2.4"];
assert_eq!(ip4.ip, "4.2.2.4");
assert_eq!(ip4.hostname, Some("d.resolvers.level3.net".to_owned()));
assert_eq!(ip4.city, "Monroe");
assert_eq!(ip4.region, "Louisiana");
assert_eq!(ip4.country, "US");
assert_eq!(ip4.loc, "32.5530,-92.0422");
assert_eq!(ip4.postal, Some("71203".to_owned()));
assert_eq!(ip4.timezone, Some("America/Chicago".to_owned()));
}
#[tokio::test]
async fn request_cache_miss_and_hit() {
let mut ipinfo = get_ipinfo_client();
let details = ipinfo
.lookup_batch(&["8.8.8.8"], BatchReqOpts::default())
.await
.expect("should lookup");
assert!(details.contains_key("8.8.8.8"));
assert_eq!(details.len(), 1);
let details = ipinfo
.lookup_batch(&["4.2.2.4", "8.8.8.8"], BatchReqOpts::default())
.await
.expect("should lookup");
assert!(details.contains_key("8.8.8.8"));
assert!(details.contains_key("4.2.2.4"));
assert_eq!(details.len(), 2);
}
#[tokio::test]
async fn request_resproxy() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/resproxy/175.107.211.204"))
.respond_with(ResponseTemplate::new(200).set_body_json(
serde_json::json!({
"ip": "175.107.211.204",
"last_seen": "2025-01-20",
"percent_days_seen": 0.85,
"service": "example_service"
}),
))
.mount(&mock_server)
.await;
let ipinfo = IpInfo::new(IpInfoConfig {
token: Some("test_token".to_string()),
base_url: Some(mock_server.uri()),
..Default::default()
})
.expect("should construct");
let details = ipinfo
.lookup_resproxy("175.107.211.204")
.await
.expect("should lookup resproxy");
assert_eq!(details.ip, "175.107.211.204");
assert_eq!(details.last_seen, Some("2025-01-20".to_string()));
assert_eq!(details.percent_days_seen, Some(0.85));
assert_eq!(details.service, Some("example_service".to_string()));
}
#[tokio::test]
async fn request_resproxy_empty() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/resproxy/8.8.8.8"))
.respond_with(
ResponseTemplate::new(200)
.set_body_json(serde_json::json!({})),
)
.mount(&mock_server)
.await;
let ipinfo = IpInfo::new(IpInfoConfig {
token: Some("test_token".to_string()),
base_url: Some(mock_server.uri()),
..Default::default()
})
.expect("should construct");
let details = ipinfo
.lookup_resproxy("8.8.8.8")
.await
.expect("should lookup resproxy");
assert!(details.ip.is_empty());
assert!(details.last_seen.is_none());
assert!(details.percent_days_seen.is_none());
assert!(details.service.is_none());
}
}