use log::debug;
use radixtarget::{RadixTarget, ScopeMode};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::error::Error as StdError;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::sync::{Mutex, RwLock};
#[cfg(feature = "py")]
mod python;
const CLOUDCHECK_SIGNATURE_URL: &str = "https://raw.githubusercontent.com/blacklanternsecurity/cloudcheck/refs/heads/stable/cloud_providers_v3.json";
#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
pub struct CloudProvider {
pub name: String,
pub tags: Vec<String>,
#[serde(default)]
pub short_description: String,
#[serde(default)]
pub long_description: String,
}
#[derive(Debug, Deserialize)]
struct ProviderData {
name: String,
tags: Vec<String>,
cidrs: Vec<String>,
domains: Vec<String>,
#[serde(default)]
short_description: String,
#[serde(default)]
long_description: String,
}
type ProvidersMap = HashMap<String, Vec<CloudProvider>>;
type Error = Box<dyn std::error::Error + Send + Sync>;
#[derive(Clone)]
pub struct CloudCheck {
radix: Arc<RwLock<Option<RadixTarget>>>,
providers: Arc<RwLock<Option<ProvidersMap>>>,
last_fetch: Arc<Mutex<Option<SystemTime>>>,
signature_url: String,
max_retries: u32,
retry_delay_seconds: u64,
force_refresh: bool,
}
impl Default for CloudCheck {
fn default() -> Self {
Self::new()
}
}
impl CloudCheck {
pub fn new() -> Self {
Self::with_config(None, None, None, None)
}
pub fn with_config(
signature_url: Option<String>,
max_retries: Option<u32>,
retry_delay_seconds: Option<u64>,
force_refresh: Option<bool>,
) -> Self {
let url = signature_url
.or_else(|| std::env::var("CLOUDCHECK_SIGNATURE_URL").ok())
.unwrap_or_else(|| CLOUDCHECK_SIGNATURE_URL.to_string());
CloudCheck {
radix: Arc::new(RwLock::new(None)),
providers: Arc::new(RwLock::new(None)),
last_fetch: Arc::new(Mutex::new(None)),
signature_url: url,
max_retries: max_retries.unwrap_or(10),
retry_delay_seconds: retry_delay_seconds.unwrap_or(1),
force_refresh: force_refresh.unwrap_or(false),
}
}
fn get_cache_path() -> Result<PathBuf, Error> {
let home = std::env::var("HOME")?;
let mut path = PathBuf::from(home);
path.push(".cache");
path.push("cloudcheck");
path.push("cloud_providers_v3.json");
Ok(path)
}
async fn fetch_and_cache(&self, cache_path: &PathBuf) -> Result<String, Error> {
let url = &self.signature_url;
let max_retries = self.max_retries;
let retry_delay_seconds = self.retry_delay_seconds;
log::info!(
"Fetching data from URL: {} (max_retries={}, retry_delay={}s)",
url,
max_retries,
retry_delay_seconds
);
let mut last_error = None;
for attempt in 0..=max_retries {
log::info!("Fetch attempt {}/{}", attempt + 1, max_retries + 1);
let result = match reqwest::get(url).await {
Ok(response) => {
let status = response.status();
log::info!(
"HTTP response received, status: {} {}",
status.as_u16(),
status
);
if !status.is_success() {
let error_msg = format!("HTTP error: {} {}", status.as_u16(), status);
log::warn!("{}", error_msg);
Err(Box::new(std::io::Error::other(error_msg)) as Error)
} else {
response.text().await.map_err(|e| {
let error_msg = format!("Failed to read response body: {}", e);
log::warn!("{}", error_msg);
Box::new(std::io::Error::other(error_msg)) as Error
})
}
}
Err(e) => {
let error_type = if e.is_timeout() {
"timeout"
} else if e.is_connect() {
"connection"
} else if e.is_request() {
"request"
} else {
"unknown"
};
let mut error_details = format!("{}", e);
let mut current_source: Option<&(dyn StdError + 'static)> =
StdError::source(&e);
while let Some(source) = current_source {
error_details = format!("{}: {}", error_details, source);
current_source = source.source();
}
log::warn!(
"HTTP request failed ({} error): {}",
error_type,
error_details
);
Err(Box::new(e) as Error)
}
};
match result {
Ok(json_data) => {
log::info!("Fetched {} bytes from network", json_data.len());
if let Some(parent) = cache_path.parent() {
log::debug!("Creating cache directory: {:?}", parent);
tokio::fs::create_dir_all(parent).await?;
}
log::debug!("Writing cache file: {:?}", cache_path);
tokio::fs::write(cache_path, &json_data).await?;
log::info!("Cache file written successfully");
return Ok(json_data);
}
Err(e) => {
last_error = Some(e);
if attempt < max_retries {
log::warn!(
"Failed to fetch (attempt {}/{}), retrying in {} second(s): {}",
attempt + 1,
max_retries + 1,
retry_delay_seconds,
last_error.as_ref().unwrap()
);
tokio::time::sleep(tokio::time::Duration::from_secs(retry_delay_seconds))
.await;
} else {
log::error!(
"Failed to fetch after {} attempts: {}",
max_retries + 1,
last_error.as_ref().unwrap()
);
}
}
}
}
let final_error = last_error.unwrap_or_else(|| {
Box::new(std::io::Error::other("Failed to fetch data after retries"))
});
Err(Box::new(std::io::Error::other(format!(
"Failed to fetch cloud provider data from {} after {} attempts: {}",
url,
max_retries + 1,
final_error
))) as Error)
}
async fn get_last_fetch_time(&self, cache_path: &PathBuf) -> Result<Option<SystemTime>, Error> {
let last_fetch = self.last_fetch.lock().await;
match *last_fetch {
Some(time) => {
debug!("Using in-memory last_fetch timestamp: {:?}", time);
Ok(Some(time))
}
None => {
drop(last_fetch);
debug!(
"No in-memory timestamp, checking cache file modification time: {:?}",
cache_path
);
match tokio::fs::metadata(cache_path).await {
Ok(metadata) => {
if let Ok(modified) = metadata.modified() {
debug!("Cache file modification time: {:?}", modified);
Ok(Some(modified))
} else {
debug!("Cache file exists but modification time unavailable");
Ok(None)
}
}
Err(_) => {
debug!("Cache file does not exist: {:?}", cache_path);
Ok(None)
}
}
}
}
}
async fn load_json_data(
&self,
cache_path: &PathBuf,
needs_refresh: bool,
) -> Result<(String, bool), Error> {
if needs_refresh {
log::info!("Refresh needed, fetching from network");
let data = self.fetch_and_cache(cache_path).await?;
Ok((data, true))
} else {
log::info!("No refresh needed, loading from cache: {:?}", cache_path);
match tokio::fs::read_to_string(cache_path).await {
Ok(data) => {
debug!("Successfully loaded {} bytes from cache", data.len());
let now = SystemTime::now();
let mut last_fetch = self.last_fetch.lock().await;
if last_fetch.is_none() {
debug!("Setting in-memory last_fetch timestamp to current time");
*last_fetch = Some(now);
} else {
debug!(
"In-memory last_fetch timestamp already set, keeping existing value"
);
}
Ok((data, false))
}
Err(e) => {
log::warn!(
"Failed to read cache file ({}), falling back to network fetch",
e
);
let data = self.fetch_and_cache(cache_path).await?;
Ok((data, true))
}
}
}
}
fn build_data_structures(json_data: &str) -> Result<(RadixTarget, ProvidersMap), Error> {
let providers_data: HashMap<String, ProviderData> = serde_json::from_str(json_data)?;
let mut radix = RadixTarget::new(&[], ScopeMode::Acl)?;
let mut providers_map: ProvidersMap = HashMap::new();
for _ in 0..2 {
for provider in providers_data.values() {
let cloud_provider = CloudProvider {
name: provider.name.clone(),
tags: provider.tags.clone(),
short_description: provider.short_description.clone(),
long_description: provider.long_description.clone(),
};
for cidr in &provider.cidrs {
let normalized = match radix.get(cidr) {
Some(n) => n,
None => match radix.insert(cidr) {
Ok(Some(n)) => n,
Ok(None) => continue,
Err(e) => {
log::warn!("Error inserting CIDR '{}': {}", cidr, e);
continue;
}
},
};
let providers_list = providers_map.entry(normalized.clone()).or_default();
if !providers_list.iter().any(|p| p.name == cloud_provider.name) {
providers_list.push(cloud_provider.clone());
}
}
for domain in &provider.domains {
let normalized = match radix.get(domain) {
Some(n) => n,
None => match radix.insert(domain) {
Ok(Some(n)) => n,
Ok(None) => continue,
Err(e) => {
log::warn!("Error inserting domain '{}': {}", domain, e);
continue;
}
},
};
let providers_list = providers_map.entry(normalized.clone()).or_default();
if !providers_list.iter().any(|p| p.name == cloud_provider.name) {
providers_list.push(cloud_provider.clone());
}
}
}
}
Ok((radix, providers_map))
}
async fn ensure_loaded(&self) -> Result<(), Error> {
let cache_valid_duration = Duration::from_secs(24 * 60 * 60);
let now = SystemTime::now();
let cache_path = Self::get_cache_path()?;
log::info!("ensure_loaded: checking cache at {:?}", cache_path);
let last_fetch_time = self.get_last_fetch_time(&cache_path).await?;
let needs_refresh = if self.force_refresh {
debug!("force_refresh is enabled, needs_refresh=true");
true
} else {
match last_fetch_time {
Some(fetch_time) => {
let elapsed = now.duration_since(fetch_time).ok();
let needs = elapsed.map(|e| e >= cache_valid_duration).unwrap_or(true);
if let Some(e) = elapsed {
debug!("Time since last fetch: {:?}, needs_refresh={}", e, needs);
} else {
debug!("Could not calculate duration since last fetch, needs_refresh=true");
}
needs
}
None => {
debug!("No last_fetch_time available, needs_refresh=true");
true
}
}
};
{
let radix_guard = self.radix.read().await;
if radix_guard.is_some() && !needs_refresh {
log::info!("Data already loaded and fresh, returning early");
return Ok(());
}
log::info!("Data not loaded or needs refresh, proceeding to load");
}
let (json_data, fetched_fresh) = self.load_json_data(&cache_path, needs_refresh).await?;
log::info!(
"Loaded JSON data, fetched_fresh={}, building data structures",
fetched_fresh
);
let (radix, providers_map) =
tokio::task::spawn_blocking(move || Self::build_data_structures(&json_data)).await??;
debug!("Built data structures: radix tree and providers map");
{
let mut radix_guard = self.radix.write().await;
*radix_guard = Some(radix);
debug!("Updated radix tree in memory");
}
{
let mut providers_guard = self.providers.write().await;
*providers_guard = Some(providers_map);
debug!("Updated providers map in memory");
}
if fetched_fresh {
let mut last_fetch = self.last_fetch.lock().await;
*last_fetch = Some(now);
debug!("Updated in-memory last_fetch timestamp to {:?}", now);
}
Ok(())
}
pub async fn lookup(&self, target: &str) -> Result<Vec<CloudProvider>, Error> {
log::info!("lookup called for target: {}", target);
match self.ensure_loaded().await {
Ok(()) => log::debug!("ensure_loaded succeeded"),
Err(e) => {
log::error!("ensure_loaded failed: {}", e);
return Err(e);
}
}
let radix_guard = self.radix.read().await;
let providers_guard = self.providers.read().await;
let radix = radix_guard.as_ref().unwrap();
let providers = providers_guard.as_ref().unwrap();
if let Some(normalized) = radix.get(target) {
debug!("Found normalized target: {} for {}", normalized, target);
let result = providers.get(&normalized).cloned().unwrap_or_default();
debug!("Returning {} providers", result.len());
Ok(result)
} else {
debug!("No match found for target: {}", target);
Ok(Vec::new())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_lookup_google_dns() {
let cloudcheck = CloudCheck::new();
let results = cloudcheck.lookup("8.8.8.8").await.unwrap();
let names: Vec<String> = results.iter().map(|p| p.name.clone()).collect();
assert!(
names.contains(&"Google".to_string()),
"Expected Google in results: {:?}",
names
);
}
#[tokio::test]
async fn test_lookup_amazon_domain() {
let cloudcheck = CloudCheck::new();
let results = cloudcheck.lookup("asdf.amazon.com").await.unwrap();
let names: Vec<String> = results.iter().map(|p| p.name.clone()).collect();
assert!(
names.contains(&"Amazon".to_string()),
"Expected Amazon in results: {:?}",
names
);
}
#[tokio::test]
async fn test_lookup_windows_blob_domain() {
let cloudcheck = CloudCheck::new();
let results = cloudcheck
.lookup("asdf.blob.core.windows.net")
.await
.unwrap();
let names: Vec<String> = results.iter().map(|p| p.name.clone()).collect();
assert!(
names.contains(&"GitHub".to_string()),
"Expected GitHub in results: {:?}",
names
);
assert!(
names.contains(&"Microsoft".to_string()),
"Expected Microsoft in results: {:?}",
names
);
}
}