1use std::sync::Arc;
4
5use axum::{
6 extract::{Json, Path, State},
7 http::StatusCode,
8 response::IntoResponse,
9 routing::{get, post},
10 Router,
11};
12use serde::{Deserialize, Serialize};
13
14use super::state::AppState;
15use crate::codec::Algorithm;
16use crate::protocol::{Capabilities, Message, MessageType};
17
18pub fn create_router(state: Arc<AppState>) -> Router {
20 Router::new()
21 .route("/health", get(health_check))
23 .route("/status", get(status))
24 .route("/session", post(create_session))
26 .route("/session/{id}", get(get_session))
27 .route("/session/{id}", axum::routing::delete(delete_session))
28 .route("/compress", post(compress))
30 .route("/decompress", post(decompress))
31 .route("/compress/auto", post(compress_auto))
32 .route("/scan", post(scan_content))
34 .route("/message", post(process_message))
36 .with_state(state)
37}
38
39#[derive(Serialize)]
41pub struct HealthResponse {
42 pub status: &'static str,
43 pub version: &'static str,
44}
45
46pub async fn health_check() -> impl IntoResponse {
48 Json(HealthResponse {
49 status: "ok",
50 version: env!("CARGO_PKG_VERSION"),
51 })
52}
53
54#[derive(Serialize)]
56pub struct StatusResponse {
57 pub status: &'static str,
58 pub version: &'static str,
59 pub uptime_secs: u64,
60 pub active_sessions: usize,
61 pub capabilities: Capabilities,
62}
63
64async fn status(State(state): State<Arc<AppState>>) -> impl IntoResponse {
66 let session_count = state.sessions.count().await;
67
68 Json(StatusResponse {
69 status: "ok",
70 version: env!("CARGO_PKG_VERSION"),
71 uptime_secs: state.uptime().as_secs(),
72 active_sessions: session_count,
73 capabilities: state.capabilities(),
74 })
75}
76
77#[derive(Deserialize)]
79pub struct CreateSessionRequest {
80 #[serde(default)]
81 pub capabilities: Option<Capabilities>,
82}
83
84#[derive(Serialize)]
86pub struct SessionResponse {
87 pub session_id: String,
88 pub capabilities: Capabilities,
89}
90
91async fn create_session(
93 State(state): State<Arc<AppState>>,
94 Json(req): Json<CreateSessionRequest>,
95) -> impl IntoResponse {
96 let client_caps = req.capabilities.unwrap_or_default();
97 let mut session = state.sessions.create(client_caps).await;
98
99 let hello = session.create_hello();
101 let _ = session.process_message(&hello);
102
103 let response = SessionResponse {
104 session_id: session.id().to_string(),
105 capabilities: state.capabilities(),
106 };
107
108 state.sessions.update(&session).await;
109 (StatusCode::CREATED, Json(response))
110}
111
112async fn get_session(
114 State(state): State<Arc<AppState>>,
115 Path(id): Path<String>,
116) -> impl IntoResponse {
117 match state.sessions.get(&id).await {
118 Some(session) => {
119 let stats = session.stats();
120 (
121 StatusCode::OK,
122 Json(serde_json::json!({
123 "session_id": stats.session_id,
124 "state": format!("{:?}", stats.state),
125 "messages_sent": stats.messages_sent,
126 "messages_received": stats.messages_received,
127 "bytes_compressed": stats.bytes_compressed,
128 "bytes_saved": stats.bytes_saved,
129 "compression_ratio": stats.compression_ratio(),
130 })),
131 )
132 },
133 None => (
134 StatusCode::NOT_FOUND,
135 Json(serde_json::json!({"error": "Session not found"})),
136 ),
137 }
138}
139
140async fn delete_session(
142 State(state): State<Arc<AppState>>,
143 Path(id): Path<String>,
144) -> impl IntoResponse {
145 state.sessions.remove(&id).await;
146 StatusCode::NO_CONTENT
147}
148
149#[derive(Deserialize)]
151pub struct CompressRequest {
152 pub content: String,
153 #[serde(default)]
154 pub algorithm: Option<Algorithm>,
155}
156
157#[derive(Serialize)]
159#[allow(dead_code)]
160pub struct CompressResponse {
161 pub data: String,
162 pub algorithm: Algorithm,
163 pub original_bytes: usize,
164 pub compressed_bytes: usize,
165 pub ratio: f64,
166}
167
168async fn compress(
170 State(state): State<Arc<AppState>>,
171 Json(req): Json<CompressRequest>,
172) -> impl IntoResponse {
173 if state.config.security_enabled {
175 let scan_result = state.scanner.scan(&req.content);
176 if let Ok(result) = scan_result {
177 if result.should_block {
178 return (
179 StatusCode::FORBIDDEN,
180 Json(serde_json::json!({
181 "error": "Content blocked by security scan",
182 "threats": result.threats.iter().map(|t| &t.name).collect::<Vec<_>>(),
183 })),
184 );
185 }
186 }
187 }
188
189 let algorithm = req.algorithm.unwrap_or(Algorithm::M2M);
190
191 match state.codec.compress(&req.content, algorithm) {
192 Ok(result) => (
193 StatusCode::OK,
194 Json(serde_json::json!({
195 "data": result.data,
196 "algorithm": result.algorithm,
197 "original_bytes": result.original_bytes,
198 "compressed_bytes": result.compressed_bytes,
199 "ratio": result.byte_ratio(),
200 })),
201 ),
202 Err(e) => (
203 StatusCode::BAD_REQUEST,
204 Json(serde_json::json!({"error": e.to_string()})),
205 ),
206 }
207}
208
209async fn compress_auto(
211 State(state): State<Arc<AppState>>,
212 Json(req): Json<CompressRequest>,
213) -> impl IntoResponse {
214 if state.config.security_enabled {
216 if let Ok(result) = state.scanner.scan(&req.content) {
217 if result.should_block {
218 return (
219 StatusCode::FORBIDDEN,
220 Json(serde_json::json!({
221 "error": "Content blocked by security scan",
222 })),
223 );
224 }
225 }
226 }
227
228 match state.codec.compress_auto(&req.content) {
229 Ok((result, _)) => (
230 StatusCode::OK,
231 Json(serde_json::json!({
232 "data": result.data,
233 "algorithm": result.algorithm,
234 "original_bytes": result.original_bytes,
235 "compressed_bytes": result.compressed_bytes,
236 "ratio": result.byte_ratio(),
237 })),
238 ),
239 Err(e) => (
240 StatusCode::BAD_REQUEST,
241 Json(serde_json::json!({"error": e.to_string()})),
242 ),
243 }
244}
245
246#[derive(Deserialize)]
248pub struct DecompressRequest {
249 pub data: String,
250}
251
252async fn decompress(
254 State(state): State<Arc<AppState>>,
255 Json(req): Json<DecompressRequest>,
256) -> impl IntoResponse {
257 match state.codec.decompress(&req.data) {
258 Ok(content) => (
259 StatusCode::OK,
260 Json(serde_json::json!({
261 "content": content,
262 "bytes": content.len(),
263 })),
264 ),
265 Err(e) => (
266 StatusCode::BAD_REQUEST,
267 Json(serde_json::json!({"error": e.to_string()})),
268 ),
269 }
270}
271
272#[derive(Deserialize)]
274pub struct ScanRequest {
275 pub content: String,
276}
277
278async fn scan_content(
280 State(state): State<Arc<AppState>>,
281 Json(req): Json<ScanRequest>,
282) -> impl IntoResponse {
283 match state.scanner.scan(&req.content) {
284 Ok(result) => (
285 StatusCode::OK,
286 Json(serde_json::json!({
287 "safe": result.safe,
288 "confidence": result.confidence,
289 "threats": result.threats.iter().map(|t| serde_json::json!({
290 "name": t.name,
291 "category": t.category,
292 "severity": t.severity,
293 "description": t.description,
294 })).collect::<Vec<_>>(),
295 "should_block": result.should_block,
296 })),
297 ),
298 Err(e) => (
299 StatusCode::BAD_REQUEST,
300 Json(serde_json::json!({"error": e.to_string()})),
301 ),
302 }
303}
304
305async fn process_message(
307 State(state): State<Arc<AppState>>,
308 Json(message): Json<Message>,
309) -> impl IntoResponse {
310 match message.msg_type {
311 MessageType::Hello => {
312 let caps = message.get_capabilities().cloned().unwrap_or_default();
314 let mut session = state.sessions.create(caps).await;
315
316 match session.process_message(&message) {
317 Ok(Some(response)) => {
318 state.sessions.update(&session).await;
319 (StatusCode::OK, Json(response))
320 },
321 Ok(None) => (
322 StatusCode::OK,
323 Json(Message::accept(session.id(), state.capabilities())),
324 ),
325 Err(e) => (
326 StatusCode::BAD_REQUEST,
327 Json(Message::reject(
328 crate::protocol::RejectionCode::Unknown,
329 &e.to_string(),
330 )),
331 ),
332 }
333 },
334 MessageType::Data => {
335 let Some(session_id) = message.session_id.as_ref() else {
337 return (
338 StatusCode::BAD_REQUEST,
339 Json(Message::reject(
340 crate::protocol::RejectionCode::Unknown,
341 "Missing session ID",
342 )),
343 );
344 };
345
346 match state.sessions.get(session_id).await {
347 Some(mut session) => match session.decompress(&message) {
348 Ok(content) => {
349 state.sessions.update(&session).await;
350 (
351 StatusCode::OK,
352 Json(serde_json::from_str::<Message>(&format!(
353 r#"{{"type":"DATA","session_id":"{session_id}","payload":{{"content":"{content}"}}}}"#
354 )).unwrap_or(message)),
355 )
356 },
357 Err(e) => (
358 StatusCode::BAD_REQUEST,
359 Json(Message::reject(
360 crate::protocol::RejectionCode::Unknown,
361 &e.to_string(),
362 )),
363 ),
364 },
365 None => (
366 StatusCode::NOT_FOUND,
367 Json(Message::reject(
368 crate::protocol::RejectionCode::Unknown,
369 "Session not found",
370 )),
371 ),
372 }
373 },
374 MessageType::Ping => {
375 let session_id = message.session_id.as_deref().unwrap_or("unknown");
376 (StatusCode::OK, Json(Message::pong(session_id)))
377 },
378 MessageType::Close => {
379 if let Some(id) = &message.session_id {
380 state.sessions.remove(id).await;
381 }
382 (StatusCode::OK, Json(message))
383 },
384 _ => (
385 StatusCode::BAD_REQUEST,
386 Json(Message::reject(
387 crate::protocol::RejectionCode::Unknown,
388 "Unsupported message type",
389 )),
390 ),
391 }
392}