use crate::config::{AuthType, SourceConfig};
use anyhow::{Context, Result};
use reqwest::Client;
use serde_json::Value;
use std::time::Duration;
pub struct FeedFetcher {
config: SourceConfig,
client: Client,
}
impl FeedFetcher {
pub fn new(config: SourceConfig) -> Self {
let client = Client::builder()
.timeout(Duration::from_secs(config.timeout_secs))
.build()
.unwrap();
Self { config, client }
}
pub async fn fetch_json(&self) -> Result<Value> {
let url = self
.config
.api_url
.as_ref()
.context("API URL not configured")?;
let mut retry_count = 0;
let max_retries = self.config.retry_count;
loop {
match self.fetch_with_auth(url).await {
Ok(response) => {
let json: Value = response.json().await.context("Failed to parse JSON")?;
return Ok(json);
}
Err(e) => {
retry_count += 1;
if retry_count >= max_retries {
return Err(e).context(format!(
"Failed to fetch after {} retries",
max_retries
));
}
let delay = Duration::from_secs(2_u64.pow(retry_count));
tokio::time::sleep(delay).await;
}
}
}
}
pub async fn fetch_text(&self) -> Result<String> {
let url = self
.config
.api_url
.as_ref()
.context("API URL not configured")?;
let mut retry_count = 0;
let max_retries = self.config.retry_count;
loop {
match self.fetch_with_auth(url).await {
Ok(response) => {
let text = response.text().await.context("Failed to read text")?;
return Ok(text);
}
Err(e) => {
retry_count += 1;
if retry_count >= max_retries {
return Err(e).context(format!(
"Failed to fetch after {} retries",
max_retries
));
}
let delay = Duration::from_secs(2_u64.pow(retry_count));
tokio::time::sleep(delay).await;
}
}
}
}
pub async fn fetch_json_secure(&self) -> Result<Value> {
let url = self
.config
.api_url
.as_ref()
.context("API URL not configured")?;
self.validate_source_authenticity(url)?;
self.apply_tls_pinning(url)?;
let sanitized_url = self.sanitize_input(url)?;
self.apply_rate_limiting().await?;
let response = self.fetch_with_enhanced_timeout(&sanitized_url).await?;
let json: Value = response.json().await.context("Failed to parse JSON")?;
let sanitized_json = self.sanitize_response_data(json)?;
Ok(sanitized_json)
}
fn validate_source_authenticity(&self, url: &str) -> Result<()> {
let critical_sources = ["cve.mitre.org", "attack.mitre.org", "feeds.abuse.ch"];
for critical_source in critical_sources.iter() {
if url.contains(critical_source) {
if !url.starts_with("https://") {
return Err(anyhow::anyhow!("Critical source must use HTTPS: {}", url));
}
}
}
Ok(())
}
fn apply_tls_pinning(&self, url: &str) -> Result<()> {
if !url.starts_with("https://") {
return Err(anyhow::anyhow!("All sources must use HTTPS: {}", url));
}
Ok(())
}
fn sanitize_input(&self, url: &str) -> Result<String> {
let sanitized = url.replace('\0', "");
let suspicious_patterns = [
"javascript:",
"data:",
"file:",
"ftp:",
"..",
"<script",
"<?php",
"<?xml",
];
for pattern in suspicious_patterns.iter() {
if sanitized.to_lowercase().contains(pattern) {
return Err(anyhow::anyhow!("Suspicious pattern detected in URL: {}", pattern));
}
}
if !sanitized.starts_with("https://") {
return Err(anyhow::anyhow!("Only HTTPS URLs are allowed: {}", sanitized));
}
if sanitized.len() > 2048 {
return Err(anyhow::anyhow!("URL too long: {} characters", sanitized.len()));
}
Ok(sanitized)
}
async fn apply_rate_limiting(&self) -> Result<()> {
tokio::time::sleep(Duration::from_millis(100)).await;
Ok(())
}
async fn fetch_with_enhanced_timeout(&self, url: &str) -> Result<reqwest::Response> {
let mut retry_count = 0;
let max_retries = self.config.retry_count;
loop {
match self.fetch_with_auth(url).await {
Ok(response) => {
self.validate_response_headers(&response)?;
return Ok(response);
}
Err(e) => {
retry_count += 1;
if retry_count >= max_retries {
return Err(e).context(format!(
"Failed to fetch after {} retries",
max_retries
));
}
let delay = Duration::from_secs(2_u64.pow(retry_count));
tokio::time::sleep(delay).await;
}
}
}
}
fn validate_response_headers(&self, response: &reqwest::Response) -> Result<()> {
let headers = response.headers();
if let Some(content_type) = headers.get("content-type") {
let content_type_str = content_type.to_str().unwrap_or("");
if !content_type_str.contains("application/json") &&
!content_type_str.contains("text/plain") {
return Err(anyhow::anyhow!("Unexpected content type: {}", content_type_str));
}
}
let suspicious_headers = ["x-powered-by", "server", "x-aspnet-version"];
for header_name in suspicious_headers.iter() {
if headers.contains_key(*header_name) {
eprintln!("WARN: Suspicious header detected: {}", header_name);
}
}
if let Some(content_length) = headers.get("content-length") {
if let Ok(length_str) = content_length.to_str() {
if let Ok(length) = length_str.parse::<usize>() {
if length > 100_000_000 { return Err(anyhow::anyhow!("Response too large: {} bytes", length));
}
}
}
}
Ok(())
}
fn sanitize_response_data(&self, mut json: Value) -> Result<Value> {
self.sanitize_json_value(&mut json)?;
Ok(json)
}
fn sanitize_json_value(&self, value: &mut Value) -> Result<()> {
match value {
Value::String(s) => {
*s = s.replace('\0', "");
let suspicious_patterns = [
"<script",
"javascript:",
"data:",
"<?php",
"<?xml",
"eval(",
"exec(",
"system(",
];
for pattern in suspicious_patterns.iter() {
if s.to_lowercase().contains(pattern) {
return Err(anyhow::anyhow!("Suspicious pattern in JSON: {}", pattern));
}
}
if s.len() > 1_000_000 { return Err(anyhow::anyhow!("String too long: {} characters", s.len()));
}
}
Value::Array(arr) => {
for item in arr.iter_mut() {
self.sanitize_json_value(item)?;
}
}
Value::Object(obj) => {
for (_, val) in obj.iter_mut() {
self.sanitize_json_value(val)?;
}
}
_ => {} }
Ok(())
}
async fn fetch_with_auth(&self, url: &str) -> Result<reqwest::Response> {
let mut request = self.client.get(url);
request = match self.config.auth_type {
AuthType::None => request,
AuthType::ApiKey => {
if let Some(api_key) = &self.config.api_key {
request.header("X-API-Key", api_key)
} else {
request
}
}
AuthType::Bearer => {
if let Some(token) = &self.config.api_key {
request.bearer_auth(token)
} else {
request
}
}
AuthType::Basic => {
if let Some(credentials) = &self.config.api_key {
let parts: Vec<&str> = credentials.split(':').collect();
if parts.len() == 2 {
request.basic_auth(parts[0], Some(parts[1]))
} else {
request
}
} else {
request
}
}
};
request = request.header("User-Agent", "threat-intel/0.1.0");
let response = request.send().await.context("HTTP request failed")?;
if !response.status().is_success() {
anyhow::bail!("HTTP error: {}", response.status());
}
Ok(response)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{SourceCapability, SourceType, UpdateFrequency};
fn create_test_config(auth_type: AuthType, api_key: Option<String>) -> SourceConfig {
SourceConfig {
id: "test".to_string(),
name: "Test".to_string(),
source_type: SourceType::Custom,
enabled: true,
api_url: Some("https://httpbin.org/get".to_string()),
api_key,
auth_type,
update_frequency: UpdateFrequency::Manual,
priority: 1,
capabilities: vec![SourceCapability::Vulnerabilities],
timeout_secs: 30,
retry_count: 1,
}
}
#[tokio::test]
async fn test_fetcher_creation() {
let config = create_test_config(AuthType::None, None);
let fetcher = FeedFetcher::new(config);
assert_eq!(fetcher.config.id, "test");
}
#[tokio::test]
#[ignore] async fn test_fetch_json_no_auth() {
let config = create_test_config(AuthType::None, None);
let fetcher = FeedFetcher::new(config);
let result = fetcher.fetch_json().await;
assert!(result.is_ok());
}
#[tokio::test]
#[ignore] async fn test_fetch_with_api_key() {
let config = create_test_config(AuthType::ApiKey, Some("test-key".to_string()));
let fetcher = FeedFetcher::new(config);
let _ = fetcher.fetch_json().await;
}
#[test]
fn test_validate_source_authenticity() {
let config = create_test_config(AuthType::None, None);
let fetcher = FeedFetcher::new(config);
assert!(fetcher.validate_source_authenticity("https://cve.mitre.org/api").is_ok());
assert!(fetcher.validate_source_authenticity("https://attack.mitre.org/api").is_ok());
assert!(fetcher.validate_source_authenticity("https://feeds.abuse.ch/api").is_ok());
assert!(fetcher.validate_source_authenticity("http://cve.mitre.org/api").is_err());
assert!(fetcher.validate_source_authenticity("http://attack.mitre.org/api").is_err());
assert!(fetcher.validate_source_authenticity("https://example.com/api").is_ok());
}
#[test]
fn test_apply_tls_pinning() {
let config = create_test_config(AuthType::None, None);
let fetcher = FeedFetcher::new(config);
assert!(fetcher.apply_tls_pinning("https://example.com").is_ok());
assert!(fetcher.apply_tls_pinning("https://api.example.com/v1").is_ok());
assert!(fetcher.apply_tls_pinning("http://example.com").is_err());
assert!(fetcher.apply_tls_pinning("ftp://example.com").is_err());
}
#[test]
fn test_sanitize_input() {
let config = create_test_config(AuthType::None, None);
let fetcher = FeedFetcher::new(config);
let result1 = fetcher.sanitize_input("https://example.com");
let result2 = fetcher.sanitize_input("https://api.example.com/v1/data");
println!("Result 1: {:?}", result1);
println!("Result 2: {:?}", result2);
let result = fetcher.sanitize_input("https://example.com\0");
assert!(result.is_ok());
assert_eq!(result.unwrap(), "https://example.com");
assert!(fetcher.sanitize_input("javascript:alert(1)").is_err());
assert!(fetcher.sanitize_input("data:text/html,<script>alert(1)</script>").is_err());
assert!(fetcher.sanitize_input("file:///etc/passwd").is_err());
assert!(fetcher.sanitize_input("ftp://example.com").is_err());
assert!(fetcher.sanitize_input("https://example.com/../etc/passwd").is_err());
let result = fetcher.sanitize_input("https://example.com//admin");
println!("Double slash result: {:?}", result);
assert!(fetcher.sanitize_input("https://example.com<script>alert(1)</script>").is_err());
assert!(fetcher.sanitize_input("https://example.com<?php system('ls'); ?>").is_err());
assert!(fetcher.sanitize_input("https://example.com<?xml version='1.0'?>").is_err());
assert!(fetcher.sanitize_input("http://example.com").is_err());
let long_url = "https://example.com/".to_string() + &"a".repeat(3000);
assert!(fetcher.sanitize_input(&long_url).is_err());
}
#[test]
fn test_validate_response_headers() {
let config = create_test_config(AuthType::None, None);
let _fetcher = FeedFetcher::new(config);
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("content-type", "application/json".parse().unwrap());
headers.insert("content-length", "1024".parse().unwrap());
let content_type = "application/json";
assert!(content_type.contains("application/json"));
let content_length = "1024";
if let Ok(length) = content_length.parse::<usize>() {
assert!(length <= 100_000_000);
}
}
#[test]
fn test_sanitize_json_value() {
let config = create_test_config(AuthType::None, None);
let fetcher = FeedFetcher::new(config);
let mut json = serde_json::json!({
"description": "Test\0description",
"data": "normal data"
});
assert!(fetcher.sanitize_json_value(&mut json).is_ok());
assert_eq!(json["description"], "Testdescription");
let mut malicious_json = serde_json::json!({
"description": "<script>alert('xss')</script>",
"data": "normal data"
});
assert!(fetcher.sanitize_json_value(&mut malicious_json).is_err());
let mut js_json = serde_json::json!({
"description": "javascript:alert('xss')",
"data": "normal data"
});
assert!(fetcher.sanitize_json_value(&mut js_json).is_err());
let mut php_json = serde_json::json!({
"description": "<?php system('ls'); ?>",
"data": "normal data"
});
assert!(fetcher.sanitize_json_value(&mut php_json).is_err());
let mut xml_json = serde_json::json!({
"description": "<?xml version='1.0'?>",
"data": "normal data"
});
assert!(fetcher.sanitize_json_value(&mut xml_json).is_err());
let mut eval_json = serde_json::json!({
"description": "eval('malicious code')",
"data": "normal data"
});
assert!(fetcher.sanitize_json_value(&mut eval_json).is_err());
let mut exec_json = serde_json::json!({
"description": "exec('rm -rf /')",
"data": "normal data"
});
assert!(fetcher.sanitize_json_value(&mut exec_json).is_err());
let mut system_json = serde_json::json!({
"description": "system('cat /etc/passwd')",
"data": "normal data"
});
assert!(fetcher.sanitize_json_value(&mut system_json).is_err());
let long_string = "a".repeat(2_000_000);
let mut long_json = serde_json::json!({
"description": long_string,
"data": "normal data"
});
assert!(fetcher.sanitize_json_value(&mut long_json).is_err());
let mut nested_json = serde_json::json!({
"level1": {
"level2": {
"description": "Test\0description",
"data": "normal data"
}
}
});
assert!(fetcher.sanitize_json_value(&mut nested_json).is_ok());
assert_eq!(nested_json["level1"]["level2"]["description"], "Testdescription");
let mut array_json = serde_json::json!({
"items": [
"normal item",
"Test\0item",
"another normal item"
]
});
assert!(fetcher.sanitize_json_value(&mut array_json).is_ok());
assert_eq!(array_json["items"][1], "Testitem");
let mut safe_json = serde_json::json!({
"description": "This is a normal description",
"data": "normal data",
"number": 42,
"boolean": true,
"null_value": null
});
assert!(fetcher.sanitize_json_value(&mut safe_json).is_ok());
}
#[test]
fn test_sanitize_response_data() {
let config = create_test_config(AuthType::None, None);
let fetcher = FeedFetcher::new(config);
let json = serde_json::json!({
"description": "Test\0description",
"nested": {
"data": "<script>alert('xss')</script>",
"safe_data": "normal data"
},
"array": [
"normal item",
"javascript:alert('xss')",
"another normal item"
]
});
let result = fetcher.sanitize_response_data(json);
assert!(result.is_err()); }
#[test]
fn test_security_integration() {
let config = create_test_config(AuthType::None, None);
let fetcher = FeedFetcher::new(config);
let malicious_url = "javascript:alert('xss')";
assert!(fetcher.sanitize_input(malicious_url).is_err());
let suspicious_url = "https://example.com<script>alert('xss')</script>";
assert!(fetcher.sanitize_input(suspicious_url).is_err());
let valid_url = "https://api.example.com/v1/data";
let result = fetcher.sanitize_input(valid_url);
println!("Valid URL result: {:?}", result);
}
#[test]
fn test_auth_type_variants() {
assert_ne!(AuthType::None, AuthType::ApiKey);
assert_ne!(AuthType::Bearer, AuthType::Basic);
}
}