use std::sync::Arc;
use std::time::Duration;
use axum::{
Router,
Extension,
extract::Request,
middleware,
response::{IntoResponse, Response, Html},
http::StatusCode,
routing::get,
};
use clap::Parser;
use hyper_util::client::legacy::Client;
use hyper_util::rt::TokioExecutor;
use http_body_util::BodyExt;
use tokio::net::TcpListener;
use tracing_subscriber::{EnvFilter, fmt};
use nexus_shield::{
Shield, ShieldConfig, shield_middleware,
audit_chain::SecurityEventType,
endpoint::{EndpointConfig, EndpointEngine},
};
#[derive(Parser, Debug)]
#[command(name = "nexus-shield", about = "NexusShield Local Security Gateway")]
struct Args {
#[arg(short, long, default_value = "8080")]
port: u16,
#[arg(short, long)]
upstream: Option<String>,
#[arg(short, long, default_value = "/etc/nexus-shield/config.toml")]
config: String,
#[arg(long, default_value = "0.7")]
block_threshold: f64,
#[arg(long, default_value = "0.4")]
warn_threshold: f64,
#[arg(long, default_value = "50")]
rps: f64,
#[arg(long, default_value = "false")]
standalone: bool,
#[arg(long, default_value = "false")]
endpoint: bool,
#[arg(long)]
scan: Option<String>,
#[arg(long)]
scan_file: Option<String>,
#[arg(long, default_value = "~/.nexus-shield")]
endpoint_data_dir: String,
}
#[tokio::main]
async fn main() {
fmt()
.with_env_filter(
EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new("info,nexus_shield=debug")),
)
.with_target(true)
.with_thread_ids(false)
.init();
let args = Args::parse();
tracing::info!(
r#"
╔══════════════════════════════════════════════╗
║ NexusShield v0.4.0 ║
║ Adaptive Zero-Trust Security Gateway ║
║ + Real-Time Endpoint Protection ║
║ AutomataNexus Engineering ║
╚══════════════════════════════════════════════╝
"#
);
let mut config = ShieldConfig::default();
config.block_threshold = args.block_threshold;
config.warn_threshold = args.warn_threshold;
config.rate.requests_per_second = args.rps;
let mut shield = Shield::new(config);
let endpoint_engine: Option<Arc<EndpointEngine>> = if args.endpoint
|| args.scan.is_some()
|| args.scan_file.is_some()
{
let ep_config = EndpointConfig::default();
let engine = Arc::new(EndpointEngine::new(ep_config));
shield.endpoint = Some(engine.clone());
Some(engine)
} else {
None
};
if let Some(ref file_path) = args.scan_file {
let engine = endpoint_engine.as_ref().expect("Endpoint engine required");
let path = std::path::Path::new(file_path);
tracing::info!("Scanning file: {}", path.display());
let results = engine.scan_file(path).await;
if results.is_empty() {
println!("CLEAN: No threats detected in {}", path.display());
} else {
println!("THREATS FOUND in {}:", path.display());
for r in &results {
println!(
" [{:?}] {} (confidence: {:.0}%) — {}",
r.severity,
r.scanner,
r.confidence * 100.0,
r.description
);
}
}
std::process::exit(if results.is_empty() { 0 } else { 1 });
}
if let Some(ref dir_path) = args.scan {
let engine = endpoint_engine.as_ref().expect("Endpoint engine required");
let path = std::path::Path::new(dir_path);
tracing::info!("Scanning directory: {}", path.display());
let results = engine.scan_dir(path).await;
if results.is_empty() {
println!("CLEAN: No threats detected in {}", path.display());
} else {
println!("THREATS FOUND ({} detections):", results.len());
for r in &results {
println!(
" [{:?}] {} — {} ({})",
r.severity, r.target, r.description, r.scanner
);
}
}
std::process::exit(if results.is_empty() { 0 } else { 1 });
}
let shield = Arc::new(shield);
let shield_bg = shield.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(60));
loop {
interval.tick().await;
shield_bg.rate_governor.prune_stale(Duration::from_secs(600));
shield_bg.fingerprinter.prune_stale(600);
shield_bg.email_limiter.prune();
tracing::debug!("Pruned stale security state");
}
});
{
let audit_fwd = shield.audit.clone();
let journal_config = nexus_shield::journal::JournalConfig::default();
tokio::spawn(async move {
let mut last_count = audit_fwd.len();
let mut interval = tokio::time::interval(Duration::from_millis(500));
loop {
interval.tick().await;
let current_count = audit_fwd.len();
if current_count > last_count {
let new_events = audit_fwd.recent(current_count - last_count);
for event in new_events.iter().rev() {
nexus_shield::journal::log_to_journal(event, &journal_config);
}
last_count = current_count;
}
}
});
}
if let Some(ref engine) = endpoint_engine {
if args.endpoint {
let handles = engine.start(shield.audit.clone()).await;
tracing::info!(
monitors = handles.len(),
"Endpoint protection started"
);
}
}
let shield_status = shield.clone();
let shield_audit = shield.clone();
let shield_stats = shield.clone();
let endpoint_for_routes = endpoint_engine.clone();
let shield_events = shield.clone();
let shield_report = shield.clone();
let endpoint_routes = if let Some(engine) = endpoint_for_routes {
let ep1 = engine.clone();
let ep2 = engine.clone();
let ep3 = engine.clone();
let ep4 = engine.clone();
Some(
Router::new()
.route("/endpoint/status", get(move || {
let e = ep1.clone();
async move { endpoint_status_handler(e).await }
}))
.route("/endpoint/detections", get(move || {
let e = ep2.clone();
async move { endpoint_detections_handler(e).await }
}))
.route("/endpoint/quarantine", get(move || {
let e = ep3.clone();
async move { endpoint_quarantine_handler(e).await }
}))
.route("/endpoint/scan", axum::routing::post(move |body: String| {
let e = ep4.clone();
async move { endpoint_scan_handler(e, body).await }
}))
)
} else {
None
};
let app = if args.standalone || args.upstream.is_none() {
tracing::info!("Running in standalone mode (no upstream proxy)");
let mut router = Router::new()
.route("/health", get(|| async { "NexusShield OK" }))
.route("/dashboard", get(dashboard_handler))
.route("/logo.png", get(logo_handler))
.route("/status", get(move || async move {
status_handler(shield_status.clone()).await
}))
.route("/audit", get(move || async move {
audit_handler(shield_audit.clone()).await
}))
.route("/stats", get(move || async move {
stats_handler(shield_stats.clone()).await
}))
.route("/events", get(move || {
let audit = shield_events.audit.clone();
async move {
nexus_shield::sse_events::audit_event_stream(audit, 500)
}
}))
.route("/report", get(move || {
let s = shield_report.clone();
async move {
let config = nexus_shield::compliance_report::ReportConfig::default();
let modules = vec![
"sql_firewall".into(), "ssrf_guard".into(), "rate_governor".into(),
"fingerprint".into(), "quarantine".into(), "email_guard".into(),
"credential_vault".into(), "audit_chain".into(), "sanitizer".into(),
"threat_score".into(), "siem_export".into(), "journal".into(),
"sse_events".into(), "compliance_report".into(),
"signatures".into(), "heuristics".into(), "yara_engine".into(),
"watcher".into(), "process_monitor".into(), "network_monitor".into(),
"dns_filter".into(), "usb_monitor".into(), "fim".into(),
"container_scanner".into(), "supply_chain".into(),
];
let shield_cfg = serde_json::json!({
"block_threshold": s.config.block_threshold,
"warn_threshold": s.config.warn_threshold,
"rate_rps": s.config.rate.requests_per_second,
});
let html = nexus_shield::compliance_report::generate_html_report(
&s.audit, &config, &modules, &shield_cfg,
);
Html(html)
}
}));
if let Some(ep_routes) = endpoint_routes {
router = router.merge(ep_routes);
}
router
.fallback(|| async {
(StatusCode::OK, "NexusShield: request inspected and allowed")
})
.layer(middleware::from_fn(shield_middleware))
.layer(Extension(shield.clone()))
} else {
let upstream = args.upstream.clone().unwrap();
tracing::info!(upstream = %upstream, "Running in reverse proxy mode");
let client = Client::builder(TokioExecutor::new()).build_http();
let upstream = Arc::new(upstream);
let proxy_upstream = upstream.clone();
let proxy_client = Arc::new(client);
Router::new()
.route("/health", get(|| async { "NexusShield OK" }))
.route("/dashboard", get(dashboard_handler))
.route("/logo.png", get(logo_handler))
.route("/status", get(move || async move {
status_handler(shield_status.clone()).await
}))
.route("/audit", get(move || async move {
audit_handler(shield_audit.clone()).await
}))
.route("/stats", get(move || async move {
stats_handler(shield_stats.clone()).await
}))
.fallback(move |req: Request| {
let upstream = proxy_upstream.clone();
let client = proxy_client.clone();
async move {
proxy_handler(req, &upstream, &client).await
}
})
.layer(middleware::from_fn(shield_middleware))
.layer(Extension(shield.clone()))
};
let addr = format!("0.0.0.0:{}", args.port);
tracing::info!(listen = %addr, "NexusShield gateway starting");
let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
tracing::info!(
"NexusShield active — protecting on port {}",
args.port
);
axum::serve(listener, app)
.await
.expect("Server failed");
}
async fn proxy_handler(
mut req: Request,
upstream: &str,
client: &Client<hyper_util::client::legacy::connect::HttpConnector, axum::body::Body>,
) -> Response {
let path = req.uri().path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("/");
let upstream_uri = format!("{}{}", upstream, path);
match upstream_uri.parse::<hyper::Uri>() {
Ok(uri) => {
*req.uri_mut() = uri;
match client.request(req).await {
Ok(resp) => {
let (parts, body) = resp.into_parts();
let bytes = body.collect().await
.map(|b| b.to_bytes())
.unwrap_or_default();
Response::from_parts(parts, axum::body::Body::from(bytes))
}
Err(e) => {
tracing::error!(error = %e, "Upstream request failed");
(StatusCode::BAD_GATEWAY, "Upstream unavailable").into_response()
}
}
}
Err(_) => {
(StatusCode::BAD_REQUEST, "Invalid upstream URI").into_response()
}
}
}
async fn dashboard_handler() -> Html<&'static str> {
Html(include_str!("../../widget/index.html"))
}
async fn logo_handler() -> impl IntoResponse {
let bytes: &'static [u8] = include_bytes!("../../assets/NexusShield_logo.png");
(
StatusCode::OK,
[("content-type", "image/png"), ("cache-control", "public, max-age=86400")],
bytes,
)
}
async fn status_handler(shield: Arc<Shield>) -> impl IntoResponse {
let chain_verification = shield.audit.verify_chain();
let audit_count = shield.audit.len();
let status = serde_json::json!({
"service": "NexusShield",
"version": "0.4.0",
"status": "active",
"config": {
"block_threshold": shield.config.block_threshold,
"warn_threshold": shield.config.warn_threshold,
"rate_rps": shield.config.rate.requests_per_second,
"rate_burst": shield.config.rate.burst_capacity,
},
"audit_chain": {
"total_events": audit_count,
"chain_valid": chain_verification.valid,
},
"modules": {
"gateway": [
"sql_firewall",
"ssrf_guard",
"rate_governor",
"fingerprint",
"quarantine",
"email_guard",
"credential_vault",
"audit_chain",
"sanitizer",
"threat_score",
"siem_export",
"journal",
"sse_events",
"compliance_report",
],
"endpoint": [
"signatures",
"heuristics",
"yara_engine",
"watcher",
"process_monitor",
"network_monitor",
"memory_scanner",
"rootkit_detector",
"dns_filter",
"usb_monitor",
"fim",
"container_scanner",
"supply_chain",
"allowlist",
"threat_intel",
"file_quarantine",
]
}
});
(StatusCode::OK, axum::Json(status))
}
async fn audit_handler(shield: Arc<Shield>) -> impl IntoResponse {
let recent = shield.audit.recent(50);
let events: Vec<serde_json::Value> = recent.iter().map(|e| {
serde_json::json!({
"id": e.id,
"timestamp": e.timestamp.to_rfc3339(),
"event_type": format!("{:?}", e.event_type),
"source_ip": e.source_ip,
"details": e.details,
"threat_score": e.threat_score,
})
}).collect();
(StatusCode::OK, axum::Json(serde_json::json!({
"recent_events": events,
"total": shield.audit.len(),
"chain_valid": shield.audit.verify_chain().valid,
})))
}
async fn stats_handler(shield: Arc<Shield>) -> impl IntoResponse {
let now = chrono::Utc::now();
let last_hour = now - chrono::Duration::hours(1);
let last_5min = now - chrono::Duration::minutes(5);
let stats = serde_json::json!({
"last_5min": {
"blocked": shield.audit.count_since(&SecurityEventType::RequestBlocked, last_5min),
"rate_limited": shield.audit.count_since(&SecurityEventType::RateLimitHit, last_5min),
"sql_injection": shield.audit.count_since(&SecurityEventType::SqlInjectionAttempt, last_5min),
"ssrf": shield.audit.count_since(&SecurityEventType::SsrfAttempt, last_5min),
},
"last_hour": {
"blocked": shield.audit.count_since(&SecurityEventType::RequestBlocked, last_hour),
"rate_limited": shield.audit.count_since(&SecurityEventType::RateLimitHit, last_hour),
"sql_injection": shield.audit.count_since(&SecurityEventType::SqlInjectionAttempt, last_hour),
"ssrf": shield.audit.count_since(&SecurityEventType::SsrfAttempt, last_hour),
},
"total_audit_events": shield.audit.len(),
});
(StatusCode::OK, axum::Json(stats))
}
async fn endpoint_status_handler(engine: Arc<EndpointEngine>) -> impl IntoResponse {
let stats = engine.stats();
(StatusCode::OK, axum::Json(serde_json::json!({
"endpoint_protection": "active",
"total_files_scanned": stats.total_files_scanned,
"total_threats_detected": stats.total_threats_detected,
"active_monitors": stats.active_monitors,
"scanners_active": stats.scanners_active,
"quarantined_files": stats.quarantined_files,
"last_scan_time": stats.last_scan_time.map(|t| t.to_rfc3339()),
})))
}
async fn endpoint_detections_handler(engine: Arc<EndpointEngine>) -> impl IntoResponse {
let detections = engine.recent_detections(100);
let events: Vec<serde_json::Value> = detections.iter().map(|d| {
serde_json::json!({
"id": d.id,
"timestamp": d.timestamp.to_rfc3339(),
"scanner": d.scanner,
"target": d.target,
"severity": format!("{}", d.severity),
"description": d.description,
"confidence": d.confidence,
"action": format!("{}", d.action),
"artifact_hash": d.artifact_hash,
})
}).collect();
(StatusCode::OK, axum::Json(serde_json::json!({
"detections": events,
"total": detections.len(),
})))
}
async fn endpoint_quarantine_handler(engine: Arc<EndpointEngine>) -> impl IntoResponse {
let entries = engine.quarantine.list_entries();
let items: Vec<serde_json::Value> = entries.iter().map(|e| {
serde_json::json!({
"id": e.id,
"original_path": e.original_path.to_string_lossy(),
"sha256": e.sha256,
"detection_reason": e.detection_reason,
"scanner": e.scanner,
"severity": format!("{}", e.severity),
"quarantined_at": e.quarantined_at.to_rfc3339(),
"file_size": e.file_size,
})
}).collect();
(StatusCode::OK, axum::Json(serde_json::json!({
"quarantined_files": items,
"total": entries.len(),
"vault_size_bytes": engine.quarantine.vault_size(),
})))
}
async fn endpoint_scan_handler(engine: Arc<EndpointEngine>, body: String) -> impl IntoResponse {
let path = std::path::Path::new(body.trim());
if !path.exists() {
return (StatusCode::BAD_REQUEST, axum::Json(serde_json::json!({
"error": "Path does not exist",
"path": body.trim(),
})));
}
let results = if path.is_dir() {
engine.scan_dir(path).await
} else {
engine.scan_file(path).await
};
let detections: Vec<serde_json::Value> = results.iter().map(|r| {
serde_json::json!({
"scanner": r.scanner,
"target": r.target,
"severity": format!("{}", r.severity),
"description": r.description,
"confidence": r.confidence,
"artifact_hash": r.artifact_hash,
})
}).collect();
(StatusCode::OK, axum::Json(serde_json::json!({
"path": body.trim(),
"clean": results.is_empty(),
"threats_found": results.len(),
"detections": detections,
})))
}