use lru::LruCache;
use std::collections::BTreeSet;
use std::net::IpAddr;
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
use tracing::debug;
#[derive(Debug)]
pub struct IpTracker {
cache: Arc<Mutex<LruCache<String, TokenIpState>>>,
max_ips_per_token: usize,
entry_ttl: Duration,
}
struct TokenIpState {
ips: BTreeSet<IpAddr>,
created_at: Instant,
}
#[derive(Debug, thiserror::Error)]
pub enum IpTrackerError {
#[error(
"IP limit exceeded: {current_count}/{limit} unique IPs used. Please refresh your token."
)]
LimitExceeded {
token_id: String,
current_count: usize,
limit: usize,
},
}
impl IpTracker {
pub fn new(max_ips_per_token: usize, cache_size: usize, entry_ttl_seconds: u64) -> Self {
Self {
cache: Arc::new(Mutex::new(LruCache::new(
NonZeroUsize::new(cache_size).unwrap(),
))),
max_ips_per_token,
entry_ttl: Duration::from_secs(entry_ttl_seconds),
}
}
pub async fn check_and_track(
&self,
token_id: &str,
client_ip: IpAddr,
) -> Result<(), IpTrackerError> {
let normalized_ip = normalize_dual_stack_ip(client_ip);
let mut cache = self.cache.lock().await;
let state = match cache.get_mut(token_id) {
Some(state) => {
if state.created_at.elapsed() > self.entry_ttl {
debug!("[IP Tracker] TTL expired for token {}, resetting", token_id);
state.ips.clear();
state.created_at = Instant::now();
}
state
}
None => {
cache.put(
token_id.to_string(),
TokenIpState {
ips: BTreeSet::new(),
created_at: Instant::now(),
},
);
cache.get_mut(token_id).unwrap()
}
};
if state.ips.contains(&normalized_ip) {
debug!(
"[IP Tracker] IP {} already tracked for token {}",
normalized_ip, token_id
);
return Ok(());
}
if state.ips.len() >= self.max_ips_per_token {
return Err(IpTrackerError::LimitExceeded {
token_id: token_id.to_string(),
current_count: state.ips.len(),
limit: self.max_ips_per_token,
});
}
state.ips.insert(normalized_ip);
debug!(
"[IP Tracker] Tracked new IP {} for token {} ({}/{} IPs)",
normalized_ip,
token_id,
state.ips.len(),
self.max_ips_per_token
);
Ok(())
}
pub async fn get_ip_count(&self, token_id: &str) -> usize {
let cache = self.cache.lock().await;
cache
.peek(token_id)
.map(|state| state.ips.len())
.unwrap_or(0)
}
}
fn normalize_dual_stack_ip(ip: IpAddr) -> IpAddr {
match ip {
IpAddr::V6(v6) => {
if let Some(v4) = v6.to_ipv4_mapped() {
IpAddr::V4(v4)
} else {
IpAddr::V6(v6)
}
}
IpAddr::V4(v4) => IpAddr::V4(v4),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_track_new_ip() {
let tracker = IpTracker::new(5, 100, 3600);
let result = tracker
.check_and_track("token1", "1.2.3.4".parse().unwrap())
.await;
assert!(result.is_ok());
let count = tracker.get_ip_count("token1").await;
assert_eq!(count, 1);
}
#[tokio::test]
async fn test_track_duplicate_ip() {
let tracker = IpTracker::new(5, 100, 3600);
tracker
.check_and_track("token1", "1.2.3.4".parse().unwrap())
.await
.unwrap();
tracker
.check_and_track("token1", "1.2.3.4".parse().unwrap())
.await
.unwrap();
let count = tracker.get_ip_count("token1").await;
assert_eq!(count, 1); }
#[tokio::test]
async fn test_ip_limit_enforcement() {
let tracker = IpTracker::new(3, 100, 3600);
tracker
.check_and_track("token1", "1.2.3.4".parse().unwrap())
.await
.unwrap();
tracker
.check_and_track("token1", "1.2.3.5".parse().unwrap())
.await
.unwrap();
tracker
.check_and_track("token1", "1.2.3.6".parse().unwrap())
.await
.unwrap();
let count = tracker.get_ip_count("token1").await;
assert_eq!(count, 3);
let result = tracker
.check_and_track("token1", "1.2.3.7".parse().unwrap())
.await;
assert!(matches!(result, Err(IpTrackerError::LimitExceeded { .. })));
if let Err(IpTrackerError::LimitExceeded {
current_count,
limit,
..
}) = result
{
assert_eq!(current_count, 3);
assert_eq!(limit, 3);
}
}
#[tokio::test]
async fn test_multiple_tokens() {
let tracker = IpTracker::new(5, 100, 3600);
tracker
.check_and_track("token1", "1.2.3.4".parse().unwrap())
.await
.unwrap();
tracker
.check_and_track("token2", "1.2.3.4".parse().unwrap())
.await
.unwrap();
let count1 = tracker.get_ip_count("token1").await;
let count2 = tracker.get_ip_count("token2").await;
assert_eq!(count1, 1);
assert_eq!(count2, 1); }
#[tokio::test]
async fn test_ttl_expiration() {
let tracker = IpTracker::new(5, 100, 1);
tracker
.check_and_track("token1", "1.2.3.4".parse().unwrap())
.await
.unwrap();
tracker
.check_and_track("token1", "1.2.3.5".parse().unwrap())
.await
.unwrap();
let count = tracker.get_ip_count("token1").await;
assert_eq!(count, 2);
tokio::time::sleep(Duration::from_secs(2)).await;
tracker
.check_and_track("token1", "1.2.3.6".parse().unwrap())
.await
.unwrap();
let count = tracker.get_ip_count("token1").await;
assert_eq!(count, 1); }
#[test]
fn test_normalize_ipv4() {
let ip: IpAddr = "192.0.2.1".parse().unwrap();
let normalized = normalize_dual_stack_ip(ip);
assert_eq!(normalized, ip);
}
#[test]
fn test_normalize_ipv6() {
let ip: IpAddr = "2001:db8::1".parse().unwrap();
let normalized = normalize_dual_stack_ip(ip);
assert_eq!(normalized, ip); }
#[test]
fn test_normalize_ipv4_mapped() {
let ip: IpAddr = "::ffff:192.0.2.1".parse().unwrap();
let normalized = normalize_dual_stack_ip(ip);
assert_eq!(normalized, "192.0.2.1".parse::<IpAddr>().unwrap());
}
}