use std::collections::HashMap;
use std::env;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::time::{Duration, Instant};
use thiserror::Error;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
#[derive(Error, Debug)]
pub enum SecurityError {
#[error("REASONKIT_WEB_TOKEN environment variable not set")]
MissingToken,
#[error("Invalid token format: {0}")]
InvalidTokenFormat(String),
#[error("Invalid rate limit: {0}")]
InvalidRateLimit(String),
#[error("Security configuration error: {0}")]
ConfigError(String),
}
pub type SecurityResult<T> = std::result::Result<T, SecurityError>;
#[derive(Debug, Clone)]
pub struct SecurityConfig {
token_hash: [u8; 32],
pub bind_all: bool,
pub bind_addr: IpAddr,
pub rate_limit_rpm: u32,
pub allowed_origins: Vec<String>,
pub auth_bypass_paths: Vec<String>,
}
impl SecurityConfig {
pub fn from_env() -> SecurityResult<Self> {
let token = env::var("REASONKIT_WEB_TOKEN").map_err(|_| SecurityError::MissingToken)?;
if token.is_empty() {
return Err(SecurityError::InvalidTokenFormat(
"Token cannot be empty".to_string(),
));
}
if token.len() < 32 {
warn!("SECURITY WARNING: REASONKIT_WEB_TOKEN is less than 32 characters");
}
let token_hash = Self::hash_token(&token);
let bind_all = env::var("REASONKIT_WEB_BIND_ALL")
.map(|v| v.to_lowercase() == "true")
.unwrap_or(false);
let bind_addr = if bind_all {
warn!("SECURITY: Binding to 0.0.0.0 (REASONKIT_WEB_BIND_ALL=true)");
IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))
} else {
info!("SECURITY: Binding to localhost only (127.0.0.1)");
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))
};
let rate_limit_rpm = env::var("REASONKIT_WEB_RATE_LIMIT")
.unwrap_or_else(|_| "100".to_string())
.parse::<u32>()
.map_err(|e| SecurityError::InvalidRateLimit(e.to_string()))?;
if rate_limit_rpm == 0 {
return Err(SecurityError::InvalidRateLimit(
"Rate limit cannot be 0".to_string(),
));
}
info!(
"SECURITY: Rate limit set to {} requests/minute per IP",
rate_limit_rpm
);
Ok(Self {
token_hash,
bind_all,
bind_addr,
rate_limit_rpm,
allowed_origins: vec![
"http://localhost".to_string(),
"http://127.0.0.1".to_string(),
"http://localhost:3000".to_string(),
"http://localhost:9100".to_string(),
"http://127.0.0.1:3000".to_string(),
"http://127.0.0.1:9100".to_string(),
],
auth_bypass_paths: vec!["/health".to_string(), "/healthz".to_string()],
})
}
#[cfg(test)]
pub fn test_config() -> Self {
Self {
token_hash: Self::hash_token("test-token-for-unit-tests-only"),
bind_all: false,
bind_addr: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
rate_limit_rpm: 100,
allowed_origins: vec!["http://localhost".to_string()],
auth_bypass_paths: vec!["/health".to_string()],
}
}
fn hash_token(token: &str) -> [u8; 32] {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut result = [0u8; 32];
let mut hasher = DefaultHasher::new();
token.hash(&mut hasher);
let hash = hasher.finish();
for (i, byte) in result.iter_mut().enumerate() {
*byte = ((hash >> ((i % 8) * 8)) & 0xFF) as u8 ^ (i as u8);
}
result
}
pub fn verify_token(&self, token: &str) -> bool {
let provided_hash = Self::hash_token(token);
constant_time_compare(&self.token_hash, &provided_hash)
}
pub fn socket_addr(&self, port: u16) -> SocketAddr {
SocketAddr::new(self.bind_addr, port)
}
pub fn is_auth_bypass_path(&self, path: &str) -> bool {
self.auth_bypass_paths.iter().any(|p| path.starts_with(p))
}
pub fn is_origin_allowed(&self, origin: &str) -> bool {
self.allowed_origins.iter().any(|o| origin.starts_with(o))
}
}
fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut result = 0u8;
for (x, y) in a.iter().zip(b.iter()) {
result |= x ^ y;
}
result == 0
}
#[derive(Debug)]
pub struct RateLimiter {
max_requests: u32,
window: Duration,
buckets: Arc<RwLock<HashMap<IpAddr, RateBucket>>>,
}
#[derive(Debug, Clone)]
struct RateBucket {
count: u32,
window_start: Instant,
}
impl RateLimiter {
pub fn new(requests_per_minute: u32) -> Self {
Self {
max_requests: requests_per_minute,
window: Duration::from_secs(60),
buckets: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn check(&self, ip: IpAddr) -> RateLimitResult {
let mut buckets = self.buckets.write().await;
let now = Instant::now();
let bucket = buckets.entry(ip).or_insert_with(|| RateBucket {
count: 0,
window_start: now,
});
if now.duration_since(bucket.window_start) >= self.window {
bucket.count = 0;
bucket.window_start = now;
}
bucket.count += 1;
if bucket.count > self.max_requests {
let remaining_time = self
.window
.saturating_sub(now.duration_since(bucket.window_start));
RateLimitResult::Exceeded {
retry_after: remaining_time,
limit: self.max_requests,
}
} else {
RateLimitResult::Allowed {
remaining: self.max_requests - bucket.count,
limit: self.max_requests,
}
}
}
pub async fn cleanup(&self) {
let mut buckets = self.buckets.write().await;
let now = Instant::now();
let window = self.window;
buckets.retain(|_, bucket| now.duration_since(bucket.window_start) < window * 2);
}
}
#[derive(Debug, Clone)]
pub enum RateLimitResult {
Allowed {
remaining: u32,
limit: u32,
},
Exceeded {
retry_after: Duration,
limit: u32,
},
}
impl RateLimitResult {
pub fn is_allowed(&self) -> bool {
matches!(self, RateLimitResult::Allowed { .. })
}
}
#[derive(Debug, Clone)]
pub struct IpFilter {
allow_all: bool,
allowed_ips: Vec<IpAddr>,
}
impl IpFilter {
pub fn new(allow_all: bool) -> Self {
Self {
allow_all,
allowed_ips: vec![
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), IpAddr::V6(std::net::Ipv6Addr::LOCALHOST),
],
}
}
pub fn is_allowed(&self, ip: IpAddr) -> bool {
if self.allow_all {
return true;
}
match ip {
IpAddr::V4(ipv4) => ipv4.is_loopback() || self.allowed_ips.contains(&ip),
IpAddr::V6(ipv6) => ipv6.is_loopback() || self.allowed_ips.contains(&ip),
}
}
pub fn allow_ip(&mut self, ip: IpAddr) {
if !self.allowed_ips.contains(&ip) {
self.allowed_ips.push(ip);
}
}
}
impl Default for IpFilter {
fn default() -> Self {
Self::new(false)
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum AuthResult {
Authenticated,
MissingHeader,
InvalidFormat,
InvalidToken,
Bypassed,
}
impl AuthResult {
pub fn is_ok(&self) -> bool {
matches!(self, AuthResult::Authenticated | AuthResult::Bypassed)
}
pub fn status_code(&self) -> u16 {
match self {
AuthResult::Authenticated | AuthResult::Bypassed => 200,
AuthResult::MissingHeader | AuthResult::InvalidFormat | AuthResult::InvalidToken => 401,
}
}
pub fn error_message(&self) -> Option<&'static str> {
match self {
AuthResult::Authenticated | AuthResult::Bypassed => None,
AuthResult::MissingHeader => Some("Missing Authorization header"),
AuthResult::InvalidFormat => {
Some("Invalid Authorization format. Expected: Bearer <token>")
}
AuthResult::InvalidToken => Some("Invalid token"),
}
}
}
#[derive(Debug, Clone)]
pub struct TokenAuthenticator {
config: Arc<SecurityConfig>,
}
impl TokenAuthenticator {
pub fn new(config: Arc<SecurityConfig>) -> Self {
Self { config }
}
pub fn authenticate(&self, path: &str, auth_header: Option<&str>) -> AuthResult {
if self.config.is_auth_bypass_path(path) {
debug!("Auth bypass for path: {}", path);
return AuthResult::Bypassed;
}
let header = match auth_header {
Some(h) => h,
None => return AuthResult::MissingHeader,
};
let token = match header.strip_prefix("Bearer ") {
Some(t) => t.trim(),
None => return AuthResult::InvalidFormat,
};
if token.is_empty() {
return AuthResult::InvalidFormat;
}
if self.config.verify_token(token) {
AuthResult::Authenticated
} else {
warn!("Invalid authentication token attempt");
AuthResult::InvalidToken
}
}
}
#[derive(Debug, Clone)]
pub struct SecurityHeaders;
impl SecurityHeaders {
pub fn headers() -> Vec<(&'static str, &'static str)> {
vec![
("X-Content-Type-Options", "nosniff"),
("X-Frame-Options", "DENY"),
(
"Content-Security-Policy",
"default-src 'none'; frame-ancestors 'none'",
),
("X-XSS-Protection", "0"),
("Referrer-Policy", "no-referrer"),
(
"Cache-Control",
"no-store, no-cache, must-revalidate, private",
),
("Pragma", "no-cache"),
(
"Permissions-Policy",
"geolocation=(), microphone=(), camera=()",
),
]
}
pub fn cors_headers(
origin: Option<&str>,
config: &SecurityConfig,
) -> Vec<(&'static str, String)> {
let mut headers = Vec::new();
if let Some(origin) = origin {
if config.is_origin_allowed(origin) {
headers.push(("Access-Control-Allow-Origin", origin.to_string()));
headers.push(("Access-Control-Allow-Methods", "GET, POST".to_string()));
headers.push((
"Access-Control-Allow-Headers",
"Authorization, Content-Type".to_string(),
));
headers.push(("Access-Control-Max-Age", "3600".to_string()));
headers.push(("Access-Control-Allow-Credentials", "false".to_string()));
} else {
debug!("Rejected CORS origin: {}", origin);
}
}
headers
}
pub fn rate_limit_headers(result: &RateLimitResult) -> Vec<(&'static str, String)> {
match result {
RateLimitResult::Allowed { remaining, limit } => {
vec![
("X-RateLimit-Limit", limit.to_string()),
("X-RateLimit-Remaining", remaining.to_string()),
]
}
RateLimitResult::Exceeded { retry_after, limit } => {
vec![
("X-RateLimit-Limit", limit.to_string()),
("X-RateLimit-Remaining", "0".to_string()),
("Retry-After", retry_after.as_secs().to_string()),
]
}
}
}
}
#[derive(Clone)]
pub struct SecurityLayer {
config: Arc<SecurityConfig>,
rate_limiter: Arc<RateLimiter>,
ip_filter: IpFilter,
authenticator: TokenAuthenticator,
}
impl SecurityLayer {
pub fn new(config: SecurityConfig) -> Self {
let config = Arc::new(config);
let rate_limiter = Arc::new(RateLimiter::new(config.rate_limit_rpm));
Self {
ip_filter: IpFilter::new(config.bind_all),
authenticator: TokenAuthenticator::new(Arc::clone(&config)),
config,
rate_limiter,
}
}
pub fn bind_addr(&self, port: u16) -> SocketAddr {
self.config.socket_addr(port)
}
pub async fn check_rate_limit(&self, ip: IpAddr) -> RateLimitResult {
self.rate_limiter.check(ip).await
}
pub fn check_ip(&self, ip: IpAddr) -> bool {
self.ip_filter.is_allowed(ip)
}
pub fn authenticate(&self, path: &str, auth_header: Option<&str>) -> AuthResult {
self.authenticator.authenticate(path, auth_header)
}
pub fn security_headers(&self) -> Vec<(&'static str, &'static str)> {
SecurityHeaders::headers()
}
pub fn cors_headers(&self, origin: Option<&str>) -> Vec<(&'static str, String)> {
SecurityHeaders::cors_headers(origin, &self.config)
}
pub fn rate_limit_headers(&self, result: &RateLimitResult) -> Vec<(&'static str, String)> {
SecurityHeaders::rate_limit_headers(result)
}
pub async fn cleanup_rate_limiter(&self) {
self.rate_limiter.cleanup().await;
}
}
pub struct SecurityCheck {
layer: SecurityLayer,
}
impl SecurityCheck {
pub fn new(layer: SecurityLayer) -> Self {
Self { layer }
}
pub async fn validate(
&self,
remote_ip: IpAddr,
path: &str,
auth_header: Option<&str>,
origin: Option<&str>,
) -> SecurityCheckResult {
if !self.layer.check_ip(remote_ip) {
warn!("Rejected request from non-localhost IP: {}", remote_ip);
return SecurityCheckResult::Rejected {
status: 403,
message: "Forbidden: Only localhost connections allowed".to_string(),
headers: self
.layer
.security_headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect(),
};
}
let rate_result = self.layer.check_rate_limit(remote_ip).await;
if !rate_result.is_allowed() {
warn!("Rate limit exceeded for IP: {}", remote_ip);
let mut headers: Vec<(String, String)> = self
.layer
.security_headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect();
headers.extend(
self.layer
.rate_limit_headers(&rate_result)
.iter()
.map(|(k, v)| (k.to_string(), v.clone())),
);
let retry_after = match &rate_result {
RateLimitResult::Exceeded { retry_after, .. } => retry_after.as_secs(),
_ => 60,
};
return SecurityCheckResult::Rejected {
status: 429,
message: format!("Too Many Requests. Retry after {} seconds.", retry_after),
headers,
};
}
let auth_result = self.layer.authenticate(path, auth_header);
if !auth_result.is_ok() {
warn!("Authentication failed for path {}: {:?}", path, auth_result);
let mut headers: Vec<(String, String)> = self
.layer
.security_headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect();
headers.push(("WWW-Authenticate".to_string(), "Bearer".to_string()));
return SecurityCheckResult::Rejected {
status: auth_result.status_code(),
message: auth_result
.error_message()
.unwrap_or("Unauthorized")
.to_string(),
headers,
};
}
let mut headers: Vec<(String, String)> = self
.layer
.security_headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect();
headers.extend(
self.layer
.rate_limit_headers(&rate_result)
.iter()
.map(|(k, v)| (k.to_string(), v.clone())),
);
headers.extend(
self.layer
.cors_headers(origin)
.iter()
.map(|(k, v)| (k.to_string(), v.clone())),
);
SecurityCheckResult::Allowed { headers }
}
}
#[derive(Debug)]
pub enum SecurityCheckResult {
Allowed {
headers: Vec<(String, String)>,
},
Rejected {
status: u16,
message: String,
headers: Vec<(String, String)>,
},
}
impl SecurityCheckResult {
pub fn is_allowed(&self) -> bool {
matches!(self, SecurityCheckResult::Allowed { .. })
}
pub fn status_code(&self) -> u16 {
match self {
SecurityCheckResult::Allowed { .. } => 200,
SecurityCheckResult::Rejected { status, .. } => *status,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::Ipv4Addr;
#[test]
fn test_security_config_token_verification() {
let config = SecurityConfig::test_config();
assert!(config.verify_token("test-token-for-unit-tests-only"));
assert!(!config.verify_token("wrong-token"));
assert!(!config.verify_token(""));
}
#[test]
fn test_security_config_auth_bypass() {
let config = SecurityConfig::test_config();
assert!(config.is_auth_bypass_path("/health"));
assert!(config.is_auth_bypass_path("/health/live"));
assert!(!config.is_auth_bypass_path("/api/tools"));
}
#[test]
fn test_security_config_origin_allowed() {
let config = SecurityConfig::test_config();
assert!(config.is_origin_allowed("http://localhost"));
assert!(config.is_origin_allowed("http://localhost:3000"));
assert!(!config.is_origin_allowed("http://example.com"));
assert!(!config.is_origin_allowed("https://malicious.com"));
}
#[test]
fn test_ip_filter_localhost() {
let filter = IpFilter::new(false);
assert!(filter.is_allowed(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))));
assert!(filter.is_allowed(IpAddr::V4(Ipv4Addr::LOCALHOST)));
assert!(!filter.is_allowed(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))));
assert!(!filter.is_allowed(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))));
}
#[test]
fn test_ip_filter_allow_all() {
let filter = IpFilter::new(true);
assert!(filter.is_allowed(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))));
assert!(filter.is_allowed(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))));
assert!(filter.is_allowed(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))));
}
#[test]
fn test_token_authenticator() {
let config = Arc::new(SecurityConfig::test_config());
let auth = TokenAuthenticator::new(config);
assert_eq!(
auth.authenticate("/api/test", Some("Bearer test-token-for-unit-tests-only")),
AuthResult::Authenticated
);
assert_eq!(
auth.authenticate("/api/test", Some("Bearer wrong-token")),
AuthResult::InvalidToken
);
assert_eq!(
auth.authenticate("/api/test", None),
AuthResult::MissingHeader
);
assert_eq!(
auth.authenticate("/api/test", Some("Basic dXNlcjpwYXNz")),
AuthResult::InvalidFormat
);
assert_eq!(auth.authenticate("/health", None), AuthResult::Bypassed);
}
#[tokio::test]
async fn test_rate_limiter() {
let limiter = RateLimiter::new(3); let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
assert!(limiter.check(ip).await.is_allowed());
assert!(limiter.check(ip).await.is_allowed());
assert!(limiter.check(ip).await.is_allowed());
assert!(!limiter.check(ip).await.is_allowed());
}
#[tokio::test]
async fn test_rate_limiter_different_ips() {
let limiter = RateLimiter::new(2);
let ip1 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
let ip2 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
assert!(limiter.check(ip1).await.is_allowed());
assert!(limiter.check(ip1).await.is_allowed());
assert!(!limiter.check(ip1).await.is_allowed());
assert!(limiter.check(ip2).await.is_allowed());
assert!(limiter.check(ip2).await.is_allowed());
assert!(!limiter.check(ip2).await.is_allowed());
}
#[test]
fn test_security_headers() {
let headers = SecurityHeaders::headers();
let header_names: Vec<&str> = headers.iter().map(|(name, _)| *name).collect();
assert!(header_names.contains(&"X-Content-Type-Options"));
assert!(header_names.contains(&"X-Frame-Options"));
assert!(header_names.contains(&"Content-Security-Policy"));
assert!(header_names.contains(&"Referrer-Policy"));
assert!(header_names.contains(&"Cache-Control"));
let x_frame = headers.iter().find(|(name, _)| *name == "X-Frame-Options");
assert_eq!(x_frame.unwrap().1, "DENY");
let csp = headers
.iter()
.find(|(name, _)| *name == "Content-Security-Policy");
assert!(csp.unwrap().1.contains("default-src 'none'"));
}
#[test]
fn test_cors_headers_allowed_origin() {
let config = SecurityConfig::test_config();
let headers = SecurityHeaders::cors_headers(Some("http://localhost"), &config);
assert!(!headers.is_empty());
let origin_header = headers
.iter()
.find(|(name, _)| *name == "Access-Control-Allow-Origin");
assert!(origin_header.is_some());
assert_eq!(origin_header.unwrap().1, "http://localhost");
}
#[test]
fn test_cors_headers_disallowed_origin() {
let config = SecurityConfig::test_config();
let headers = SecurityHeaders::cors_headers(Some("http://evil.com"), &config);
assert!(headers.is_empty());
}
#[test]
fn test_constant_time_compare() {
let a = [1u8, 2, 3, 4];
let b = [1u8, 2, 3, 4];
let c = [1u8, 2, 3, 5];
let d = [1u8, 2, 3];
assert!(constant_time_compare(&a, &b));
assert!(!constant_time_compare(&a, &c));
assert!(!constant_time_compare(&a, &d));
}
#[tokio::test]
async fn test_security_check_full_flow() {
let config = SecurityConfig::test_config();
let layer = SecurityLayer::new(config);
let check = SecurityCheck::new(layer);
let result = check
.validate(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
"/api/test",
Some("Bearer test-token-for-unit-tests-only"),
Some("http://localhost"),
)
.await;
assert!(result.is_allowed());
let result = check
.validate(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
"/health",
None,
None,
)
.await;
assert!(result.is_allowed());
}
#[tokio::test]
async fn test_security_check_rejected_cases() {
let config = SecurityConfig::test_config();
let layer = SecurityLayer::new(config);
let check = SecurityCheck::new(layer);
let result = check
.validate(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
"/api/test",
None,
None,
)
.await;
assert!(!result.is_allowed());
assert_eq!(result.status_code(), 401);
let result = check
.validate(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
"/api/test",
Some("Bearer wrong-token"),
None,
)
.await;
assert!(!result.is_allowed());
assert_eq!(result.status_code(), 401);
}
#[test]
fn test_auth_result_status_codes() {
assert_eq!(AuthResult::Authenticated.status_code(), 200);
assert_eq!(AuthResult::Bypassed.status_code(), 200);
assert_eq!(AuthResult::MissingHeader.status_code(), 401);
assert_eq!(AuthResult::InvalidFormat.status_code(), 401);
assert_eq!(AuthResult::InvalidToken.status_code(), 401);
}
}