mpl_proxy/
handlers.rs

1//! HTTP handlers for the proxy
2
3use std::sync::Arc;
4
5use axum::{
6    body::Body,
7    extract::{
8        ws::{Message, WebSocket, WebSocketUpgrade},
9        Path, State,
10    },
11    http::{Request, Response, StatusCode},
12    response::IntoResponse,
13    Json,
14};
15use futures_util::{SinkExt, StreamExt};
16use serde::{Deserialize, Serialize};
17use serde_json::json;
18use tracing::{debug, error, info};
19
20use mpl_core::envelope::MplEnvelope;
21use mpl_core::metrics::{TocMethod, TocResult};
22
23use crate::proxy::{AiAlpnClientHello, ProxyState};
24
25/// Health check endpoint
26pub async fn health() -> impl IntoResponse {
27    Json(json!({
28        "status": "healthy",
29        "version": env!("CARGO_PKG_VERSION")
30    }))
31}
32
33/// Prometheus metrics endpoint
34pub async fn metrics(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
35    let metrics = &state.metrics;
36    let schema_pass_rate = metrics.schema_pass_rate();
37    let qom_pass_rate = metrics.qom_pass_rate();
38    let downgrade_rate = metrics.downgrade_rate();
39
40    let output = format!(
41        r#"# HELP mpl_requests_total Total number of requests
42# TYPE mpl_requests_total counter
43mpl_requests_total {}
44
45# HELP mpl_schema_validations_total Schema validation results
46# TYPE mpl_schema_validations_total counter
47mpl_schema_validations_total{{result="pass"}} {}
48mpl_schema_validations_total{{result="fail"}} {}
49
50# HELP mpl_schema_pass_rate Schema validation pass rate
51# TYPE mpl_schema_pass_rate gauge
52mpl_schema_pass_rate {}
53
54# HELP mpl_qom_pass_rate QoM pass rate
55# TYPE mpl_qom_pass_rate gauge
56mpl_qom_pass_rate {}
57
58# HELP mpl_handshakes_total Total AI-ALPN handshakes
59# TYPE mpl_handshakes_total counter
60mpl_handshakes_total {}
61
62# HELP mpl_downgrade_rate Protocol downgrade rate
63# TYPE mpl_downgrade_rate gauge
64mpl_downgrade_rate {}
65"#,
66        metrics.requests_total.load(std::sync::atomic::Ordering::Relaxed),
67        metrics.schema_pass.load(std::sync::atomic::Ordering::Relaxed),
68        metrics.schema_fail.load(std::sync::atomic::Ordering::Relaxed),
69        schema_pass_rate,
70        qom_pass_rate,
71        metrics.handshakes.load(std::sync::atomic::Ordering::Relaxed),
72        downgrade_rate,
73    );
74
75    (
76        StatusCode::OK,
77        [("content-type", "text/plain; charset=utf-8")],
78        output,
79    )
80}
81
82/// AI-ALPN handshake endpoint
83pub async fn ai_alpn_handshake(
84    State(state): State<Arc<ProxyState>>,
85    Json(hello): Json<AiAlpnClientHello>,
86) -> impl IntoResponse {
87    info!("AI-ALPN handshake from client with {} STypes", hello.stypes.len());
88
89    let response = state.handle_handshake(hello);
90
91    info!(
92        "Negotiated {} common STypes, profile: {:?}",
93        response.common_stypes.len(),
94        response.selected_profile
95    );
96
97    Json(response)
98}
99
100/// WebSocket upgrade handler for MCP/A2A connections
101pub async fn websocket_handler(
102    ws: WebSocketUpgrade,
103    State(state): State<Arc<ProxyState>>,
104) -> impl IntoResponse {
105    info!("WebSocket upgrade requested");
106    ws.on_upgrade(move |socket| handle_websocket(socket, state))
107}
108
109/// Handle WebSocket connection
110async fn handle_websocket(socket: WebSocket, state: Arc<ProxyState>) {
111    let (mut sender, mut receiver) = socket.split();
112
113    info!("WebSocket connection established");
114
115    while let Some(msg) = receiver.next().await {
116        match msg {
117            Ok(Message::Text(text)) => {
118                debug!("Received WebSocket message: {} bytes", text.len());
119
120                // Try to parse as MPL envelope
121                let response = match serde_json::from_str::<MplEnvelope>(&text) {
122                    Ok(envelope) => {
123                        // Validate the envelope
124                        let validation = state.validate_request(&envelope).await;
125
126                        if !validation.valid && state.is_strict() {
127                            // Return error response
128                            json!({
129                                "error": "E-SCHEMA-FIDELITY",
130                                "message": "Validation failed",
131                                "details": validation.errors,
132                            })
133                        } else {
134                            // Forward to upstream (simplified - in real impl, maintain upstream WS connection)
135                            // For now, echo back with validation result
136                            json!({
137                                "type": "mpl-response",
138                                "stype": envelope.stype,
139                                "validation": {
140                                    "valid": validation.valid,
141                                    "schema_valid": validation.schema_valid,
142                                    "qom_passed": validation.qom_passed,
143                                },
144                                "payload": envelope.payload,
145                            })
146                        }
147                    }
148                    Err(_) => {
149                        // Try to parse as AI-ALPN handshake
150                        if let Ok(hello) = serde_json::from_str::<AiAlpnClientHello>(&text) {
151                            let select = state.handle_handshake(hello);
152                            serde_json::to_value(&select).unwrap_or_else(|_| json!({"error": "serialization failed"}))
153                        } else {
154                            // Pass through non-MPL messages
155                            json!({
156                                "type": "passthrough",
157                                "message": text,
158                            })
159                        }
160                    }
161                };
162
163                if let Err(e) = sender.send(Message::Text(response.to_string())).await {
164                    error!("Failed to send WebSocket message: {}", e);
165                    break;
166                }
167            }
168            Ok(Message::Binary(data)) => {
169                debug!("Received binary WebSocket message: {} bytes", data.len());
170                // Pass through binary messages
171                if let Err(e) = sender.send(Message::Binary(data)).await {
172                    error!("Failed to send binary WebSocket message: {}", e);
173                    break;
174                }
175            }
176            Ok(Message::Ping(data)) => {
177                if let Err(e) = sender.send(Message::Pong(data)).await {
178                    error!("Failed to send pong: {}", e);
179                    break;
180                }
181            }
182            Ok(Message::Pong(_)) => {}
183            Ok(Message::Close(_)) => {
184                info!("WebSocket connection closed by client");
185                break;
186            }
187            Err(e) => {
188                error!("WebSocket error: {}", e);
189                break;
190            }
191        }
192    }
193
194    info!("WebSocket connection ended");
195}
196
197/// Main proxy handler - forwards requests to upstream
198pub async fn proxy_handler(
199    State(state): State<Arc<ProxyState>>,
200    Path(path): Path<String>,
201    request: Request<Body>,
202) -> impl IntoResponse {
203    debug!("Proxying request to: {}", path);
204
205    match state.forward_request(path, request).await {
206        Ok(response) => response,
207        Err(e) => {
208            error!("Proxy error: {}", e);
209            Response::builder()
210                .status(StatusCode::BAD_GATEWAY)
211                .header("content-type", "application/json")
212                .body(Body::from(
213                    json!({
214                        "error": "E-PROXY-ERROR",
215                        "message": format!("Proxy error: {}", e),
216                    })
217                    .to_string(),
218                ))
219                .unwrap()
220        }
221    }
222}
223
224/// Server capabilities endpoint
225pub async fn capabilities(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
226    let stypes = state.validator.registered_stypes();
227    let profiles: Vec<&str> = state.profiles.iter().map(|p| p.name.as_str()).collect();
228
229    Json(json!({
230        "version": env!("CARGO_PKG_VERSION"),
231        "mpl_version": "1.0",
232        "capabilities": {
233            "schema_validation": state.config.mpl.enforce_schema,
234            "qom_evaluation": true,
235            "semantic_hashing": true,
236            "websocket": true,
237            "toc_callback": true,
238        },
239        "stypes": stypes,
240        "profiles": profiles,
241        "mode": format!("{:?}", state.config.mpl.mode),
242    }))
243}
244
245/// TOC callback request body
246#[derive(Debug, Clone, Serialize, Deserialize)]
247pub struct TocCallbackRequest {
248    /// The callback ID from the original request
249    pub callback_id: String,
250    /// Whether the tool outcome was verified
251    pub verified: bool,
252    /// Optional details about the verification
253    #[serde(skip_serializing_if = "Option::is_none")]
254    pub details: Option<String>,
255    /// Expected outcome (for audit)
256    #[serde(skip_serializing_if = "Option::is_none")]
257    pub expected: Option<String>,
258    /// Actual outcome observed
259    #[serde(skip_serializing_if = "Option::is_none")]
260    pub actual: Option<String>,
261}
262
263/// TOC callback response
264#[derive(Debug, Clone, Serialize, Deserialize)]
265pub struct TocCallbackResponse {
266    /// Whether the callback was accepted
267    pub accepted: bool,
268    /// The callback ID
269    pub callback_id: String,
270    /// Status message
271    pub message: String,
272}
273
274/// TOC callback endpoint - receives verification results from external systems
275///
276/// POST /_mpl/toc/callback
277/// Body: { "callback_id": "...", "verified": true/false, "details": "..." }
278pub async fn toc_callback(
279    State(state): State<Arc<ProxyState>>,
280    Json(request): Json<TocCallbackRequest>,
281) -> impl IntoResponse {
282    info!(
283        "TOC callback received: {} verified={}",
284        request.callback_id, request.verified
285    );
286
287    // Build the TOC result
288    let result = if request.verified {
289        let mut r = TocResult::verified(TocMethod::Callback);
290        r.details = request.details.clone();
291        r.expected = request.expected;
292        r.actual = request.actual;
293        r
294    } else {
295        let mut r = TocResult::failed(
296            TocMethod::Callback,
297            request.details.clone().unwrap_or_else(|| "Verification failed".to_string()),
298        );
299        r.expected = request.expected;
300        r.actual = request.actual;
301        r
302    };
303
304    // Complete the verification
305    let was_pending = state.complete_toc(&request.callback_id, result);
306
307    let response = if was_pending {
308        TocCallbackResponse {
309            accepted: true,
310            callback_id: request.callback_id,
311            message: "TOC verification recorded".to_string(),
312        }
313    } else {
314        TocCallbackResponse {
315            accepted: false,
316            callback_id: request.callback_id,
317            message: "Unknown or expired callback ID".to_string(),
318        }
319    };
320
321    Json(response)
322}
323
324/// Query TOC status for a callback ID
325///
326/// GET /_mpl/toc/status/{callback_id}
327pub async fn toc_status(
328    State(state): State<Arc<ProxyState>>,
329    Path(callback_id): Path<String>,
330) -> impl IntoResponse {
331    // Check if completed
332    if let Some(result) = state.get_toc_result(&callback_id) {
333        return Json(json!({
334            "callback_id": callback_id,
335            "status": "completed",
336            "result": result,
337        }));
338    }
339
340    // Check if pending
341    if let Some(pending) = state.get_pending_toc(&callback_id) {
342        return Json(json!({
343            "callback_id": callback_id,
344            "status": "pending",
345            "stype": pending.stype,
346            "registered_at": pending.timestamp,
347        }));
348    }
349
350    // Unknown
351    Json(json!({
352        "callback_id": callback_id,
353        "status": "unknown",
354        "message": "No verification found for this callback ID",
355    }))
356}
357
358/// List all pending TOC verifications
359///
360/// GET /_mpl/toc/pending
361pub async fn toc_pending_list(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
362    let pending: Vec<_> = state
363        .pending_toc
364        .read()
365        .map(|p| p.values().cloned().collect())
366        .unwrap_or_default();
367
368    Json(json!({
369        "pending_count": pending.len(),
370        "verifications": pending,
371    }))
372}
373
374// ============ QoM API Endpoints ============
375
376/// Query parameters for QoM events
377#[derive(Debug, Deserialize, Default)]
378pub struct QomEventsQuery {
379    /// Maximum number of events to return
380    pub limit: Option<usize>,
381}
382
383/// Query parameters for QoM history
384#[derive(Debug, Deserialize, Default)]
385pub struct QomHistoryQuery {
386    /// Time period: "1h", "6h", "24h", "7d"
387    pub period: Option<String>,
388}
389
390/// Get QoM metrics summary
391///
392/// GET /_mpl/qom
393pub async fn qom_summary(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
394    let summary = state.qom_recorder.get_summary().await;
395
396    Json(json!({
397        "metrics": {
398            "schema_fidelity": summary.schema_fidelity,
399            "instruction_compliance": summary.instruction_compliance,
400            "tool_outcome_correctness": summary.tool_outcome_correctness,
401            "groundedness": summary.groundedness,
402            "determinism_jitter": summary.determinism_jitter,
403            "ontology_adherence": summary.ontology_adherence,
404        }
405    }))
406}
407
408/// Get recent QoM events
409///
410/// GET /_mpl/qom/events?limit=50
411pub async fn qom_events(
412    State(state): State<Arc<ProxyState>>,
413    axum::extract::Query(query): axum::extract::Query<QomEventsQuery>,
414) -> impl IntoResponse {
415    let limit = query.limit.unwrap_or(50);
416    let events = state.qom_recorder.get_events(limit).await;
417
418    // Convert events to JSON-friendly format
419    let events_json: Vec<serde_json::Value> = events
420        .iter()
421        .map(|e| {
422            json!({
423                "id": e.id,
424                "stype": e.stype,
425                "profile": e.profile,
426                "passed": e.passed,
427                "scores": {
428                    "SF": e.scores.sf,
429                    "IC": e.scores.ic,
430                    "TOC": e.scores.toc,
431                    "G": e.scores.g,
432                    "DJ": e.scores.dj,
433                    "OA": e.scores.oa,
434                },
435                "failure_reason": e.failure_reason,
436                "timestamp": e.timestamp.to_rfc3339(),
437            })
438        })
439        .collect();
440
441    Json(json!({
442        "events": events_json,
443        "total": events_json.len(),
444    }))
445}
446
447/// Get QoM history for trends
448///
449/// GET /_mpl/qom/history?period=24h
450pub async fn qom_history(
451    State(state): State<Arc<ProxyState>>,
452    axum::extract::Query(query): axum::extract::Query<QomHistoryQuery>,
453) -> impl IntoResponse {
454    let period = query.period.unwrap_or_else(|| "24h".to_string());
455    let history = state.qom_recorder.get_history(&period).await;
456
457    // Convert history to JSON-friendly format
458    let history_json: Vec<serde_json::Value> = history
459        .iter()
460        .map(|h| {
461            json!({
462                "timestamp": h.timestamp.to_rfc3339(),
463                "count": h.count,
464                "sf": h.sf,
465                "ic": h.ic,
466                "toc": h.toc,
467                "g": h.g,
468                "dj": h.dj,
469                "oa": h.oa,
470                "pass_rate": h.pass_rate,
471            })
472        })
473        .collect();
474
475    Json(json!({
476        "history": history_json,
477        "period": period,
478    }))
479}
480
481/// Persist QoM history to disk (for maintenance)
482///
483/// POST /_mpl/qom/persist
484pub async fn qom_persist(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
485    state.qom_recorder.persist_history().await;
486
487    Json(json!({
488        "status": "ok",
489        "message": "QoM history persisted to disk",
490    }))
491}
492
493/// Get learning/traffic recording statistics
494///
495/// GET /_mpl/learning/stats
496pub async fn learning_stats(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
497    let enabled = state.traffic_recorder.is_enabled();
498    let stats = state.traffic_recorder.get_stats();
499
500    let total_samples: usize = stats.values().sum();
501    let stype_count = stats.len();
502
503    // Get top stypes by sample count
504    let mut stypes_sorted: Vec<_> = stats.into_iter().collect();
505    stypes_sorted.sort_by(|a, b| b.1.cmp(&a.1));
506    let top_stypes: Vec<_> = stypes_sorted.into_iter().take(20).collect();
507
508    Json(json!({
509        "enabled": enabled,
510        "total_samples": total_samples,
511        "stype_count": stype_count,
512        "top_stypes": top_stypes.iter().map(|(stype, count)| {
513            json!({
514                "stype": stype,
515                "samples": count
516            })
517        }).collect::<Vec<_>>()
518    }))
519}
520
521/// Get traffic samples for a specific SType
522///
523/// GET /_mpl/learning/samples/:stype
524pub async fn learning_samples(
525    State(state): State<Arc<ProxyState>>,
526    axum::extract::Path(stype): axum::extract::Path<String>,
527    axum::extract::Query(query): axum::extract::Query<LearningQuery>,
528) -> impl IntoResponse {
529    let samples = state.traffic_recorder.get_samples(&stype);
530    let limit = query.limit.unwrap_or(50);
531
532    let samples_json: Vec<serde_json::Value> = samples
533        .iter()
534        .rev()
535        .take(limit)
536        .map(|s| {
537            json!({
538                "id": s.id,
539                "timestamp": s.timestamp,
540                "method": s.method,
541                "path": s.path,
542                "payload": s.payload,
543                "response": s.response,
544                "status_code": s.status_code,
545                "duration_ms": s.duration_ms,
546                "validation_passed": s.validation_passed,
547            })
548        })
549        .collect();
550
551    Json(json!({
552        "stype": stype,
553        "samples": samples_json,
554        "total": samples.len(),
555    }))
556}
557
558#[derive(Debug, Deserialize)]
559pub struct LearningQuery {
560    pub limit: Option<usize>,
561}