mpl_proxy/
proxy.rs

1//! Core proxy logic
2
3use std::collections::HashMap;
4use std::path::Path;
5use std::sync::{Arc, RwLock};
6
7use anyhow::Result;
8use axum::body::Body;
9use axum::http::{Request, Response, StatusCode};
10use reqwest::Client;
11use serde::{Deserialize, Serialize};
12use tracing::{debug, info, warn};
13
14use mpl_core::assertions::{AssertionSet, EvaluationContext};
15use mpl_core::envelope::MplEnvelope;
16use mpl_core::hash::{semantic_hash, verify_hash};
17use mpl_core::metrics::{TocMethod, TocResult};
18use mpl_core::ontology::Ontology;
19use mpl_core::qom::{QomMetrics, QomProfile};
20use mpl_core::validation::SchemaValidator;
21
22use crate::config::{ProxyConfig, ProxyMode};
23use crate::metrics::MetricsState;
24use crate::qom_recorder::{QomRecorder, QomRecorderConfig, QomScores};
25use crate::traffic::TrafficRecorder;
26
27/// MPL headers
28pub const HEADER_STYPE: &str = "X-MPL-SType";
29pub const HEADER_PROFILE: &str = "X-MPL-Profile";
30pub const HEADER_SEM_HASH: &str = "X-MPL-Sem-Hash";
31pub const HEADER_QOM_RESULT: &str = "X-MPL-QoM-Result";
32/// TOC verification result header (values: "verified", "failed", "pending", "skip")
33pub const HEADER_TOC_RESULT: &str = "X-MPL-TOC-Result";
34/// TOC verification callback ID (for async verification)
35pub const HEADER_TOC_CALLBACK_ID: &str = "X-MPL-TOC-Callback-Id";
36
37/// Validation result for a request
38#[derive(Debug, Clone, Serialize)]
39pub struct ValidationResult {
40    pub valid: bool,
41    pub stype: Option<String>,
42    pub schema_valid: bool,
43    pub qom_passed: bool,
44    pub hash_valid: bool,
45    /// TOC result if available (from header or callback)
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub toc_result: Option<TocResult>,
48    /// IC score if computed
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub ic_score: Option<f64>,
51    /// Profile used for evaluation
52    #[serde(skip_serializing_if = "Option::is_none")]
53    pub profile_used: Option<String>,
54    /// Whether profile was degraded from original
55    #[serde(default)]
56    pub degraded: bool,
57    /// Original profile before degradation
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub original_profile: Option<String>,
60    /// Retry count (0 = first attempt)
61    #[serde(default)]
62    pub retry_count: u32,
63    pub errors: Vec<String>,
64}
65
66impl Default for ValidationResult {
67    fn default() -> Self {
68        Self {
69            valid: true,
70            stype: None,
71            schema_valid: true,
72            qom_passed: true,
73            hash_valid: true,
74            toc_result: None,
75            ic_score: None,
76            profile_used: None,
77            degraded: false,
78            original_profile: None,
79            retry_count: 0,
80            errors: Vec::new(),
81        }
82    }
83}
84
85/// Pending TOC verification tracking
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct PendingTocVerification {
88    /// Unique callback ID
89    pub callback_id: String,
90    /// SType being verified
91    pub stype: String,
92    /// Request timestamp
93    pub timestamp: String,
94    /// Expected outcome (if specified)
95    pub expected_outcome: Option<String>,
96    /// Tool name (if applicable)
97    pub tool_name: Option<String>,
98}
99
100/// Shared proxy state
101pub struct ProxyState {
102    /// Configuration
103    pub config: ProxyConfig,
104
105    /// HTTP client for upstream requests
106    pub client: Client,
107
108    /// Schema validator
109    pub validator: SchemaValidator,
110
111    /// Assertions by SType (for IC computation)
112    pub assertions: Arc<RwLock<HashMap<String, AssertionSet>>>,
113
114    /// QoM profiles
115    pub profiles: Vec<QomProfile>,
116
117    /// Metrics
118    pub metrics: Arc<MetricsState>,
119
120    /// QoM recorder for full metric tracking and persistence
121    pub qom_recorder: Arc<QomRecorder>,
122
123    /// Traffic recorder for schema inference
124    pub traffic_recorder: Arc<TrafficRecorder>,
125
126    /// Pending TOC verifications (callback_id -> verification)
127    pub pending_toc: Arc<RwLock<HashMap<String, PendingTocVerification>>>,
128
129    /// Completed TOC results (callback_id -> result)
130    pub completed_toc: Arc<RwLock<HashMap<String, TocResult>>>,
131
132    /// Counter for generating callback IDs
133    toc_counter: std::sync::atomic::AtomicU64,
134}
135
136impl ProxyState {
137    /// Create a new proxy state from configuration
138    pub async fn new(config: ProxyConfig) -> Result<Self> {
139        Self::with_options(config, None, false).await
140    }
141
142    /// Create a new proxy state with traffic recording options
143    pub async fn with_options(
144        config: ProxyConfig,
145        data_dir: Option<&str>,
146        learning_enabled: bool,
147    ) -> Result<Self> {
148        // Build HTTP client with configured timeouts
149        let client = Client::builder()
150            .connect_timeout(config.transport.connect_timeout())
151            .timeout(config.transport.request_timeout())
152            .pool_idle_timeout(config.transport.idle_timeout())
153            .build()?;
154
155        let mut validator = SchemaValidator::new();
156        let mut assertions_map: HashMap<String, AssertionSet> = HashMap::new();
157
158        // Load schemas and assertions from registry if it's a local path
159        let registry_path = &config.mpl.registry;
160        if Path::new(registry_path).exists() {
161            Self::load_schemas_from_registry(&mut validator, registry_path)?;
162            Self::load_assertions_from_registry(&mut assertions_map, registry_path)?;
163        }
164
165        let profiles = vec![
166            QomProfile::basic(),
167            QomProfile::strict_argcheck(),
168            QomProfile::outcome(),
169            QomProfile::comprehensive(),
170        ];
171
172        let metrics = Arc::new(MetricsState::new());
173
174        // Initialize traffic recorder
175        let data_path = data_dir
176            .map(Path::new)
177            .unwrap_or_else(|| Path::new("~/.mpl"));
178        let traffic_recorder = Arc::new(TrafficRecorder::new(data_path, learning_enabled));
179
180        if learning_enabled {
181            // Load existing samples from disk
182            if let Err(e) = traffic_recorder.load_from_disk() {
183                warn!("Failed to load existing traffic samples: {}", e);
184            }
185            info!("Traffic recording enabled");
186        }
187
188        // Initialize QoM recorder
189        let qom_data_dir = data_path.join("qom");
190        let qom_recorder = Arc::new(QomRecorder::new(QomRecorderConfig {
191            data_dir: qom_data_dir,
192            ..Default::default()
193        }));
194
195        // Load QoM events from disk
196        if let Err(e) = qom_recorder.load_from_disk().await {
197            warn!("Failed to load QoM events: {}", e);
198        }
199
200        // Load ontology specs from registry
201        Self::load_ontologies_from_registry(&qom_recorder, registry_path).await;
202
203        info!("Proxy state initialized");
204        info!("Mode: {:?}", config.mpl.mode);
205        info!("Loaded {} schemas", validator.registered_stypes().len());
206        info!("Loaded {} assertion sets", assertions_map.len());
207        info!("Loaded {} QoM profiles", profiles.len());
208
209        Ok(Self {
210            config,
211            client,
212            validator,
213            assertions: Arc::new(RwLock::new(assertions_map)),
214            profiles,
215            metrics,
216            qom_recorder,
217            traffic_recorder,
218            pending_toc: Arc::new(RwLock::new(HashMap::new())),
219            completed_toc: Arc::new(RwLock::new(HashMap::new())),
220            toc_counter: std::sync::atomic::AtomicU64::new(0),
221        })
222    }
223
224    /// Generate a unique TOC callback ID
225    pub fn next_toc_callback_id(&self) -> String {
226        let count = self.toc_counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
227        format!("toc-{:016x}", count)
228    }
229
230    /// Register a pending TOC verification
231    pub fn register_pending_toc(&self, verification: PendingTocVerification) {
232        if let Ok(mut pending) = self.pending_toc.write() {
233            pending.insert(verification.callback_id.clone(), verification);
234        }
235    }
236
237    /// Complete a TOC verification (called from callback endpoint)
238    pub fn complete_toc(&self, callback_id: &str, result: TocResult) -> bool {
239        // Remove from pending
240        let mut was_pending = false;
241        if let Ok(mut pending) = self.pending_toc.write() {
242            was_pending = pending.remove(callback_id).is_some();
243        }
244
245        // Add to completed
246        if let Ok(mut completed) = self.completed_toc.write() {
247            completed.insert(callback_id.to_string(), result);
248        }
249
250        was_pending
251    }
252
253    /// Get completed TOC result for a callback ID
254    pub fn get_toc_result(&self, callback_id: &str) -> Option<TocResult> {
255        self.completed_toc
256            .read()
257            .ok()
258            .and_then(|completed| completed.get(callback_id).cloned())
259    }
260
261    /// Get pending TOC verification
262    pub fn get_pending_toc(&self, callback_id: &str) -> Option<PendingTocVerification> {
263        self.pending_toc
264            .read()
265            .ok()
266            .and_then(|pending| pending.get(callback_id).cloned())
267    }
268
269    /// Parse TOC result from header value
270    pub fn parse_toc_header(value: &str) -> Option<TocResult> {
271        match value.to_lowercase().as_str() {
272            "verified" | "pass" | "true" | "1" => Some(TocResult::verified(TocMethod::Header)),
273            "failed" | "fail" | "false" | "0" => {
274                Some(TocResult::failed(TocMethod::Header, "Verification failed"))
275            }
276            "pending" => None, // Still waiting
277            "skip" | "na" => Some(TocResult::verified(TocMethod::None)), // Not applicable
278            _ => None,
279        }
280    }
281
282    /// Load schemas from local registry directory
283    fn load_schemas_from_registry(validator: &mut SchemaValidator, registry_path: &str) -> Result<()> {
284        let stypes_path = Path::new(registry_path).join("stypes");
285        if !stypes_path.exists() {
286            debug!("Registry stypes path does not exist: {}", stypes_path.display());
287            return Ok(());
288        }
289
290        // Walk the stypes directory structure: namespace/domain/Name/vN/schema.json
291        Self::walk_registry_dir(validator, &stypes_path, Vec::new())?;
292        Ok(())
293    }
294
295    fn walk_registry_dir(
296        validator: &mut SchemaValidator,
297        path: &Path,
298        parts: Vec<String>,
299    ) -> Result<()> {
300        if !path.is_dir() {
301            return Ok(());
302        }
303
304        for entry in std::fs::read_dir(path)? {
305            let entry = entry?;
306            let entry_path = entry.path();
307            let name = entry.file_name().to_string_lossy().to_string();
308
309            if entry_path.is_dir() {
310                let mut new_parts = parts.clone();
311                new_parts.push(name);
312                Self::walk_registry_dir(validator, &entry_path, new_parts)?;
313            } else if name == "schema.json" && parts.len() >= 4 {
314                // We have: namespace/domain/Name/vN/schema.json
315                // parts = [namespace, domain, Name, vN]
316                let version_str = &parts[parts.len() - 1];
317                if let Some(version) = version_str.strip_prefix('v') {
318                    if version.parse::<u32>().is_ok() {
319                        let namespace = parts[..parts.len() - 2].join(".");
320                        let name = &parts[parts.len() - 2];
321                        let stype = format!(
322                            "{}.{}.{}",
323                            namespace,
324                            name,
325                            version_str
326                        );
327
328                        // Read and register schema
329                        if let Ok(schema_content) = std::fs::read_to_string(&entry_path) {
330                            if validator.register_json(&stype, &schema_content).is_ok() {
331                                debug!("Registered schema for {}", stype);
332                            }
333                        }
334                    }
335                }
336            }
337        }
338        Ok(())
339    }
340
341    /// Load assertions from local registry directory
342    fn load_assertions_from_registry(
343        assertions: &mut HashMap<String, AssertionSet>,
344        registry_path: &str,
345    ) -> Result<()> {
346        let stypes_path = Path::new(registry_path).join("stypes");
347        if !stypes_path.exists() {
348            return Ok(());
349        }
350
351        // Walk the stypes directory structure looking for assertions.json files
352        Self::walk_registry_dir_for_assertions(assertions, &stypes_path, Vec::new())?;
353        Ok(())
354    }
355
356    fn walk_registry_dir_for_assertions(
357        assertions: &mut HashMap<String, AssertionSet>,
358        path: &Path,
359        parts: Vec<String>,
360    ) -> Result<()> {
361        if !path.is_dir() {
362            return Ok(());
363        }
364
365        for entry in std::fs::read_dir(path)? {
366            let entry = entry?;
367            let entry_path = entry.path();
368            let name = entry.file_name().to_string_lossy().to_string();
369
370            if entry_path.is_dir() {
371                let mut new_parts = parts.clone();
372                new_parts.push(name);
373                Self::walk_registry_dir_for_assertions(assertions, &entry_path, new_parts)?;
374            } else if name == "assertions.json" && parts.len() >= 4 {
375                // We have: namespace/domain/Name/vN/assertions.json
376                let version_str = &parts[parts.len() - 1];
377                if let Some(version) = version_str.strip_prefix('v') {
378                    if version.parse::<u32>().is_ok() {
379                        let namespace = parts[..parts.len() - 2].join(".");
380                        let type_name = &parts[parts.len() - 2];
381                        let stype = format!("{}.{}.{}", namespace, type_name, version_str);
382
383                        // Read and parse assertions
384                        if let Ok(content) = std::fs::read_to_string(&entry_path) {
385                            match serde_json::from_str::<AssertionSet>(&content) {
386                                Ok(assertion_set) => {
387                                    debug!(
388                                        "Loaded {} assertions for {}",
389                                        assertion_set.assertions.len(),
390                                        stype
391                                    );
392                                    assertions.insert(stype, assertion_set);
393                                }
394                                Err(e) => {
395                                    warn!(
396                                        "Failed to parse assertions for {}: {}",
397                                        entry_path.display(),
398                                        e
399                                    );
400                                }
401                            }
402                        }
403                    }
404                }
405            }
406        }
407        Ok(())
408    }
409
410    /// Load ontology specs from local registry directory
411    async fn load_ontologies_from_registry(qom_recorder: &QomRecorder, registry_path: &str) {
412        let stypes_path = Path::new(registry_path).join("stypes");
413        if !stypes_path.exists() {
414            return;
415        }
416
417        // Collect all ontologies synchronously first
418        let ontologies = Self::collect_ontologies_from_registry(&stypes_path, Vec::new());
419
420        // Then load them asynchronously
421        for (stype, spec) in ontologies {
422            debug!("Loaded ontology for {}", stype);
423            qom_recorder.load_ontology(&stype, spec).await;
424        }
425    }
426
427    fn collect_ontologies_from_registry(
428        path: &Path,
429        parts: Vec<String>,
430    ) -> Vec<(String, Ontology)> {
431        let mut result = Vec::new();
432
433        if !path.is_dir() {
434            return result;
435        }
436
437        if let Ok(entries) = std::fs::read_dir(path) {
438            for entry in entries.flatten() {
439                let entry_path = entry.path();
440                let name = entry.file_name().to_string_lossy().to_string();
441
442                if entry_path.is_dir() {
443                    let mut new_parts = parts.clone();
444                    new_parts.push(name);
445                    result.extend(Self::collect_ontologies_from_registry(&entry_path, new_parts));
446                } else if name == "ontology.json" && parts.len() >= 4 {
447                    // We have: namespace/domain/Name/vN/ontology.json
448                    let version_str = &parts[parts.len() - 1];
449                    if let Some(version) = version_str.strip_prefix('v') {
450                        if version.parse::<u32>().is_ok() {
451                            let namespace = parts[..parts.len() - 2].join(".");
452                            let type_name = &parts[parts.len() - 2];
453                            let stype = format!("{}.{}.{}", namespace, type_name, version_str);
454
455                            // Read and parse ontology spec
456                            if let Ok(content) = std::fs::read_to_string(&entry_path) {
457                                match serde_json::from_str::<Ontology>(&content) {
458                                    Ok(spec) => {
459                                        result.push((stype, spec));
460                                    }
461                                    Err(e) => {
462                                        warn!(
463                                            "Failed to parse ontology for {}: {}",
464                                            entry_path.display(),
465                                            e
466                                        );
467                                    }
468                                }
469                            }
470                        }
471                    }
472                }
473            }
474        }
475
476        result
477    }
478
479    /// Get assertions for an SType
480    pub fn get_assertions(&self, stype: &str) -> Option<AssertionSet> {
481        self.assertions
482            .read()
483            .ok()
484            .and_then(|a| a.get(stype).cloned())
485    }
486
487    /// Validate an MPL request
488    pub async fn validate_request(&self, envelope: &MplEnvelope) -> ValidationResult {
489        self.validate_request_full(envelope, None).await
490    }
491
492    /// Validate an MPL request with optional response for determinism checking
493    pub async fn validate_request_full(
494        &self,
495        envelope: &MplEnvelope,
496        response: Option<&serde_json::Value>,
497    ) -> ValidationResult {
498        let mut result = ValidationResult {
499            stype: Some(envelope.stype.clone()),
500            ..Default::default()
501        };
502
503        // Compute payload hash for determinism tracking
504        let payload_hash = semantic_hash(&envelope.payload).ok();
505
506        // Schema validation
507        let sf_score = if self.config.mpl.enforce_schema {
508            match self.validator.validate(&envelope.stype, &envelope.payload) {
509                Ok(validation) => {
510                    result.schema_valid = validation.valid;
511                    if !validation.valid {
512                        result.valid = false;
513                        for err in validation.errors {
514                            result.errors.push(format!("Schema error at {}: {}", err.path, err.message));
515                        }
516                    }
517                    if validation.valid { 1.0 } else { 0.0 }
518                }
519                Err(e) => {
520                    // Unknown SType - check mode
521                    if self.is_strict() {
522                        result.valid = false;
523                        result.schema_valid = false;
524                        result.errors.push(format!("Unknown SType: {} ({})", envelope.stype, e));
525                        0.0
526                    } else {
527                        warn!("Unknown SType: {}, allowing in transparent mode", envelope.stype);
528                        1.0
529                    }
530                }
531            }
532        } else {
533            1.0
534        };
535
536        // Hash verification (if provided)
537        if let Some(ref expected_hash) = envelope.sem_hash {
538            match verify_hash(&envelope.payload, expected_hash) {
539                Ok(valid) => {
540                    result.hash_valid = valid;
541                    if !valid {
542                        result.valid = false;
543                        result.errors.push("Semantic hash mismatch".to_string());
544                    }
545                }
546                Err(e) => {
547                    result.hash_valid = false;
548                    result.valid = false;
549                    result.errors.push(format!("Hash verification failed: {}", e));
550                }
551            }
552        }
553
554        // Instruction Compliance (IC) - evaluate assertions if available
555        let ic_score = if let Some(assertion_set) = self.get_assertions(&envelope.stype) {
556            let eval_ctx = EvaluationContext {
557                stype: Some(envelope.stype.clone()),
558                ..Default::default()
559            };
560
561            match assertion_set.evaluate_with_context(&envelope.payload, &eval_ctx) {
562                Ok(assertion_result) => {
563                    let score = assertion_result.ic_score;
564                    result.ic_score = Some(score);
565
566                    // Add assertion failures to errors
567                    for ar in &assertion_result.results {
568                        if !ar.passed {
569                            match ar.severity {
570                                mpl_core::assertions::AssertionSeverity::Error => {
571                                    result.errors.push(format!("IC error [{}]: {}", ar.id, ar.message));
572                                }
573                                mpl_core::assertions::AssertionSeverity::Warning => {
574                                    debug!("IC warning [{}]: {}", ar.id, ar.message);
575                                }
576                                mpl_core::assertions::AssertionSeverity::Info => {
577                                    debug!("IC info [{}]: {}", ar.id, ar.message);
578                                }
579                            }
580                        }
581                    }
582
583                    // Check if any error-severity assertions failed
584                    if assertion_result.error_count > 0 && self.is_strict() {
585                        result.valid = false;
586                    }
587
588                    Some(score)
589                }
590                Err(e) => {
591                    warn!("Assertion evaluation failed for {}: {}", envelope.stype, e);
592                    None
593                }
594            }
595        } else {
596            None
597        };
598
599        // Ontology Adherence (OA) - check domain constraints
600        let oa_result = self.qom_recorder.check_ontology(&envelope.stype, &envelope.payload).await;
601        let oa_score = if oa_result.constraints_checked > 0 {
602            Some(oa_result.score)
603        } else {
604            None
605        };
606
607        // Determinism Jitter (DJ) - check response stability if response provided
608        let dj_score = if let (Some(resp), Some(ref hash)) = (response, &payload_hash) {
609            let dj_result = self.qom_recorder.check_determinism(&envelope.stype, hash, resp).await;
610            if dj_result.comparison_count > 0 {
611                Some(dj_result.similarity)
612            } else {
613                None
614            }
615        } else {
616            None
617        };
618
619        // QoM evaluation - now includes all computed metrics
620        let profile_name = if let Some(profile) = self.active_profile() {
621            let mut metrics = if result.schema_valid {
622                QomMetrics::schema_valid()
623            } else {
624                QomMetrics::schema_invalid()
625            };
626
627            // Add all computed metrics
628            if let Some(ic) = ic_score {
629                metrics = metrics.with_instruction_compliance(ic);
630            }
631            if let Some(oa) = oa_score {
632                metrics = metrics.with_ontology_adherence(oa);
633            }
634            if let Some(dj) = dj_score {
635                metrics = metrics.with_determinism_jitter(dj);
636            }
637
638            let evaluation = profile.evaluate(&metrics);
639            result.qom_passed = evaluation.meets_profile;
640            result.profile_used = Some(profile.name.clone());
641
642            if !evaluation.meets_profile {
643                result.valid = false;
644                for failure in evaluation.failures {
645                    result
646                        .errors
647                        .push(format!("QoM breach: {} < {}", failure.metric, failure.threshold));
648                }
649            }
650
651            Some(profile.name.clone())
652        } else {
653            None
654        };
655
656        // Update metrics counters
657        self.metrics.inc_requests();
658        if result.schema_valid {
659            self.metrics.inc_schema_pass();
660        } else {
661            self.metrics.inc_schema_fail();
662        }
663        if result.qom_passed {
664            self.metrics.inc_qom_pass();
665        } else {
666            self.metrics.inc_qom_fail();
667        }
668
669        // Record QoM event
670        let scores = QomScores {
671            sf: Some(sf_score),
672            ic: ic_score,
673            toc: None, // TOC is async, handled separately
674            g: None,   // Groundedness requires response content analysis
675            dj: dj_score,
676            oa: oa_score,
677        };
678
679        let failure_reason = if !result.qom_passed && !result.errors.is_empty() {
680            Some(result.errors.join("; "))
681        } else {
682            None
683        };
684
685        let event = self.qom_recorder.create_event(
686            &envelope.stype,
687            &profile_name.unwrap_or_else(|| "none".to_string()),
688            result.qom_passed,
689            scores,
690            failure_reason,
691            payload_hash,
692        );
693        self.qom_recorder.record_event(event).await;
694
695        result
696    }
697
698    /// Forward a request to the upstream server
699    pub async fn forward_request(
700        &self,
701        path: String,
702        request: Request<Body>,
703    ) -> Result<Response<Body>> {
704        use crate::traffic::{StypeInferrer, TrafficRecord};
705
706        let start_time = std::time::Instant::now();
707        let upstream = &self.config.transport.upstream;
708        let uri = format!("http://{}/{}", upstream, path);
709
710        debug!("Forwarding to: {}", uri);
711
712        // Extract headers
713        let method = request.method().clone();
714        let method_str = method.to_string();
715        let headers = request.headers().clone();
716        let stype_header = headers
717            .get(HEADER_STYPE)
718            .and_then(|v| v.to_str().ok())
719            .map(String::from);
720
721        // Read body
722        let body_bytes = axum::body::to_bytes(request.into_body(), usize::MAX).await?;
723
724        // Parse payload for traffic recording
725        let payload: serde_json::Value = serde_json::from_slice(&body_bytes)
726            .unwrap_or(serde_json::Value::Null);
727
728        // Try to parse as MPL envelope or create one from headers
729        let envelope = if let Ok(env) = serde_json::from_slice::<MplEnvelope>(&body_bytes) {
730            Some(env)
731        } else if let Some(stype) = stype_header.clone() {
732            // Create envelope from headers + body
733            if let Ok(payload) = serde_json::from_slice(&body_bytes) {
734                Some(MplEnvelope::new(stype, payload))
735            } else {
736                None
737            }
738        } else {
739            None
740        };
741
742        // Determine SType (from envelope, header, or inferred)
743        let stype = envelope
744            .as_ref()
745            .map(|e| e.stype.clone())
746            .or(stype_header)
747            .unwrap_or_else(|| StypeInferrer::infer(&path, &method_str, &payload));
748
749        // Validate if we have an envelope
750        let validation_result = if let Some(ref env) = envelope {
751            let result = self.validate_request(env).await;
752
753            // In strict mode, block invalid requests
754            if !result.valid && self.is_strict() {
755                let error_response = serde_json::json!({
756                    "error": "E-SCHEMA-FIDELITY",
757                    "message": "Request validation failed",
758                    "details": result.errors,
759                });
760
761                // Record failed request if learning is enabled
762                if self.traffic_recorder.is_enabled() {
763                    let record = TrafficRecord {
764                        id: self.traffic_recorder.next_id(),
765                        timestamp: chrono::Utc::now().to_rfc3339(),
766                        stype: stype.clone(),
767                        method: method_str.clone(),
768                        path: path.clone(),
769                        payload: payload.clone(),
770                        response: Some(error_response.clone()),
771                        status_code: Some(400),
772                        duration_ms: Some(start_time.elapsed().as_millis() as u64),
773                        validation_passed: false,
774                        validation_errors: result.errors.clone(),
775                    };
776                    self.traffic_recorder.record(record);
777                }
778
779                return Ok(Response::builder()
780                    .status(StatusCode::BAD_REQUEST)
781                    .header("content-type", "application/json")
782                    .header(HEADER_QOM_RESULT, if result.qom_passed { "pass" } else { "fail" })
783                    .body(Body::from(serde_json::to_string(&error_response)?))?);
784            }
785
786            Some(result)
787        } else {
788            None
789        };
790
791        // Build upstream request
792        let mut req_builder = self.client.request(method, &uri);
793
794        for (name, value) in headers.iter() {
795            if name != "host" {
796                req_builder = req_builder.header(name, value);
797            }
798        }
799
800        let upstream_response = req_builder.body(body_bytes.to_vec()).send().await?;
801
802        // Convert response back to axum
803        let status = upstream_response.status();
804        let status_code = status.as_u16();
805        let response_headers = upstream_response.headers().clone();
806        let body = upstream_response.bytes().await?;
807
808        // Record traffic if learning is enabled
809        if self.traffic_recorder.is_enabled() {
810            let response_payload: Option<serde_json::Value> = serde_json::from_slice(&body).ok();
811
812            let record = TrafficRecord {
813                id: self.traffic_recorder.next_id(),
814                timestamp: chrono::Utc::now().to_rfc3339(),
815                stype,
816                method: method_str,
817                path,
818                payload,
819                response: response_payload,
820                status_code: Some(status_code),
821                duration_ms: Some(start_time.elapsed().as_millis() as u64),
822                validation_passed: validation_result.as_ref().map(|r| r.valid).unwrap_or(true),
823                validation_errors: validation_result
824                    .as_ref()
825                    .map(|r| r.errors.clone())
826                    .unwrap_or_default(),
827            };
828            self.traffic_recorder.record(record);
829        }
830
831        let mut response = Response::builder().status(status);
832
833        for (name, value) in response_headers.iter() {
834            response = response.header(name, value);
835        }
836
837        // Add MPL headers to response
838        if let Some(ref result) = validation_result {
839            response = response.header(HEADER_QOM_RESULT, if result.qom_passed { "pass" } else { "fail" });
840        }
841
842        Ok(response.body(Body::from(body))?)
843    }
844
845    /// Get the active QoM profile
846    pub fn active_profile(&self) -> Option<&QomProfile> {
847        let profile_name = self.config.mpl.required_profile.as_ref()?;
848        self.profiles.iter().find(|p| &p.name == profile_name)
849    }
850
851    /// Get a profile by name
852    pub fn get_profile(&self, name: &str) -> Option<&QomProfile> {
853        self.profiles.iter().find(|p| p.name == name)
854    }
855
856    /// Get the degradation chain for a profile
857    /// Returns profiles from strictest to most lenient
858    pub fn get_degradation_chain(&self, start_profile: &str) -> Vec<&QomProfile> {
859        // Profile degradation order: comprehensive -> outcome -> strict-argcheck -> basic
860        let order = ["qom-comprehensive", "qom-outcome", "qom-strict-argcheck", "qom-basic"];
861
862        let start_idx = order.iter().position(|&p| p == start_profile).unwrap_or(0);
863
864        order[start_idx..]
865            .iter()
866            .filter_map(|&name| self.get_profile(name))
867            .collect()
868    }
869
870    /// Validate with automatic profile degradation
871    /// Returns (result, final_profile_name, was_degraded)
872    pub async fn validate_with_degradation(
873        &self,
874        envelope: &MplEnvelope,
875    ) -> ValidationResult {
876        let original_profile = self.config.mpl.required_profile.clone();
877
878        if let Some(ref profile_name) = original_profile {
879            let chain = self.get_degradation_chain(profile_name);
880
881            for (idx, profile) in chain.iter().enumerate() {
882                let mut result = self.validate_request_with_profile(envelope, profile);
883
884                if result.qom_passed || idx == chain.len() - 1 {
885                    // Either passed or we're at the end of the chain
886                    result.profile_used = Some(profile.name.clone());
887                    result.degraded = idx > 0;
888                    if idx > 0 {
889                        result.original_profile = Some(profile_name.clone());
890                    }
891                    return result;
892                }
893            }
894        }
895
896        // No profile configured, just validate
897        self.validate_request(envelope).await
898    }
899
900    /// Validate with a specific profile (internal)
901    fn validate_request_with_profile(
902        &self,
903        envelope: &MplEnvelope,
904        profile: &QomProfile,
905    ) -> ValidationResult {
906        let mut result = ValidationResult {
907            stype: Some(envelope.stype.clone()),
908            profile_used: Some(profile.name.clone()),
909            ..Default::default()
910        };
911
912        // Schema validation
913        if self.config.mpl.enforce_schema {
914            match self.validator.validate(&envelope.stype, &envelope.payload) {
915                Ok(validation) => {
916                    result.schema_valid = validation.valid;
917                    if !validation.valid {
918                        result.valid = false;
919                        for err in validation.errors {
920                            result.errors.push(format!("Schema error at {}: {}", err.path, err.message));
921                        }
922                    }
923                }
924                Err(e) => {
925                    if self.is_strict() {
926                        result.valid = false;
927                        result.schema_valid = false;
928                        result.errors.push(format!("Unknown SType: {} ({})", envelope.stype, e));
929                    }
930                }
931            }
932        }
933
934        // IC evaluation
935        let ic_score = if let Some(assertion_set) = self.get_assertions(&envelope.stype) {
936            let eval_ctx = EvaluationContext {
937                stype: Some(envelope.stype.clone()),
938                ..Default::default()
939            };
940
941            match assertion_set.evaluate_with_context(&envelope.payload, &eval_ctx) {
942                Ok(assertion_result) => {
943                    result.ic_score = Some(assertion_result.ic_score);
944                    Some(assertion_result.ic_score)
945                }
946                Err(_) => None,
947            }
948        } else {
949            None
950        };
951
952        // Build metrics and evaluate against profile
953        let mut metrics = if result.schema_valid {
954            QomMetrics::schema_valid()
955        } else {
956            QomMetrics::schema_invalid()
957        };
958
959        if let Some(ic) = ic_score {
960            metrics = metrics.with_instruction_compliance(ic);
961        }
962
963        let evaluation = profile.evaluate(&metrics);
964        result.qom_passed = evaluation.meets_profile;
965
966        if !evaluation.meets_profile {
967            for failure in evaluation.failures {
968                result.errors.push(format!("QoM breach [{}]: {} < {}",
969                    profile.name, failure.metric, failure.threshold));
970            }
971        }
972
973        result
974    }
975
976    /// Check if we're in strict mode
977    pub fn is_strict(&self) -> bool {
978        matches!(self.config.mpl.mode, ProxyMode::Strict)
979    }
980}
981
982/// AI-ALPN Client Hello message
983#[derive(Debug, Clone, Serialize, Deserialize)]
984pub struct AiAlpnClientHello {
985    #[serde(rename = "type")]
986    pub msg_type: String,
987    pub version: String,
988    pub stypes: Vec<String>,
989    #[serde(default)]
990    pub qom_profiles: Vec<String>,
991}
992
993/// AI-ALPN Server Select message
994#[derive(Debug, Clone, Serialize, Deserialize)]
995pub struct AiAlpnServerSelect {
996    #[serde(rename = "type")]
997    pub msg_type: String,
998    pub common_stypes: Vec<String>,
999    pub selected_profile: Option<String>,
1000    #[serde(default)]
1001    pub extensions: serde_json::Value,
1002}
1003
1004impl ProxyState {
1005    /// Handle AI-ALPN handshake
1006    pub fn handle_handshake(&self, hello: AiAlpnClientHello) -> AiAlpnServerSelect {
1007        // Find common STypes
1008        let server_stypes: Vec<String> = self.validator.registered_stypes().iter().map(|s| s.to_string()).collect();
1009        let common_stypes: Vec<String> = hello
1010            .stypes
1011            .iter()
1012            .filter(|s| server_stypes.contains(s))
1013            .cloned()
1014            .collect();
1015
1016        // Select profile
1017        let selected_profile = hello
1018            .qom_profiles
1019            .iter()
1020            .find(|p| self.profiles.iter().any(|sp| &sp.name == *p))
1021            .cloned()
1022            .or_else(|| self.config.mpl.required_profile.clone());
1023
1024        self.metrics.inc_handshakes();
1025
1026        AiAlpnServerSelect {
1027            msg_type: "ai-alpn-select".to_string(),
1028            common_stypes,
1029            selected_profile,
1030            extensions: serde_json::json!({}),
1031        }
1032    }
1033}