Skip to main content

aegis_server/
shield_handlers.rs

1//! Aegis Shield API Handlers
2//!
3//! REST API handlers for the integrated security shield.
4
5use crate::state::AppState;
6use axum::{
7    extract::{Path, Query, State},
8    http::StatusCode,
9    response::IntoResponse,
10    Json,
11};
12use serde::Deserialize;
13
14// =============================================================================
15// Request Types
16// =============================================================================
17
18#[derive(Deserialize, Default)]
19pub struct EventsQuery {
20    pub limit: Option<usize>,
21    pub level: Option<String>,
22}
23
24#[derive(Deserialize)]
25pub struct BlockIpRequest {
26    pub ip: String,
27    pub reason: Option<String>,
28    pub duration_secs: Option<u64>,
29}
30
31#[derive(Deserialize)]
32pub struct AllowlistRequest {
33    pub ip: String,
34}
35
36#[derive(Deserialize)]
37pub struct UpdatePolicyRequest {
38    pub preset: Option<String>,
39    pub sql_injection_enabled: Option<bool>,
40    pub anomaly_detection_enabled: Option<bool>,
41    pub ip_reputation_enabled: Option<bool>,
42    pub fingerprinting_enabled: Option<bool>,
43    pub auto_blocking_enabled: Option<bool>,
44    pub auto_block_threshold: Option<u32>,
45}
46
47// =============================================================================
48// Handlers
49// =============================================================================
50
51/// GET /api/v1/shield/status
52pub async fn shield_status(State(state): State<AppState>) -> impl IntoResponse {
53    let status = state.shield.get_status();
54    Json(serde_json::json!({
55        "enabled": status.enabled,
56        "preset": format!("{:?}", status.preset),
57        "uptime_secs": status.uptime_secs,
58        "total_requests_analyzed": status.total_requests_analyzed,
59        "total_threats_detected": status.total_threats_detected,
60        "active_bans": status.active_bans,
61        "blocked_ips": status.blocked_ips,
62    }))
63}
64
65/// GET /api/v1/shield/stats
66pub async fn shield_stats(State(state): State<AppState>) -> impl IntoResponse {
67    let stats = state.shield.get_stats();
68    Json(serde_json::to_value(&stats).unwrap_or_default())
69}
70
71/// GET /api/v1/shield/events
72pub async fn shield_events(
73    State(state): State<AppState>,
74    Query(params): Query<EventsQuery>,
75) -> impl IntoResponse {
76    let limit = params.limit.unwrap_or(50);
77    let events = state.shield.get_recent_events(limit);
78    Json(serde_json::json!({"events": events, "count": events.len()}))
79}
80
81/// GET /api/v1/shield/blocked
82pub async fn list_blocked(State(state): State<AppState>) -> impl IntoResponse {
83    let blocked = state.shield.get_blocked_ips();
84    Json(serde_json::json!({"blocked": blocked, "count": blocked.len()}))
85}
86
87/// POST /api/v1/shield/blocked
88pub async fn block_ip(
89    State(state): State<AppState>,
90    Json(req): Json<BlockIpRequest>,
91) -> impl IntoResponse {
92    let reason = req.reason.unwrap_or_else(|| "Manual block".to_string());
93    let duration = req.duration_secs.unwrap_or(3600);
94    state.shield.manual_block(&req.ip, &reason, duration);
95    (
96        StatusCode::CREATED,
97        Json(serde_json::json!({
98            "status": "blocked",
99            "ip": req.ip,
100            "duration_secs": duration,
101        })),
102    )
103}
104
105/// DELETE /api/v1/shield/blocked/:ip
106pub async fn unblock_ip(
107    State(state): State<AppState>,
108    Path(ip): Path<String>,
109) -> impl IntoResponse {
110    if state.shield.unblock_ip(&ip) {
111        (
112            StatusCode::OK,
113            Json(serde_json::json!({"status": "unblocked", "ip": ip})),
114        )
115    } else {
116        (
117            StatusCode::NOT_FOUND,
118            Json(serde_json::json!({"error": "IP not found in block list"})),
119        )
120    }
121}
122
123/// GET /api/v1/shield/allowlist
124pub async fn get_allowlist(State(state): State<AppState>) -> impl IntoResponse {
125    let list = state.shield.get_allowlist();
126    Json(serde_json::json!({"allowlist": list}))
127}
128
129/// POST /api/v1/shield/allowlist
130pub async fn add_to_allowlist(
131    State(state): State<AppState>,
132    Json(req): Json<AllowlistRequest>,
133) -> impl IntoResponse {
134    state.shield.add_to_allowlist(&req.ip);
135    (
136        StatusCode::CREATED,
137        Json(serde_json::json!({"status": "added", "ip": req.ip})),
138    )
139}
140
141/// DELETE /api/v1/shield/allowlist/:ip
142pub async fn remove_from_allowlist(
143    State(state): State<AppState>,
144    Path(ip): Path<String>,
145) -> impl IntoResponse {
146    state.shield.remove_from_allowlist(&ip);
147    (
148        StatusCode::OK,
149        Json(serde_json::json!({"status": "removed", "ip": ip})),
150    )
151}
152
153/// GET /api/v1/shield/policy
154pub async fn get_policy(State(state): State<AppState>) -> impl IntoResponse {
155    let policy = state.shield.get_policy();
156    Json(serde_json::to_value(&policy).unwrap_or_default())
157}
158
159/// PUT /api/v1/shield/policy
160pub async fn update_policy(
161    State(state): State<AppState>,
162    Json(req): Json<UpdatePolicyRequest>,
163) -> impl IntoResponse {
164    let mut policy = state.shield.get_policy();
165
166    if let Some(preset_str) = &req.preset {
167        match preset_str.to_lowercase().as_str() {
168            "strict" => {
169                policy =
170                    aegis_shield::SecurityPolicy::from_preset(aegis_shield::SecurityPreset::Strict)
171            }
172            "moderate" => {
173                policy = aegis_shield::SecurityPolicy::from_preset(
174                    aegis_shield::SecurityPreset::Moderate,
175                )
176            }
177            "permissive" => {
178                policy = aegis_shield::SecurityPolicy::from_preset(
179                    aegis_shield::SecurityPreset::Permissive,
180                )
181            }
182            _ => {
183                return (
184                    StatusCode::BAD_REQUEST,
185                    Json(serde_json::json!({
186                        "error": "Invalid preset. Use: strict, moderate, permissive"
187                    })),
188                );
189            }
190        }
191    }
192
193    if let Some(v) = req.sql_injection_enabled {
194        policy.sql_injection_enabled = v;
195    }
196    if let Some(v) = req.anomaly_detection_enabled {
197        policy.anomaly_detection_enabled = v;
198    }
199    if let Some(v) = req.ip_reputation_enabled {
200        policy.ip_reputation_enabled = v;
201    }
202    if let Some(v) = req.fingerprinting_enabled {
203        policy.fingerprinting_enabled = v;
204    }
205    if let Some(v) = req.auto_blocking_enabled {
206        policy.auto_blocking_enabled = v;
207    }
208    state.shield.update_policy(policy);
209    (
210        StatusCode::OK,
211        Json(serde_json::json!({"status": "updated"})),
212    )
213}
214
215/// GET /api/v1/shield/ip/:ip
216pub async fn get_ip_reputation(
217    State(state): State<AppState>,
218    Path(ip): Path<String>,
219) -> impl IntoResponse {
220    match state.shield.get_ip_reputation(&ip) {
221        Some(rep) => (
222            StatusCode::OK,
223            Json(serde_json::to_value(&rep).unwrap_or_default()),
224        ),
225        None => (
226            StatusCode::NOT_FOUND,
227            Json(serde_json::json!({"error": "No data for this IP"})),
228        ),
229    }
230}
231
232/// GET /api/v1/shield/feed
233pub async fn shield_feed(State(state): State<AppState>) -> impl IntoResponse {
234    let stats = state.shield.get_stats();
235    let recent = state.shield.get_recent_events(10);
236    Json(serde_json::json!({
237        "stats": stats,
238        "recent_events": recent,
239    }))
240}