1use std::collections::HashMap;
4use std::path::Path;
5use std::sync::{Arc, RwLock};
6
7use anyhow::Result;
8use axum::body::Body;
9use axum::http::{Request, Response, StatusCode};
10use reqwest::Client;
11use serde::{Deserialize, Serialize};
12use tracing::{debug, info, warn};
13
14use mpl_core::assertions::{AssertionSet, EvaluationContext};
15use mpl_core::envelope::MplEnvelope;
16use mpl_core::hash::{semantic_hash, verify_hash};
17use mpl_core::metrics::{TocMethod, TocResult};
18use mpl_core::ontology::Ontology;
19use mpl_core::qom::{QomMetrics, QomProfile};
20use mpl_core::validation::SchemaValidator;
21
22use crate::config::{ProxyConfig, ProxyMode};
23use crate::metrics::MetricsState;
24use crate::qom_recorder::{QomRecorder, QomRecorderConfig, QomScores};
25use crate::traffic::TrafficRecorder;
26
27pub const HEADER_STYPE: &str = "X-MPL-SType";
29pub const HEADER_PROFILE: &str = "X-MPL-Profile";
30pub const HEADER_SEM_HASH: &str = "X-MPL-Sem-Hash";
31pub const HEADER_QOM_RESULT: &str = "X-MPL-QoM-Result";
32pub const HEADER_TOC_RESULT: &str = "X-MPL-TOC-Result";
34pub const HEADER_TOC_CALLBACK_ID: &str = "X-MPL-TOC-Callback-Id";
36
37#[derive(Debug, Clone, Serialize)]
39pub struct ValidationResult {
40 pub valid: bool,
41 pub stype: Option<String>,
42 pub schema_valid: bool,
43 pub qom_passed: bool,
44 pub hash_valid: bool,
45 #[serde(skip_serializing_if = "Option::is_none")]
47 pub toc_result: Option<TocResult>,
48 #[serde(skip_serializing_if = "Option::is_none")]
50 pub ic_score: Option<f64>,
51 #[serde(skip_serializing_if = "Option::is_none")]
53 pub profile_used: Option<String>,
54 #[serde(default)]
56 pub degraded: bool,
57 #[serde(skip_serializing_if = "Option::is_none")]
59 pub original_profile: Option<String>,
60 #[serde(default)]
62 pub retry_count: u32,
63 pub errors: Vec<String>,
64}
65
66impl Default for ValidationResult {
67 fn default() -> Self {
68 Self {
69 valid: true,
70 stype: None,
71 schema_valid: true,
72 qom_passed: true,
73 hash_valid: true,
74 toc_result: None,
75 ic_score: None,
76 profile_used: None,
77 degraded: false,
78 original_profile: None,
79 retry_count: 0,
80 errors: Vec::new(),
81 }
82 }
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct PendingTocVerification {
88 pub callback_id: String,
90 pub stype: String,
92 pub timestamp: String,
94 pub expected_outcome: Option<String>,
96 pub tool_name: Option<String>,
98}
99
100pub struct ProxyState {
102 pub config: ProxyConfig,
104
105 pub client: Client,
107
108 pub validator: SchemaValidator,
110
111 pub assertions: Arc<RwLock<HashMap<String, AssertionSet>>>,
113
114 pub profiles: Vec<QomProfile>,
116
117 pub metrics: Arc<MetricsState>,
119
120 pub qom_recorder: Arc<QomRecorder>,
122
123 pub traffic_recorder: Arc<TrafficRecorder>,
125
126 pub pending_toc: Arc<RwLock<HashMap<String, PendingTocVerification>>>,
128
129 pub completed_toc: Arc<RwLock<HashMap<String, TocResult>>>,
131
132 toc_counter: std::sync::atomic::AtomicU64,
134}
135
136impl ProxyState {
137 pub async fn new(config: ProxyConfig) -> Result<Self> {
139 Self::with_options(config, None, false).await
140 }
141
142 pub async fn with_options(
144 config: ProxyConfig,
145 data_dir: Option<&str>,
146 learning_enabled: bool,
147 ) -> Result<Self> {
148 let client = Client::builder()
150 .connect_timeout(config.transport.connect_timeout())
151 .timeout(config.transport.request_timeout())
152 .pool_idle_timeout(config.transport.idle_timeout())
153 .build()?;
154
155 let mut validator = SchemaValidator::new();
156 let mut assertions_map: HashMap<String, AssertionSet> = HashMap::new();
157
158 let registry_path = &config.mpl.registry;
160 if Path::new(registry_path).exists() {
161 Self::load_schemas_from_registry(&mut validator, registry_path)?;
162 Self::load_assertions_from_registry(&mut assertions_map, registry_path)?;
163 }
164
165 let profiles = vec![
166 QomProfile::basic(),
167 QomProfile::strict_argcheck(),
168 QomProfile::outcome(),
169 QomProfile::comprehensive(),
170 ];
171
172 let metrics = Arc::new(MetricsState::new());
173
174 let data_path = data_dir
176 .map(Path::new)
177 .unwrap_or_else(|| Path::new("~/.mpl"));
178 let traffic_recorder = Arc::new(TrafficRecorder::new(data_path, learning_enabled));
179
180 if learning_enabled {
181 if let Err(e) = traffic_recorder.load_from_disk() {
183 warn!("Failed to load existing traffic samples: {}", e);
184 }
185 info!("Traffic recording enabled");
186 }
187
188 let qom_data_dir = data_path.join("qom");
190 let qom_recorder = Arc::new(QomRecorder::new(QomRecorderConfig {
191 data_dir: qom_data_dir,
192 ..Default::default()
193 }));
194
195 if let Err(e) = qom_recorder.load_from_disk().await {
197 warn!("Failed to load QoM events: {}", e);
198 }
199
200 Self::load_ontologies_from_registry(&qom_recorder, registry_path).await;
202
203 info!("Proxy state initialized");
204 info!("Mode: {:?}", config.mpl.mode);
205 info!("Loaded {} schemas", validator.registered_stypes().len());
206 info!("Loaded {} assertion sets", assertions_map.len());
207 info!("Loaded {} QoM profiles", profiles.len());
208
209 Ok(Self {
210 config,
211 client,
212 validator,
213 assertions: Arc::new(RwLock::new(assertions_map)),
214 profiles,
215 metrics,
216 qom_recorder,
217 traffic_recorder,
218 pending_toc: Arc::new(RwLock::new(HashMap::new())),
219 completed_toc: Arc::new(RwLock::new(HashMap::new())),
220 toc_counter: std::sync::atomic::AtomicU64::new(0),
221 })
222 }
223
224 pub fn next_toc_callback_id(&self) -> String {
226 let count = self.toc_counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
227 format!("toc-{:016x}", count)
228 }
229
230 pub fn register_pending_toc(&self, verification: PendingTocVerification) {
232 if let Ok(mut pending) = self.pending_toc.write() {
233 pending.insert(verification.callback_id.clone(), verification);
234 }
235 }
236
237 pub fn complete_toc(&self, callback_id: &str, result: TocResult) -> bool {
239 let mut was_pending = false;
241 if let Ok(mut pending) = self.pending_toc.write() {
242 was_pending = pending.remove(callback_id).is_some();
243 }
244
245 if let Ok(mut completed) = self.completed_toc.write() {
247 completed.insert(callback_id.to_string(), result);
248 }
249
250 was_pending
251 }
252
253 pub fn get_toc_result(&self, callback_id: &str) -> Option<TocResult> {
255 self.completed_toc
256 .read()
257 .ok()
258 .and_then(|completed| completed.get(callback_id).cloned())
259 }
260
261 pub fn get_pending_toc(&self, callback_id: &str) -> Option<PendingTocVerification> {
263 self.pending_toc
264 .read()
265 .ok()
266 .and_then(|pending| pending.get(callback_id).cloned())
267 }
268
269 pub fn parse_toc_header(value: &str) -> Option<TocResult> {
271 match value.to_lowercase().as_str() {
272 "verified" | "pass" | "true" | "1" => Some(TocResult::verified(TocMethod::Header)),
273 "failed" | "fail" | "false" | "0" => {
274 Some(TocResult::failed(TocMethod::Header, "Verification failed"))
275 }
276 "pending" => None, "skip" | "na" => Some(TocResult::verified(TocMethod::None)), _ => None,
279 }
280 }
281
282 fn load_schemas_from_registry(validator: &mut SchemaValidator, registry_path: &str) -> Result<()> {
284 let stypes_path = Path::new(registry_path).join("stypes");
285 if !stypes_path.exists() {
286 debug!("Registry stypes path does not exist: {}", stypes_path.display());
287 return Ok(());
288 }
289
290 Self::walk_registry_dir(validator, &stypes_path, Vec::new())?;
292 Ok(())
293 }
294
295 fn walk_registry_dir(
296 validator: &mut SchemaValidator,
297 path: &Path,
298 parts: Vec<String>,
299 ) -> Result<()> {
300 if !path.is_dir() {
301 return Ok(());
302 }
303
304 for entry in std::fs::read_dir(path)? {
305 let entry = entry?;
306 let entry_path = entry.path();
307 let name = entry.file_name().to_string_lossy().to_string();
308
309 if entry_path.is_dir() {
310 let mut new_parts = parts.clone();
311 new_parts.push(name);
312 Self::walk_registry_dir(validator, &entry_path, new_parts)?;
313 } else if name == "schema.json" && parts.len() >= 4 {
314 let version_str = &parts[parts.len() - 1];
317 if let Some(version) = version_str.strip_prefix('v') {
318 if version.parse::<u32>().is_ok() {
319 let namespace = parts[..parts.len() - 2].join(".");
320 let name = &parts[parts.len() - 2];
321 let stype = format!(
322 "{}.{}.{}",
323 namespace,
324 name,
325 version_str
326 );
327
328 if let Ok(schema_content) = std::fs::read_to_string(&entry_path) {
330 if validator.register_json(&stype, &schema_content).is_ok() {
331 debug!("Registered schema for {}", stype);
332 }
333 }
334 }
335 }
336 }
337 }
338 Ok(())
339 }
340
341 fn load_assertions_from_registry(
343 assertions: &mut HashMap<String, AssertionSet>,
344 registry_path: &str,
345 ) -> Result<()> {
346 let stypes_path = Path::new(registry_path).join("stypes");
347 if !stypes_path.exists() {
348 return Ok(());
349 }
350
351 Self::walk_registry_dir_for_assertions(assertions, &stypes_path, Vec::new())?;
353 Ok(())
354 }
355
356 fn walk_registry_dir_for_assertions(
357 assertions: &mut HashMap<String, AssertionSet>,
358 path: &Path,
359 parts: Vec<String>,
360 ) -> Result<()> {
361 if !path.is_dir() {
362 return Ok(());
363 }
364
365 for entry in std::fs::read_dir(path)? {
366 let entry = entry?;
367 let entry_path = entry.path();
368 let name = entry.file_name().to_string_lossy().to_string();
369
370 if entry_path.is_dir() {
371 let mut new_parts = parts.clone();
372 new_parts.push(name);
373 Self::walk_registry_dir_for_assertions(assertions, &entry_path, new_parts)?;
374 } else if name == "assertions.json" && parts.len() >= 4 {
375 let version_str = &parts[parts.len() - 1];
377 if let Some(version) = version_str.strip_prefix('v') {
378 if version.parse::<u32>().is_ok() {
379 let namespace = parts[..parts.len() - 2].join(".");
380 let type_name = &parts[parts.len() - 2];
381 let stype = format!("{}.{}.{}", namespace, type_name, version_str);
382
383 if let Ok(content) = std::fs::read_to_string(&entry_path) {
385 match serde_json::from_str::<AssertionSet>(&content) {
386 Ok(assertion_set) => {
387 debug!(
388 "Loaded {} assertions for {}",
389 assertion_set.assertions.len(),
390 stype
391 );
392 assertions.insert(stype, assertion_set);
393 }
394 Err(e) => {
395 warn!(
396 "Failed to parse assertions for {}: {}",
397 entry_path.display(),
398 e
399 );
400 }
401 }
402 }
403 }
404 }
405 }
406 }
407 Ok(())
408 }
409
410 async fn load_ontologies_from_registry(qom_recorder: &QomRecorder, registry_path: &str) {
412 let stypes_path = Path::new(registry_path).join("stypes");
413 if !stypes_path.exists() {
414 return;
415 }
416
417 let ontologies = Self::collect_ontologies_from_registry(&stypes_path, Vec::new());
419
420 for (stype, spec) in ontologies {
422 debug!("Loaded ontology for {}", stype);
423 qom_recorder.load_ontology(&stype, spec).await;
424 }
425 }
426
427 fn collect_ontologies_from_registry(
428 path: &Path,
429 parts: Vec<String>,
430 ) -> Vec<(String, Ontology)> {
431 let mut result = Vec::new();
432
433 if !path.is_dir() {
434 return result;
435 }
436
437 if let Ok(entries) = std::fs::read_dir(path) {
438 for entry in entries.flatten() {
439 let entry_path = entry.path();
440 let name = entry.file_name().to_string_lossy().to_string();
441
442 if entry_path.is_dir() {
443 let mut new_parts = parts.clone();
444 new_parts.push(name);
445 result.extend(Self::collect_ontologies_from_registry(&entry_path, new_parts));
446 } else if name == "ontology.json" && parts.len() >= 4 {
447 let version_str = &parts[parts.len() - 1];
449 if let Some(version) = version_str.strip_prefix('v') {
450 if version.parse::<u32>().is_ok() {
451 let namespace = parts[..parts.len() - 2].join(".");
452 let type_name = &parts[parts.len() - 2];
453 let stype = format!("{}.{}.{}", namespace, type_name, version_str);
454
455 if let Ok(content) = std::fs::read_to_string(&entry_path) {
457 match serde_json::from_str::<Ontology>(&content) {
458 Ok(spec) => {
459 result.push((stype, spec));
460 }
461 Err(e) => {
462 warn!(
463 "Failed to parse ontology for {}: {}",
464 entry_path.display(),
465 e
466 );
467 }
468 }
469 }
470 }
471 }
472 }
473 }
474 }
475
476 result
477 }
478
479 pub fn get_assertions(&self, stype: &str) -> Option<AssertionSet> {
481 self.assertions
482 .read()
483 .ok()
484 .and_then(|a| a.get(stype).cloned())
485 }
486
487 pub async fn validate_request(&self, envelope: &MplEnvelope) -> ValidationResult {
489 self.validate_request_full(envelope, None).await
490 }
491
492 pub async fn validate_request_full(
494 &self,
495 envelope: &MplEnvelope,
496 response: Option<&serde_json::Value>,
497 ) -> ValidationResult {
498 let mut result = ValidationResult {
499 stype: Some(envelope.stype.clone()),
500 ..Default::default()
501 };
502
503 let payload_hash = semantic_hash(&envelope.payload).ok();
505
506 let sf_score = if self.config.mpl.enforce_schema {
508 match self.validator.validate(&envelope.stype, &envelope.payload) {
509 Ok(validation) => {
510 result.schema_valid = validation.valid;
511 if !validation.valid {
512 result.valid = false;
513 for err in validation.errors {
514 result.errors.push(format!("Schema error at {}: {}", err.path, err.message));
515 }
516 }
517 if validation.valid { 1.0 } else { 0.0 }
518 }
519 Err(e) => {
520 if self.is_strict() {
522 result.valid = false;
523 result.schema_valid = false;
524 result.errors.push(format!("Unknown SType: {} ({})", envelope.stype, e));
525 0.0
526 } else {
527 warn!("Unknown SType: {}, allowing in transparent mode", envelope.stype);
528 1.0
529 }
530 }
531 }
532 } else {
533 1.0
534 };
535
536 if let Some(ref expected_hash) = envelope.sem_hash {
538 match verify_hash(&envelope.payload, expected_hash) {
539 Ok(valid) => {
540 result.hash_valid = valid;
541 if !valid {
542 result.valid = false;
543 result.errors.push("Semantic hash mismatch".to_string());
544 }
545 }
546 Err(e) => {
547 result.hash_valid = false;
548 result.valid = false;
549 result.errors.push(format!("Hash verification failed: {}", e));
550 }
551 }
552 }
553
554 let ic_score = if let Some(assertion_set) = self.get_assertions(&envelope.stype) {
556 let eval_ctx = EvaluationContext {
557 stype: Some(envelope.stype.clone()),
558 ..Default::default()
559 };
560
561 match assertion_set.evaluate_with_context(&envelope.payload, &eval_ctx) {
562 Ok(assertion_result) => {
563 let score = assertion_result.ic_score;
564 result.ic_score = Some(score);
565
566 for ar in &assertion_result.results {
568 if !ar.passed {
569 match ar.severity {
570 mpl_core::assertions::AssertionSeverity::Error => {
571 result.errors.push(format!("IC error [{}]: {}", ar.id, ar.message));
572 }
573 mpl_core::assertions::AssertionSeverity::Warning => {
574 debug!("IC warning [{}]: {}", ar.id, ar.message);
575 }
576 mpl_core::assertions::AssertionSeverity::Info => {
577 debug!("IC info [{}]: {}", ar.id, ar.message);
578 }
579 }
580 }
581 }
582
583 if assertion_result.error_count > 0 && self.is_strict() {
585 result.valid = false;
586 }
587
588 Some(score)
589 }
590 Err(e) => {
591 warn!("Assertion evaluation failed for {}: {}", envelope.stype, e);
592 None
593 }
594 }
595 } else {
596 None
597 };
598
599 let oa_result = self.qom_recorder.check_ontology(&envelope.stype, &envelope.payload).await;
601 let oa_score = if oa_result.constraints_checked > 0 {
602 Some(oa_result.score)
603 } else {
604 None
605 };
606
607 let dj_score = if let (Some(resp), Some(ref hash)) = (response, &payload_hash) {
609 let dj_result = self.qom_recorder.check_determinism(&envelope.stype, hash, resp).await;
610 if dj_result.comparison_count > 0 {
611 Some(dj_result.similarity)
612 } else {
613 None
614 }
615 } else {
616 None
617 };
618
619 let profile_name = if let Some(profile) = self.active_profile() {
621 let mut metrics = if result.schema_valid {
622 QomMetrics::schema_valid()
623 } else {
624 QomMetrics::schema_invalid()
625 };
626
627 if let Some(ic) = ic_score {
629 metrics = metrics.with_instruction_compliance(ic);
630 }
631 if let Some(oa) = oa_score {
632 metrics = metrics.with_ontology_adherence(oa);
633 }
634 if let Some(dj) = dj_score {
635 metrics = metrics.with_determinism_jitter(dj);
636 }
637
638 let evaluation = profile.evaluate(&metrics);
639 result.qom_passed = evaluation.meets_profile;
640 result.profile_used = Some(profile.name.clone());
641
642 if !evaluation.meets_profile {
643 result.valid = false;
644 for failure in evaluation.failures {
645 result
646 .errors
647 .push(format!("QoM breach: {} < {}", failure.metric, failure.threshold));
648 }
649 }
650
651 Some(profile.name.clone())
652 } else {
653 None
654 };
655
656 self.metrics.inc_requests();
658 if result.schema_valid {
659 self.metrics.inc_schema_pass();
660 } else {
661 self.metrics.inc_schema_fail();
662 }
663 if result.qom_passed {
664 self.metrics.inc_qom_pass();
665 } else {
666 self.metrics.inc_qom_fail();
667 }
668
669 let scores = QomScores {
671 sf: Some(sf_score),
672 ic: ic_score,
673 toc: None, g: None, dj: dj_score,
676 oa: oa_score,
677 };
678
679 let failure_reason = if !result.qom_passed && !result.errors.is_empty() {
680 Some(result.errors.join("; "))
681 } else {
682 None
683 };
684
685 let event = self.qom_recorder.create_event(
686 &envelope.stype,
687 &profile_name.unwrap_or_else(|| "none".to_string()),
688 result.qom_passed,
689 scores,
690 failure_reason,
691 payload_hash,
692 );
693 self.qom_recorder.record_event(event).await;
694
695 result
696 }
697
698 pub async fn forward_request(
700 &self,
701 path: String,
702 request: Request<Body>,
703 ) -> Result<Response<Body>> {
704 use crate::traffic::{StypeInferrer, TrafficRecord};
705
706 let start_time = std::time::Instant::now();
707 let upstream = &self.config.transport.upstream;
708 let uri = format!("http://{}/{}", upstream, path);
709
710 debug!("Forwarding to: {}", uri);
711
712 let method = request.method().clone();
714 let method_str = method.to_string();
715 let headers = request.headers().clone();
716 let stype_header = headers
717 .get(HEADER_STYPE)
718 .and_then(|v| v.to_str().ok())
719 .map(String::from);
720
721 let body_bytes = axum::body::to_bytes(request.into_body(), usize::MAX).await?;
723
724 let payload: serde_json::Value = serde_json::from_slice(&body_bytes)
726 .unwrap_or(serde_json::Value::Null);
727
728 let envelope = if let Ok(env) = serde_json::from_slice::<MplEnvelope>(&body_bytes) {
730 Some(env)
731 } else if let Some(stype) = stype_header.clone() {
732 if let Ok(payload) = serde_json::from_slice(&body_bytes) {
734 Some(MplEnvelope::new(stype, payload))
735 } else {
736 None
737 }
738 } else {
739 None
740 };
741
742 let stype = envelope
744 .as_ref()
745 .map(|e| e.stype.clone())
746 .or(stype_header)
747 .unwrap_or_else(|| StypeInferrer::infer(&path, &method_str, &payload));
748
749 let validation_result = if let Some(ref env) = envelope {
751 let result = self.validate_request(env).await;
752
753 if !result.valid && self.is_strict() {
755 let error_response = serde_json::json!({
756 "error": "E-SCHEMA-FIDELITY",
757 "message": "Request validation failed",
758 "details": result.errors,
759 });
760
761 if self.traffic_recorder.is_enabled() {
763 let record = TrafficRecord {
764 id: self.traffic_recorder.next_id(),
765 timestamp: chrono::Utc::now().to_rfc3339(),
766 stype: stype.clone(),
767 method: method_str.clone(),
768 path: path.clone(),
769 payload: payload.clone(),
770 response: Some(error_response.clone()),
771 status_code: Some(400),
772 duration_ms: Some(start_time.elapsed().as_millis() as u64),
773 validation_passed: false,
774 validation_errors: result.errors.clone(),
775 };
776 self.traffic_recorder.record(record);
777 }
778
779 return Ok(Response::builder()
780 .status(StatusCode::BAD_REQUEST)
781 .header("content-type", "application/json")
782 .header(HEADER_QOM_RESULT, if result.qom_passed { "pass" } else { "fail" })
783 .body(Body::from(serde_json::to_string(&error_response)?))?);
784 }
785
786 Some(result)
787 } else {
788 None
789 };
790
791 let mut req_builder = self.client.request(method, &uri);
793
794 for (name, value) in headers.iter() {
795 if name != "host" {
796 req_builder = req_builder.header(name, value);
797 }
798 }
799
800 let upstream_response = req_builder.body(body_bytes.to_vec()).send().await?;
801
802 let status = upstream_response.status();
804 let status_code = status.as_u16();
805 let response_headers = upstream_response.headers().clone();
806 let body = upstream_response.bytes().await?;
807
808 if self.traffic_recorder.is_enabled() {
810 let response_payload: Option<serde_json::Value> = serde_json::from_slice(&body).ok();
811
812 let record = TrafficRecord {
813 id: self.traffic_recorder.next_id(),
814 timestamp: chrono::Utc::now().to_rfc3339(),
815 stype,
816 method: method_str,
817 path,
818 payload,
819 response: response_payload,
820 status_code: Some(status_code),
821 duration_ms: Some(start_time.elapsed().as_millis() as u64),
822 validation_passed: validation_result.as_ref().map(|r| r.valid).unwrap_or(true),
823 validation_errors: validation_result
824 .as_ref()
825 .map(|r| r.errors.clone())
826 .unwrap_or_default(),
827 };
828 self.traffic_recorder.record(record);
829 }
830
831 let mut response = Response::builder().status(status);
832
833 for (name, value) in response_headers.iter() {
834 response = response.header(name, value);
835 }
836
837 if let Some(ref result) = validation_result {
839 response = response.header(HEADER_QOM_RESULT, if result.qom_passed { "pass" } else { "fail" });
840 }
841
842 Ok(response.body(Body::from(body))?)
843 }
844
845 pub fn active_profile(&self) -> Option<&QomProfile> {
847 let profile_name = self.config.mpl.required_profile.as_ref()?;
848 self.profiles.iter().find(|p| &p.name == profile_name)
849 }
850
851 pub fn get_profile(&self, name: &str) -> Option<&QomProfile> {
853 self.profiles.iter().find(|p| p.name == name)
854 }
855
856 pub fn get_degradation_chain(&self, start_profile: &str) -> Vec<&QomProfile> {
859 let order = ["qom-comprehensive", "qom-outcome", "qom-strict-argcheck", "qom-basic"];
861
862 let start_idx = order.iter().position(|&p| p == start_profile).unwrap_or(0);
863
864 order[start_idx..]
865 .iter()
866 .filter_map(|&name| self.get_profile(name))
867 .collect()
868 }
869
870 pub async fn validate_with_degradation(
873 &self,
874 envelope: &MplEnvelope,
875 ) -> ValidationResult {
876 let original_profile = self.config.mpl.required_profile.clone();
877
878 if let Some(ref profile_name) = original_profile {
879 let chain = self.get_degradation_chain(profile_name);
880
881 for (idx, profile) in chain.iter().enumerate() {
882 let mut result = self.validate_request_with_profile(envelope, profile);
883
884 if result.qom_passed || idx == chain.len() - 1 {
885 result.profile_used = Some(profile.name.clone());
887 result.degraded = idx > 0;
888 if idx > 0 {
889 result.original_profile = Some(profile_name.clone());
890 }
891 return result;
892 }
893 }
894 }
895
896 self.validate_request(envelope).await
898 }
899
900 fn validate_request_with_profile(
902 &self,
903 envelope: &MplEnvelope,
904 profile: &QomProfile,
905 ) -> ValidationResult {
906 let mut result = ValidationResult {
907 stype: Some(envelope.stype.clone()),
908 profile_used: Some(profile.name.clone()),
909 ..Default::default()
910 };
911
912 if self.config.mpl.enforce_schema {
914 match self.validator.validate(&envelope.stype, &envelope.payload) {
915 Ok(validation) => {
916 result.schema_valid = validation.valid;
917 if !validation.valid {
918 result.valid = false;
919 for err in validation.errors {
920 result.errors.push(format!("Schema error at {}: {}", err.path, err.message));
921 }
922 }
923 }
924 Err(e) => {
925 if self.is_strict() {
926 result.valid = false;
927 result.schema_valid = false;
928 result.errors.push(format!("Unknown SType: {} ({})", envelope.stype, e));
929 }
930 }
931 }
932 }
933
934 let ic_score = if let Some(assertion_set) = self.get_assertions(&envelope.stype) {
936 let eval_ctx = EvaluationContext {
937 stype: Some(envelope.stype.clone()),
938 ..Default::default()
939 };
940
941 match assertion_set.evaluate_with_context(&envelope.payload, &eval_ctx) {
942 Ok(assertion_result) => {
943 result.ic_score = Some(assertion_result.ic_score);
944 Some(assertion_result.ic_score)
945 }
946 Err(_) => None,
947 }
948 } else {
949 None
950 };
951
952 let mut metrics = if result.schema_valid {
954 QomMetrics::schema_valid()
955 } else {
956 QomMetrics::schema_invalid()
957 };
958
959 if let Some(ic) = ic_score {
960 metrics = metrics.with_instruction_compliance(ic);
961 }
962
963 let evaluation = profile.evaluate(&metrics);
964 result.qom_passed = evaluation.meets_profile;
965
966 if !evaluation.meets_profile {
967 for failure in evaluation.failures {
968 result.errors.push(format!("QoM breach [{}]: {} < {}",
969 profile.name, failure.metric, failure.threshold));
970 }
971 }
972
973 result
974 }
975
976 pub fn is_strict(&self) -> bool {
978 matches!(self.config.mpl.mode, ProxyMode::Strict)
979 }
980}
981
982#[derive(Debug, Clone, Serialize, Deserialize)]
984pub struct AiAlpnClientHello {
985 #[serde(rename = "type")]
986 pub msg_type: String,
987 pub version: String,
988 pub stypes: Vec<String>,
989 #[serde(default)]
990 pub qom_profiles: Vec<String>,
991}
992
993#[derive(Debug, Clone, Serialize, Deserialize)]
995pub struct AiAlpnServerSelect {
996 #[serde(rename = "type")]
997 pub msg_type: String,
998 pub common_stypes: Vec<String>,
999 pub selected_profile: Option<String>,
1000 #[serde(default)]
1001 pub extensions: serde_json::Value,
1002}
1003
1004impl ProxyState {
1005 pub fn handle_handshake(&self, hello: AiAlpnClientHello) -> AiAlpnServerSelect {
1007 let server_stypes: Vec<String> = self.validator.registered_stypes().iter().map(|s| s.to_string()).collect();
1009 let common_stypes: Vec<String> = hello
1010 .stypes
1011 .iter()
1012 .filter(|s| server_stypes.contains(s))
1013 .cloned()
1014 .collect();
1015
1016 let selected_profile = hello
1018 .qom_profiles
1019 .iter()
1020 .find(|p| self.profiles.iter().any(|sp| &sp.name == *p))
1021 .cloned()
1022 .or_else(|| self.config.mpl.required_profile.clone());
1023
1024 self.metrics.inc_handshakes();
1025
1026 AiAlpnServerSelect {
1027 msg_type: "ai-alpn-select".to_string(),
1028 common_stypes,
1029 selected_profile,
1030 extensions: serde_json::json!({}),
1031 }
1032 }
1033}