auth_framework/api/
security_simple.rs1use crate::api::{ApiResponse, ApiState};
8use axum::{
9 extract::{Path, State},
10 Form,
11};
12use serde::{Deserialize, Serialize};
13use serde_json::json;
14use std::collections::HashSet;
15use std::net::IpAddr;
16use std::sync::RwLock;
17use lazy_static::lazy_static;
18
19lazy_static! {
21 static ref IP_BLACKLIST: RwLock<HashSet<IpAddr>> = RwLock::new(HashSet::new());
22 static ref SECURITY_STATS: RwLock<SecurityStats> = RwLock::new(SecurityStats::default());
23}
24
25#[derive(Debug, Default, Clone, Serialize)]
26struct SecurityStats {
27 blocked_requests: u64,
28 failed_auth_attempts: u64,
29 suspicious_activity: u64,
30 last_updated: Option<i64>,
31}
32
33#[derive(Debug, Deserialize)]
34pub struct BlacklistIpForm {
35 pub ip: String,
36 pub reason: Option<String>,
37}
38
39#[derive(Debug, Serialize)]
40pub struct SecurityStatsResponse {
41 pub blocked_requests: u64,
42 pub failed_auth_attempts: u64,
43 pub suspicious_activity: u64,
44 pub blacklisted_ips: usize,
45 pub last_updated: Option<i64>,
46}
47
48pub async fn blacklist_ip_endpoint(
51 State(_state): State<ApiState>,
52 Form(form): Form<BlacklistIpForm>,
53) -> ApiResponse<serde_json::Value> {
54 let ip: IpAddr = match form.ip.parse() {
55 Ok(ip) => ip,
56 Err(_) => return ApiResponse::error_typed("invalid_ip", "Invalid IP address format"),
57 };
58
59 {
60 let mut blacklist = IP_BLACKLIST.write().unwrap();
61 blacklist.insert(ip);
62 }
63
64 {
66 let mut stats = SECURITY_STATS.write().unwrap();
67 stats.blocked_requests += 1;
68 stats.last_updated = Some(chrono::Utc::now().timestamp());
69 }
70
71 let data = json!({
72 "ip": ip.to_string(),
73 "reason": form.reason.unwrap_or_else(|| "Manual blacklist".to_string())
74 });
75
76 ApiResponse::success_with_message(
77 data,
78 format!("IP {} added to blacklist", ip)
79 )
80}
81
82pub async fn unblock_ip_endpoint(
85 State(_state): State<ApiState>,
86 Path(ip_str): Path<String>,
87) -> ApiResponse<serde_json::Value> {
88 let ip: IpAddr = match ip_str.parse() {
89 Ok(ip) => ip,
90 Err(_) => return ApiResponse::error_typed("invalid_ip", "Invalid IP address format"),
91 };
92
93 let removed = {
94 let mut blacklist = IP_BLACKLIST.write().unwrap();
95 blacklist.remove(&ip)
96 };
97
98 if removed {
99 {
101 let mut stats = SECURITY_STATS.write().unwrap();
102 stats.last_updated = Some(chrono::Utc::now().timestamp());
103 }
104
105 let data = json!({
106 "ip": ip.to_string(),
107 "status": "unblocked"
108 });
109
110 ApiResponse::success_with_message(
111 data,
112 format!("IP {} removed from blacklist", ip)
113 )
114 } else {
115 let data = json!({
116 "ip": ip.to_string(),
117 "status": "not_found"
118 });
119
120 ApiResponse::success_with_message(
121 data,
122 format!("IP {} was not in blacklist", ip)
123 )
124 }
125}
126
127pub async fn stats_endpoint(
130 State(_state): State<ApiState>,
131) -> ApiResponse<SecurityStatsResponse> {
132 let stats = SECURITY_STATS.read().unwrap().clone();
133 let blacklist_size = IP_BLACKLIST.read().unwrap().len();
134
135 let response_data = SecurityStatsResponse {
136 blocked_requests: stats.blocked_requests,
137 failed_auth_attempts: stats.failed_auth_attempts,
138 suspicious_activity: stats.suspicious_activity,
139 blacklisted_ips: blacklist_size,
140 last_updated: stats.last_updated,
141 };
142
143 ApiResponse::success(response_data)
144}
145
146pub fn is_ip_blacklisted(ip: &IpAddr) -> bool {
148 IP_BLACKLIST.read().unwrap().contains(ip)
149}
150
151pub fn increment_failed_auth() {
153 let mut stats = SECURITY_STATS.write().unwrap();
154 stats.failed_auth_attempts += 1;
155 stats.last_updated = Some(chrono::Utc::now().timestamp());
156}
157
158pub fn increment_suspicious_activity() {
160 let mut stats = SECURITY_STATS.write().unwrap();
161 stats.suspicious_activity += 1;
162 stats.last_updated = Some(chrono::Utc::now().timestamp());
163}