use crate::PeerId;
use crate::error::{P2PError, P2pResult};
use std::collections::HashMap;
use std::net::IpAddr;
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024; const MAX_PATH_LENGTH: usize = 4096;
const MAX_KEY_SIZE: usize = 1024 * 1024; const MAX_VALUE_SIZE: usize = 10 * 1024 * 1024; #[allow(dead_code)]
const MAX_FILE_NAME_LENGTH: usize = 255;
const DEFAULT_RATE_LIMIT_WINDOW: Duration = Duration::from_secs(60);
const DEFAULT_MAX_REQUESTS_PER_WINDOW: u32 = 1000;
const DEFAULT_BURST_SIZE: u32 = 100;
#[derive(Debug, Error)]
pub enum ValidationError {
#[error("Invalid peer ID format: {0}")]
InvalidPeerId(String),
#[error("Invalid network address: {0}")]
InvalidAddress(String),
#[error("Message size exceeds limit: {size} > {limit}")]
MessageTooLarge { size: usize, limit: usize },
#[error("Invalid file path: {0}")]
InvalidPath(String),
#[error("Path traversal attempt detected: {0}")]
PathTraversal(String),
#[error("Invalid key size: {size} bytes (max: {max})")]
InvalidKeySize { size: usize, max: usize },
#[error("Invalid value size: {size} bytes (max: {max})")]
InvalidValueSize { size: usize, max: usize },
#[error("Invalid cryptographic parameter: {0}")]
InvalidCryptoParam(String),
#[error("Rate limit exceeded for {identifier}")]
RateLimitExceeded { identifier: String },
#[error("Invalid format: {0}")]
InvalidFormat(String),
#[error("Value out of range: {value} (min: {min}, max: {max})")]
OutOfRange { value: i64, min: i64, max: i64 },
}
impl From<ValidationError> for P2PError {
fn from(err: ValidationError) -> Self {
P2PError::Validation(err.to_string().into())
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct ValidationContext {
pub max_message_size: usize,
pub max_key_size: usize,
pub max_value_size: usize,
pub max_path_length: usize,
pub allow_localhost: bool,
pub allow_private_ips: bool,
pub rate_limiter: Option<Arc<RateLimiter>>,
}
impl Default for ValidationContext {
fn default() -> Self {
Self {
max_message_size: MAX_MESSAGE_SIZE,
max_key_size: MAX_KEY_SIZE,
max_value_size: MAX_VALUE_SIZE,
max_path_length: MAX_PATH_LENGTH,
allow_localhost: false,
allow_private_ips: false,
rate_limiter: None,
}
}
}
#[allow(dead_code)]
impl ValidationContext {
pub fn new() -> Self {
Self::default()
}
pub fn with_rate_limiting(mut self, limiter: Arc<RateLimiter>) -> Self {
self.rate_limiter = Some(limiter);
self
}
pub fn allow_localhost(mut self) -> Self {
self.allow_localhost = true;
self
}
pub fn allow_private_ips(mut self) -> Self {
self.allow_private_ips = true;
self
}
}
#[allow(dead_code)]
pub trait Validate {
fn validate(&self, ctx: &ValidationContext) -> P2pResult<()>;
}
#[allow(dead_code)]
pub trait Sanitize {
fn sanitize(&self) -> Self;
}
#[allow(dead_code)]
pub fn validate_peer_id(_peer_id: &PeerId) -> P2pResult<()> {
Ok(())
}
#[allow(dead_code)]
pub fn validate_message_size(size: usize, max_size: usize) -> P2pResult<()> {
if size > max_size {
return Err(ValidationError::MessageTooLarge {
size,
limit: max_size,
}
.into());
}
Ok(())
}
#[allow(dead_code)]
pub fn validate_file_path(path: &Path) -> P2pResult<()> {
let path_str = path.to_string_lossy();
if path_str.len() > MAX_PATH_LENGTH {
return Err(ValidationError::InvalidPath(format!(
"Path too long: {} > {}",
path_str.len(),
MAX_PATH_LENGTH
))
.into());
}
let decoded = path_str
.replace("%2e", ".")
.replace("%2f", "/")
.replace("%5c", "\\");
let traversal_patterns = ["../", "..\\", "..", "..;", "....//", "%2e%2e", "%252e%252e"];
for pattern in &traversal_patterns {
if path_str.contains(pattern) || decoded.contains(pattern) {
return Err(ValidationError::PathTraversal(path_str.to_string()).into());
}
}
if path_str.contains('\0') {
return Err(ValidationError::InvalidPath("Path contains null bytes".to_string()).into());
}
let dangerous_chars = ['|', '&', ';', '$', '`', '\n'];
if path_str.chars().any(|c| dangerous_chars.contains(&c)) {
return Err(
ValidationError::InvalidPath("Path contains dangerous characters".to_string()).into(),
);
}
for component in path.components() {
if let Some(name) = component.as_os_str().to_str() {
if name.len() > MAX_FILE_NAME_LENGTH {
return Err(ValidationError::InvalidPath(format!(
"Component '{}' exceeds maximum length",
name
))
.into());
}
if name.contains('\0') {
return Err(ValidationError::InvalidPath(format!(
"Component '{}' contains invalid characters",
name
))
.into());
}
}
}
Ok(())
}
#[allow(dead_code)]
pub fn validate_key_size(size: usize, expected: usize) -> P2pResult<()> {
if size != expected {
return Err(ValidationError::InvalidCryptoParam(format!(
"Invalid key size: expected {} bytes, got {}",
expected, size
))
.into());
}
Ok(())
}
#[allow(dead_code)]
pub fn validate_nonce_size(size: usize, expected: usize) -> P2pResult<()> {
if size != expected {
return Err(ValidationError::InvalidCryptoParam(format!(
"Invalid nonce size: expected {} bytes, got {}",
expected, size
))
.into());
}
Ok(())
}
#[allow(dead_code)]
pub fn validate_dht_key(key: &[u8], ctx: &ValidationContext) -> P2pResult<()> {
if key.is_empty() {
return Err(ValidationError::InvalidFormat("DHT key cannot be empty".to_string()).into());
}
if key.len() > ctx.max_key_size {
return Err(ValidationError::InvalidKeySize {
size: key.len(),
max: ctx.max_key_size,
}
.into());
}
Ok(())
}
#[allow(dead_code)]
pub fn validate_dht_value(value: &[u8], ctx: &ValidationContext) -> P2pResult<()> {
if value.len() > ctx.max_value_size {
return Err(ValidationError::InvalidValueSize {
size: value.len(),
max: ctx.max_value_size,
}
.into());
}
Ok(())
}
#[derive(Debug)]
pub struct RateLimiter {
engine: crate::rate_limit::SharedEngine<IpAddr>,
#[allow(dead_code)]
config: RateLimitConfig,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct RateLimitConfig {
pub window: Duration,
pub max_requests: u32,
pub burst_size: u32,
pub adaptive: bool,
pub cleanup_interval: Duration,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
window: DEFAULT_RATE_LIMIT_WINDOW,
max_requests: DEFAULT_MAX_REQUESTS_PER_WINDOW,
burst_size: DEFAULT_BURST_SIZE,
adaptive: true,
cleanup_interval: Duration::from_secs(300), }
}
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
let engine_cfg = crate::rate_limit::EngineConfig {
window: config.window,
max_requests: config.max_requests,
burst_size: config.burst_size,
};
Self {
engine: std::sync::Arc::new(crate::rate_limit::Engine::new(engine_cfg)),
config,
}
}
pub fn check_ip(&self, ip: &IpAddr) -> P2pResult<()> {
if !self.engine.try_consume_global() {
return Err(ValidationError::RateLimitExceeded {
identifier: "global".to_string(),
}
.into());
}
if !self.engine.try_consume_key(ip) {
return Err(ValidationError::RateLimitExceeded {
identifier: ip.to_string(),
}
.into());
}
Ok(())
}
#[allow(dead_code)]
pub fn cleanup(&self) {
}
}
#[derive(Debug)]
#[allow(dead_code)]
pub struct NetworkMessage {
pub peer_id: PeerId,
pub payload: Vec<u8>,
pub timestamp: u64,
}
impl Validate for NetworkMessage {
fn validate(&self, ctx: &ValidationContext) -> P2pResult<()> {
validate_peer_id(&self.peer_id)?;
validate_message_size(self.payload.len(), ctx.max_message_size)?;
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_err(|e| P2PError::Internal(format!("System time error: {}", e).into()))?
.as_secs();
if self.timestamp > now + 300 {
return Err(
ValidationError::InvalidFormat("Timestamp too far in future".to_string()).into(),
);
}
Ok(())
}
}
#[derive(Debug)]
#[allow(dead_code)]
pub struct ApiRequest {
pub method: String,
pub path: String,
pub params: HashMap<String, String>,
}
impl Validate for ApiRequest {
fn validate(&self, _ctx: &ValidationContext) -> P2pResult<()> {
match self.method.as_str() {
"GET" | "POST" | "PUT" | "DELETE" => {}
_ => {
return Err(ValidationError::InvalidFormat(format!(
"Invalid HTTP method: {}",
self.method
))
.into());
}
}
if !self.path.starts_with('/') {
return Err(
ValidationError::InvalidFormat("Path must start with /".to_string()).into(),
);
}
if self.path.contains("..") {
return Err(ValidationError::PathTraversal(self.path.clone()).into());
}
for (key, value) in &self.params {
if key.is_empty() {
return Err(
ValidationError::InvalidFormat("Empty parameter key".to_string()).into(),
);
}
let lower_value = value.to_lowercase();
let sql_patterns = [
"select ", "insert ", "update ", "delete ", "drop ", "union ", "exec ", "--", "/*",
"*/", "'", "\"", " or ", " and ", "1=1", "1='1",
];
for pattern in &sql_patterns {
if lower_value.contains(pattern) {
return Err(ValidationError::InvalidFormat(
"Suspicious parameter value: potential SQL injection".to_string(),
)
.into());
}
}
let dangerous_chars = ['|', '&', ';', '$', '`', '\n', '\0'];
if value.chars().any(|c| dangerous_chars.contains(&c)) {
return Err(ValidationError::InvalidFormat(
"Dangerous characters in parameter value".to_string(),
)
.into());
}
}
Ok(())
}
}
#[allow(dead_code)]
pub fn sanitize_string(input: &str, max_length: usize) -> String {
let mut cleaned = input
.replace(['<', '>'], "")
.replace("script", "")
.replace("javascript:", "")
.replace("onerror", "")
.replace("onload", "")
.replace("onclick", "")
.replace("alert", "")
.replace("iframe", "");
cleaned = cleaned.replace('\u{2060}', ""); cleaned = cleaned.replace('\u{ffa0}', ""); cleaned = cleaned.replace('\u{200b}', ""); cleaned = cleaned.replace('\u{200c}', ""); cleaned = cleaned.replace('\u{200d}', "");
cleaned
.chars()
.filter(|c| c.is_alphanumeric() || *c == '_' || *c == '-' || *c == '.')
.take(max_length)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_peer_id_validation() {
let peer = PeerId::random();
assert!(validate_peer_id(&peer).is_ok());
}
#[test]
fn test_file_path_validation() {
assert!(validate_file_path(Path::new("data/file.txt")).is_ok());
assert!(validate_file_path(Path::new("/usr/local/bin")).is_ok());
assert!(validate_file_path(Path::new("../etc/passwd")).is_err());
assert!(validate_file_path(Path::new("file\0name")).is_err());
}
#[test]
fn test_rate_limiter() {
let config = RateLimitConfig {
window: Duration::from_millis(500), max_requests: 10,
burst_size: 5,
..Default::default()
};
let limiter = RateLimiter::new(config);
let ip: IpAddr = "192.168.1.1".parse().unwrap();
for _ in 0..5 {
assert!(limiter.check_ip(&ip).is_ok());
}
assert!(limiter.check_ip(&ip).is_err());
std::thread::sleep(Duration::from_millis(600));
assert!(limiter.check_ip(&ip).is_ok());
}
#[test]
fn test_message_validation() {
let ctx = ValidationContext::default();
let valid_msg = NetworkMessage {
peer_id: PeerId::random(),
payload: vec![0u8; 1024],
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs(),
};
assert!(valid_msg.validate(&ctx).is_ok());
}
#[test]
fn test_sanitization() {
assert_eq!(sanitize_string("hello world!", 20), "helloworld");
assert_eq!(sanitize_string("test@#$%123", 20), "test123");
assert_eq!(
sanitize_string("very_long_string_that_exceeds_limit", 10),
"very_long_"
);
}
}