1use std::sync::Arc;
4
5use axum::{
6 body::Body,
7 extract::{
8 ws::{Message, WebSocket, WebSocketUpgrade},
9 Path, State,
10 },
11 http::{Request, Response, StatusCode},
12 response::IntoResponse,
13 Json,
14};
15use futures_util::{SinkExt, StreamExt};
16use serde::{Deserialize, Serialize};
17use serde_json::json;
18use tracing::{debug, error, info};
19
20use mpl_core::envelope::MplEnvelope;
21use mpl_core::metrics::{TocMethod, TocResult};
22
23use crate::proxy::{AiAlpnClientHello, ProxyState};
24
25pub async fn health() -> impl IntoResponse {
27 Json(json!({
28 "status": "healthy",
29 "version": env!("CARGO_PKG_VERSION")
30 }))
31}
32
33pub async fn metrics(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
35 let metrics = &state.metrics;
36 let schema_pass_rate = metrics.schema_pass_rate();
37 let qom_pass_rate = metrics.qom_pass_rate();
38 let downgrade_rate = metrics.downgrade_rate();
39
40 let output = format!(
41 r#"# HELP mpl_requests_total Total number of requests
42# TYPE mpl_requests_total counter
43mpl_requests_total {}
44
45# HELP mpl_schema_validations_total Schema validation results
46# TYPE mpl_schema_validations_total counter
47mpl_schema_validations_total{{result="pass"}} {}
48mpl_schema_validations_total{{result="fail"}} {}
49
50# HELP mpl_schema_pass_rate Schema validation pass rate
51# TYPE mpl_schema_pass_rate gauge
52mpl_schema_pass_rate {}
53
54# HELP mpl_qom_pass_rate QoM pass rate
55# TYPE mpl_qom_pass_rate gauge
56mpl_qom_pass_rate {}
57
58# HELP mpl_handshakes_total Total AI-ALPN handshakes
59# TYPE mpl_handshakes_total counter
60mpl_handshakes_total {}
61
62# HELP mpl_downgrade_rate Protocol downgrade rate
63# TYPE mpl_downgrade_rate gauge
64mpl_downgrade_rate {}
65"#,
66 metrics.requests_total.load(std::sync::atomic::Ordering::Relaxed),
67 metrics.schema_pass.load(std::sync::atomic::Ordering::Relaxed),
68 metrics.schema_fail.load(std::sync::atomic::Ordering::Relaxed),
69 schema_pass_rate,
70 qom_pass_rate,
71 metrics.handshakes.load(std::sync::atomic::Ordering::Relaxed),
72 downgrade_rate,
73 );
74
75 (
76 StatusCode::OK,
77 [("content-type", "text/plain; charset=utf-8")],
78 output,
79 )
80}
81
82pub async fn ai_alpn_handshake(
84 State(state): State<Arc<ProxyState>>,
85 Json(hello): Json<AiAlpnClientHello>,
86) -> impl IntoResponse {
87 info!("AI-ALPN handshake from client with {} STypes", hello.stypes.len());
88
89 let response = state.handle_handshake(hello);
90
91 info!(
92 "Negotiated {} common STypes, profile: {:?}",
93 response.common_stypes.len(),
94 response.selected_profile
95 );
96
97 Json(response)
98}
99
100pub async fn websocket_handler(
102 ws: WebSocketUpgrade,
103 State(state): State<Arc<ProxyState>>,
104) -> impl IntoResponse {
105 info!("WebSocket upgrade requested");
106 ws.on_upgrade(move |socket| handle_websocket(socket, state))
107}
108
109async fn handle_websocket(socket: WebSocket, state: Arc<ProxyState>) {
111 let (mut sender, mut receiver) = socket.split();
112
113 info!("WebSocket connection established");
114
115 while let Some(msg) = receiver.next().await {
116 match msg {
117 Ok(Message::Text(text)) => {
118 debug!("Received WebSocket message: {} bytes", text.len());
119
120 let response = match serde_json::from_str::<MplEnvelope>(&text) {
122 Ok(envelope) => {
123 let validation = state.validate_request(&envelope).await;
125
126 if !validation.valid && state.is_strict() {
127 json!({
129 "error": "E-SCHEMA-FIDELITY",
130 "message": "Validation failed",
131 "details": validation.errors,
132 })
133 } else {
134 json!({
137 "type": "mpl-response",
138 "stype": envelope.stype,
139 "validation": {
140 "valid": validation.valid,
141 "schema_valid": validation.schema_valid,
142 "qom_passed": validation.qom_passed,
143 },
144 "payload": envelope.payload,
145 })
146 }
147 }
148 Err(_) => {
149 if let Ok(hello) = serde_json::from_str::<AiAlpnClientHello>(&text) {
151 let select = state.handle_handshake(hello);
152 serde_json::to_value(&select).unwrap_or_else(|_| json!({"error": "serialization failed"}))
153 } else {
154 json!({
156 "type": "passthrough",
157 "message": text,
158 })
159 }
160 }
161 };
162
163 if let Err(e) = sender.send(Message::Text(response.to_string())).await {
164 error!("Failed to send WebSocket message: {}", e);
165 break;
166 }
167 }
168 Ok(Message::Binary(data)) => {
169 debug!("Received binary WebSocket message: {} bytes", data.len());
170 if let Err(e) = sender.send(Message::Binary(data)).await {
172 error!("Failed to send binary WebSocket message: {}", e);
173 break;
174 }
175 }
176 Ok(Message::Ping(data)) => {
177 if let Err(e) = sender.send(Message::Pong(data)).await {
178 error!("Failed to send pong: {}", e);
179 break;
180 }
181 }
182 Ok(Message::Pong(_)) => {}
183 Ok(Message::Close(_)) => {
184 info!("WebSocket connection closed by client");
185 break;
186 }
187 Err(e) => {
188 error!("WebSocket error: {}", e);
189 break;
190 }
191 }
192 }
193
194 info!("WebSocket connection ended");
195}
196
197pub async fn proxy_handler(
199 State(state): State<Arc<ProxyState>>,
200 Path(path): Path<String>,
201 request: Request<Body>,
202) -> impl IntoResponse {
203 debug!("Proxying request to: {}", path);
204
205 match state.forward_request(path, request).await {
206 Ok(response) => response,
207 Err(e) => {
208 error!("Proxy error: {}", e);
209 Response::builder()
210 .status(StatusCode::BAD_GATEWAY)
211 .header("content-type", "application/json")
212 .body(Body::from(
213 json!({
214 "error": "E-PROXY-ERROR",
215 "message": format!("Proxy error: {}", e),
216 })
217 .to_string(),
218 ))
219 .unwrap()
220 }
221 }
222}
223
224pub async fn capabilities(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
226 let stypes = state.validator.registered_stypes();
227 let profiles: Vec<&str> = state.profiles.iter().map(|p| p.name.as_str()).collect();
228
229 Json(json!({
230 "version": env!("CARGO_PKG_VERSION"),
231 "mpl_version": "1.0",
232 "capabilities": {
233 "schema_validation": state.config.mpl.enforce_schema,
234 "qom_evaluation": true,
235 "semantic_hashing": true,
236 "websocket": true,
237 "toc_callback": true,
238 },
239 "stypes": stypes,
240 "profiles": profiles,
241 "mode": format!("{:?}", state.config.mpl.mode),
242 }))
243}
244
245#[derive(Debug, Clone, Serialize, Deserialize)]
247pub struct TocCallbackRequest {
248 pub callback_id: String,
250 pub verified: bool,
252 #[serde(skip_serializing_if = "Option::is_none")]
254 pub details: Option<String>,
255 #[serde(skip_serializing_if = "Option::is_none")]
257 pub expected: Option<String>,
258 #[serde(skip_serializing_if = "Option::is_none")]
260 pub actual: Option<String>,
261}
262
263#[derive(Debug, Clone, Serialize, Deserialize)]
265pub struct TocCallbackResponse {
266 pub accepted: bool,
268 pub callback_id: String,
270 pub message: String,
272}
273
274pub async fn toc_callback(
279 State(state): State<Arc<ProxyState>>,
280 Json(request): Json<TocCallbackRequest>,
281) -> impl IntoResponse {
282 info!(
283 "TOC callback received: {} verified={}",
284 request.callback_id, request.verified
285 );
286
287 let result = if request.verified {
289 let mut r = TocResult::verified(TocMethod::Callback);
290 r.details = request.details.clone();
291 r.expected = request.expected;
292 r.actual = request.actual;
293 r
294 } else {
295 let mut r = TocResult::failed(
296 TocMethod::Callback,
297 request.details.clone().unwrap_or_else(|| "Verification failed".to_string()),
298 );
299 r.expected = request.expected;
300 r.actual = request.actual;
301 r
302 };
303
304 let was_pending = state.complete_toc(&request.callback_id, result);
306
307 let response = if was_pending {
308 TocCallbackResponse {
309 accepted: true,
310 callback_id: request.callback_id,
311 message: "TOC verification recorded".to_string(),
312 }
313 } else {
314 TocCallbackResponse {
315 accepted: false,
316 callback_id: request.callback_id,
317 message: "Unknown or expired callback ID".to_string(),
318 }
319 };
320
321 Json(response)
322}
323
324pub async fn toc_status(
328 State(state): State<Arc<ProxyState>>,
329 Path(callback_id): Path<String>,
330) -> impl IntoResponse {
331 if let Some(result) = state.get_toc_result(&callback_id) {
333 return Json(json!({
334 "callback_id": callback_id,
335 "status": "completed",
336 "result": result,
337 }));
338 }
339
340 if let Some(pending) = state.get_pending_toc(&callback_id) {
342 return Json(json!({
343 "callback_id": callback_id,
344 "status": "pending",
345 "stype": pending.stype,
346 "registered_at": pending.timestamp,
347 }));
348 }
349
350 Json(json!({
352 "callback_id": callback_id,
353 "status": "unknown",
354 "message": "No verification found for this callback ID",
355 }))
356}
357
358pub async fn toc_pending_list(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
362 let pending: Vec<_> = state
363 .pending_toc
364 .read()
365 .map(|p| p.values().cloned().collect())
366 .unwrap_or_default();
367
368 Json(json!({
369 "pending_count": pending.len(),
370 "verifications": pending,
371 }))
372}
373
374#[derive(Debug, Deserialize, Default)]
378pub struct QomEventsQuery {
379 pub limit: Option<usize>,
381}
382
383#[derive(Debug, Deserialize, Default)]
385pub struct QomHistoryQuery {
386 pub period: Option<String>,
388}
389
390pub async fn qom_summary(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
394 let summary = state.qom_recorder.get_summary().await;
395
396 Json(json!({
397 "metrics": {
398 "schema_fidelity": summary.schema_fidelity,
399 "instruction_compliance": summary.instruction_compliance,
400 "tool_outcome_correctness": summary.tool_outcome_correctness,
401 "groundedness": summary.groundedness,
402 "determinism_jitter": summary.determinism_jitter,
403 "ontology_adherence": summary.ontology_adherence,
404 }
405 }))
406}
407
408pub async fn qom_events(
412 State(state): State<Arc<ProxyState>>,
413 axum::extract::Query(query): axum::extract::Query<QomEventsQuery>,
414) -> impl IntoResponse {
415 let limit = query.limit.unwrap_or(50);
416 let events = state.qom_recorder.get_events(limit).await;
417
418 let events_json: Vec<serde_json::Value> = events
420 .iter()
421 .map(|e| {
422 json!({
423 "id": e.id,
424 "stype": e.stype,
425 "profile": e.profile,
426 "passed": e.passed,
427 "scores": {
428 "SF": e.scores.sf,
429 "IC": e.scores.ic,
430 "TOC": e.scores.toc,
431 "G": e.scores.g,
432 "DJ": e.scores.dj,
433 "OA": e.scores.oa,
434 },
435 "failure_reason": e.failure_reason,
436 "timestamp": e.timestamp.to_rfc3339(),
437 })
438 })
439 .collect();
440
441 Json(json!({
442 "events": events_json,
443 "total": events_json.len(),
444 }))
445}
446
447pub async fn qom_history(
451 State(state): State<Arc<ProxyState>>,
452 axum::extract::Query(query): axum::extract::Query<QomHistoryQuery>,
453) -> impl IntoResponse {
454 let period = query.period.unwrap_or_else(|| "24h".to_string());
455 let history = state.qom_recorder.get_history(&period).await;
456
457 let history_json: Vec<serde_json::Value> = history
459 .iter()
460 .map(|h| {
461 json!({
462 "timestamp": h.timestamp.to_rfc3339(),
463 "count": h.count,
464 "sf": h.sf,
465 "ic": h.ic,
466 "toc": h.toc,
467 "g": h.g,
468 "dj": h.dj,
469 "oa": h.oa,
470 "pass_rate": h.pass_rate,
471 })
472 })
473 .collect();
474
475 Json(json!({
476 "history": history_json,
477 "period": period,
478 }))
479}
480
481pub async fn qom_persist(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
485 state.qom_recorder.persist_history().await;
486
487 Json(json!({
488 "status": "ok",
489 "message": "QoM history persisted to disk",
490 }))
491}
492
493pub async fn learning_stats(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
497 let enabled = state.traffic_recorder.is_enabled();
498 let stats = state.traffic_recorder.get_stats();
499
500 let total_samples: usize = stats.values().sum();
501 let stype_count = stats.len();
502
503 let mut stypes_sorted: Vec<_> = stats.into_iter().collect();
505 stypes_sorted.sort_by(|a, b| b.1.cmp(&a.1));
506 let top_stypes: Vec<_> = stypes_sorted.into_iter().take(20).collect();
507
508 Json(json!({
509 "enabled": enabled,
510 "total_samples": total_samples,
511 "stype_count": stype_count,
512 "top_stypes": top_stypes.iter().map(|(stype, count)| {
513 json!({
514 "stype": stype,
515 "samples": count
516 })
517 }).collect::<Vec<_>>()
518 }))
519}
520
521pub async fn learning_samples(
525 State(state): State<Arc<ProxyState>>,
526 axum::extract::Path(stype): axum::extract::Path<String>,
527 axum::extract::Query(query): axum::extract::Query<LearningQuery>,
528) -> impl IntoResponse {
529 let samples = state.traffic_recorder.get_samples(&stype);
530 let limit = query.limit.unwrap_or(50);
531
532 let samples_json: Vec<serde_json::Value> = samples
533 .iter()
534 .rev()
535 .take(limit)
536 .map(|s| {
537 json!({
538 "id": s.id,
539 "timestamp": s.timestamp,
540 "method": s.method,
541 "path": s.path,
542 "payload": s.payload,
543 "response": s.response,
544 "status_code": s.status_code,
545 "duration_ms": s.duration_ms,
546 "validation_passed": s.validation_passed,
547 })
548 })
549 .collect();
550
551 Json(json!({
552 "stype": stype,
553 "samples": samples_json,
554 "total": samples.len(),
555 }))
556}
557
558#[derive(Debug, Deserialize)]
559pub struct LearningQuery {
560 pub limit: Option<usize>,
561}