use std::collections::HashSet;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tracing::{debug, warn};
use crate::errors::AppError;
use crate::services::settings_service::SettingsService;
#[derive(Debug, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
struct SanctionsListResponse {
addresses: Vec<String>,
countries: Vec<String>,
}
struct SanctionsCache {
addresses: HashSet<String>,
countries: HashSet<String>,
fetched_at: Instant,
}
pub struct SanctionsStats {
pub address_count: usize,
pub country_count: usize,
pub last_refresh: Option<Instant>,
pub cache_age_secs: Option<u64>,
}
const HTTP_TIMEOUT_SECS: u64 = 10;
const MIN_REFRESH_INTERVAL_SECS: u64 = 60;
const DEFAULT_REFRESH_INTERVAL_SECS: u64 = 3600;
pub struct SanctionsService {
client: reqwest::Client,
cache: Arc<RwLock<Option<SanctionsCache>>>,
settings_service: Arc<SettingsService>,
}
impl SanctionsService {
pub fn new(settings_service: Arc<SettingsService>) -> Self {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(HTTP_TIMEOUT_SECS))
.build()
.unwrap_or_else(|e| {
tracing::error!(
error = %e,
"Failed to build HTTP client for SanctionsService; falling back to default"
);
reqwest::Client::new()
});
Self {
client,
cache: Arc::new(RwLock::new(None)),
settings_service,
}
}
pub async fn refresh(&self) -> Result<(), AppError> {
let api_url = self
.settings_service
.get("sanctions_api_url")
.await?
.unwrap_or_default();
if api_url.is_empty() {
return Err(AppError::Config("sanctions_api_url not configured".into()));
}
let url = format!("{}/v1/lists", api_url.trim_end_matches('/'));
let response = self.client.get(&url).send().await.map_err(|e| {
AppError::Internal(anyhow::anyhow!("Sanctions API request failed: {}", e))
})?;
if !response.status().is_success() {
return Err(AppError::Internal(anyhow::anyhow!(
"Sanctions API returned status {}",
response.status()
)));
}
let body: SanctionsListResponse = response.json().await.map_err(|e| {
AppError::Internal(anyhow::anyhow!("Failed to parse sanctions response: {}", e))
})?;
let entry = SanctionsCache {
addresses: body.addresses.into_iter().collect(),
countries: body
.countries
.into_iter()
.map(|c| c.to_uppercase())
.collect(),
fetched_at: Instant::now(),
};
let addr_count = entry.addresses.len();
let country_count = entry.countries.len();
{
let mut cache = self.cache.write().await;
*cache = Some(entry);
}
debug!(
addresses = addr_count,
countries = country_count,
"Sanctions list refreshed"
);
Ok(())
}
pub async fn ensure_fresh(&self) {
let ttl_secs = self
.settings_service
.get_u64("sanctions_refresh_interval_secs")
.await
.ok()
.flatten()
.unwrap_or(DEFAULT_REFRESH_INTERVAL_SECS)
.max(MIN_REFRESH_INTERVAL_SECS);
let is_stale = {
let cache = self.cache.read().await;
match cache.as_ref() {
Some(entry) => entry.fetched_at.elapsed() >= Duration::from_secs(ttl_secs),
None => true,
}
};
if is_stale {
if let Err(e) = self.refresh().await {
warn!(error = %e, "Failed to refresh sanctions list; keeping stale cache");
}
}
}
pub async fn is_address_sanctioned(&self, address: &str) -> bool {
let enabled = self
.settings_service
.get_bool("sanctions_enabled")
.await
.ok()
.flatten()
.unwrap_or(false);
if !enabled {
return false;
}
let cache = self.cache.read().await;
match cache.as_ref() {
Some(entry) => entry.addresses.contains(address),
None => false,
}
}
pub async fn check_address(&self, address: &str) -> Result<(), AppError> {
self.ensure_fresh().await;
if self.is_address_sanctioned(address).await {
return Err(AppError::Forbidden(
"Transaction blocked: sanctioned address".into(),
));
}
Ok(())
}
pub async fn is_country_sanctioned(&self, country_code: &str) -> bool {
let enabled = self
.settings_service
.get_bool("sanctions_enabled")
.await
.ok()
.flatten()
.unwrap_or(false);
if !enabled {
return false;
}
let upper = country_code.to_uppercase();
let cache = self.cache.read().await;
match cache.as_ref() {
Some(entry) => entry.countries.contains(&upper),
None => false,
}
}
pub async fn check_country_from_request(
&self,
headers: &axum::http::HeaderMap,
) -> Result<(), AppError> {
let custom_header = self
.settings_service
.get("sanctions_geoip_header")
.await
.ok()
.flatten()
.unwrap_or_default();
let country_code = if !custom_header.is_empty() {
headers
.get(custom_header.as_str())
.and_then(|v| v.to_str().ok())
.map(|s| s.trim().to_uppercase())
} else {
None
}
.or_else(|| {
headers
.get("cf-ipcountry")
.and_then(|v| v.to_str().ok())
.map(|s| s.trim().to_uppercase())
})
.or_else(|| {
headers
.get("x-country-code")
.and_then(|v| v.to_str().ok())
.map(|s| s.trim().to_uppercase())
});
let country_code = match country_code {
Some(cc) if cc.len() == 2 => cc,
_ => return Ok(()), };
self.ensure_fresh().await;
if self.is_country_sanctioned(&country_code).await {
tracing::warn!(
country_code = %country_code,
"Request blocked: sanctioned country"
);
return Err(AppError::Forbidden(
"Access blocked: restricted country".into(),
));
}
Ok(())
}
pub async fn stats(&self) -> SanctionsStats {
let guard = self.cache.read().await;
match guard.as_ref() {
Some(entry) => SanctionsStats {
address_count: entry.addresses.len(),
country_count: entry.countries.len(),
last_refresh: Some(entry.fetched_at),
cache_age_secs: Some(entry.fetched_at.elapsed().as_secs()),
},
None => SanctionsStats {
address_count: 0,
country_count: 0,
last_refresh: None,
cache_age_secs: None,
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::Storage;
fn make_service() -> SanctionsService {
let storage = Storage::in_memory();
let settings_service = Arc::new(SettingsService::new(storage.system_settings_repo));
SanctionsService::new(settings_service)
}
async fn seed_cache(service: &SanctionsService, addresses: Vec<String>) {
let mut cache = service.cache.write().await;
*cache = Some(SanctionsCache {
addresses: addresses.into_iter().collect(),
countries: ["KP".to_string(), "IR".to_string()].into_iter().collect(),
fetched_at: Instant::now(),
});
}
#[tokio::test]
async fn address_found_in_cache_is_sanctioned() {
let svc = make_service();
seed_cache(&svc, vec!["BadWallet111".to_string()]).await;
let result = svc.check_address("BadWallet111").await;
assert!(result.is_ok(), "disabled service should pass any address");
}
#[tokio::test]
async fn address_not_in_cache_is_not_sanctioned() {
let svc = make_service();
seed_cache(&svc, vec!["BadWallet111".to_string()]).await;
let cache = svc.cache.read().await;
let entry = cache.as_ref().unwrap();
assert!(!entry.addresses.contains("CleanWallet999"));
}
#[tokio::test]
async fn countries_stored_uppercase() {
let svc = make_service();
let entry = SanctionsCache {
addresses: HashSet::new(),
countries: ["kp", "ir", "Cuba"]
.iter()
.map(|c| c.to_uppercase())
.collect(),
fetched_at: Instant::now(),
};
{
let mut cache = svc.cache.write().await;
*cache = Some(entry);
}
let guard = svc.cache.read().await;
let stored = &guard.as_ref().unwrap().countries;
assert!(stored.contains("KP"));
assert!(stored.contains("IR"));
assert!(stored.contains("CUBA"));
assert!(!stored.contains("kp"));
assert!(!stored.contains("Cuba"));
}
#[tokio::test]
async fn stats_returns_none_when_cache_empty() {
let svc = make_service();
let stats = svc.stats().await;
assert_eq!(stats.address_count, 0);
assert!(stats.last_refresh.is_none());
assert!(stats.cache_age_secs.is_none());
}
#[tokio::test]
async fn stats_returns_counts_when_cache_seeded() {
let svc = make_service();
seed_cache(&svc, vec!["Wallet1".to_string(), "Wallet2".to_string()]).await;
let stats = svc.stats().await;
assert_eq!(stats.address_count, 2);
assert_eq!(stats.country_count, 2); assert!(stats.last_refresh.is_some());
assert!(stats.cache_age_secs.is_some());
}
}