1use crate::client::AxonFlowClient;
24use crate::error::AxonFlowError;
25use crate::types::pep::{
26 DecideRequest, DecideResponse, MCPCheckInputRequest, MCPCheckInputResponse, Obligation,
27};
28
29pub const OBLIGATION_REDACT_PII: &str = "redact_pii";
34
35pub const PHASE_REQUEST: &str = "request";
38pub const PHASE_RESPONSE: &str = "response";
41
42pub const CONTENT_TYPE_TEXT: &str = "text/plain";
46
47pub const VERDICT_ALLOW: &str = "allow";
51pub const VERDICT_DENY: &str = "deny";
53pub const VERDICT_NEEDS_APPROVAL: &str = "needs_approval";
55
56pub const DECIDE_PATH: &str = "/api/v1/decide";
60pub const REQUEST_REDACTION_PATH: &str = "/api/v1/mcp/check-input";
64pub const RESPONSE_REDACTION_PATH: &str = "/api/v1/mcp/check-output";
66
67pub const GATEWAY_CONNECTOR_TAG: &str = "gateway";
71
72pub fn has_request_redaction(obligations: &[Obligation]) -> bool {
77 obligations.iter().any(|o| {
78 o.r#type == OBLIGATION_REDACT_PII
79 && o.fulfillment
80 .as_ref()
81 .is_some_and(|f| f.phase == PHASE_REQUEST)
82 })
83}
84
85pub(crate) fn endpoint_path_matches(endpoint: &str, expected: &str) -> bool {
90 let e = endpoint.trim();
91 if e == expected {
92 return true;
93 }
94 if let Some(idx) = e.find("://") {
95 let rest = &e[idx + 3..];
96 if let Some(slash) = rest.find('/') {
97 let mut path = &rest[slash..];
98 if let Some(q) = path.find('?') {
99 path = &path[..q];
100 }
101 return path == expected;
102 }
103 }
104 false
105}
106
107impl AxonFlowClient {
108 pub async fn decide(&self, request: DecideRequest) -> Result<DecideResponse, AxonFlowError> {
140 let url = format!("{}{}", self.endpoint(), DECIDE_PATH);
141 let resp = self.checked_post_json(&url, &request).await?;
145 let body = resp.text().await?;
146 let parsed: DecideResponse = serde_json::from_str(&body)?;
147 Ok(parsed)
148 }
149
150 pub async fn fulfill_request(
174 &self,
175 decision: &DecideResponse,
176 statement: &str,
177 ) -> Result<(String, bool), AxonFlowError> {
178 let mut redacted = statement.to_string();
179 let mut did_redact = false;
180 for ob in &decision.obligations {
181 if ob.r#type != OBLIGATION_REDACT_PII {
182 continue;
185 }
186 let fulfillment = match &ob.fulfillment {
187 Some(f) if f.phase == PHASE_REQUEST => f,
188 _ => {
189 return Err(AxonFlowError::ObligationNotFulfillable(
192 "redact_pii obligation missing request-phase fulfillment".to_string(),
193 ));
194 }
195 };
196 if let Some(cts) = &fulfillment.content_types {
200 if !cts.is_empty() && !cts.iter().any(|c| c == CONTENT_TYPE_TEXT) {
201 return Err(AxonFlowError::ObligationNotFulfillable(format!(
202 "fulfillment endpoint does not advertise a {CONTENT_TYPE_TEXT} detector"
203 )));
204 }
205 }
206 if !endpoint_path_matches(&fulfillment.endpoint, REQUEST_REDACTION_PATH) {
207 return Err(AxonFlowError::ObligationNotFulfillable(format!(
208 "fulfillment endpoint {:?} is not the request-redaction endpoint",
209 fulfillment.endpoint
210 )));
211 }
212 redacted = self.fulfill_via_check_input(&redacted).await?;
213 if redacted != statement {
214 did_redact = true;
215 }
216 }
217 Ok((redacted, did_redact))
218 }
219
220 async fn fulfill_via_check_input(&self, statement: &str) -> Result<String, AxonFlowError> {
227 let req = MCPCheckInputRequest {
228 connector_type: GATEWAY_CONNECTOR_TAG.to_string(),
229 statement: statement.to_string(),
230 operation: Some("execute".to_string()),
231 tenant_id: None,
232 content_type: Some(CONTENT_TYPE_TEXT.to_string()),
233 };
234 let url = format!("{}{}", self.endpoint(), REQUEST_REDACTION_PATH);
235 let result: MCPCheckInputResponse = match self.checked_post_json(&url, &req).await {
236 Ok(resp) => {
237 let body = resp.text().await?;
238 serde_json::from_str(&body).map_err(|e| {
239 AxonFlowError::ObligationNotFulfillable(format!(
240 "decode request-redaction engine response: {e}"
241 ))
242 })?
243 }
244 Err(e) => {
245 return Err(AxonFlowError::ObligationNotFulfillable(format!(
246 "request-redaction engine call failed: {e}"
247 )));
248 }
249 };
250 if !result.redaction_evaluated {
254 return Err(AxonFlowError::ObligationNotFulfillable(
255 "engine reported the redactor did not run (redaction disabled)".to_string(),
256 ));
257 }
258 match (result.redacted, result.redacted_statement) {
259 (true, Some(masked)) if !masked.is_empty() => Ok(masked),
260 (true, _) => Err(AxonFlowError::ObligationNotFulfillable(
265 "engine reported redacted=true but returned no redacted_statement".to_string(),
266 )),
267 (false, _) => Ok(statement.to_string()),
269 }
270 }
271
272 pub async fn decide_and_fulfill(
288 &self,
289 request: DecideRequest,
290 ) -> Result<(String, String, DecideResponse), AxonFlowError> {
291 let query = request.query.clone();
292 let decision = self.decide(request).await?;
293 if decision.verdict != VERDICT_ALLOW {
294 return Ok((decision.verdict.clone(), query, decision));
295 }
296 let (redacted, _) = self.fulfill_request(&decision, &query).await?;
297 Ok((decision.verdict.clone(), redacted, decision))
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304 use crate::types::pep::ObligationFulfillment;
305 use crate::{AxonFlowConfig, AxonFlowError};
306 use serde_json::json;
307 use std::time::Duration;
308 use wiremock::matchers::{body_partial_json, method, path};
309 use wiremock::{Mock, MockServer, ResponseTemplate};
310
311 fn make_client(endpoint: String) -> AxonFlowClient {
312 let config = AxonFlowConfig {
313 endpoint,
314 client_id: Some("org-1".into()),
315 client_secret: Some("license-1".into()),
316 timeout: Duration::from_secs(2),
317 ..Default::default()
318 };
319 AxonFlowClient::new(config).expect("client init")
320 }
321
322 fn redact_obligation() -> Obligation {
323 Obligation {
324 r#type: OBLIGATION_REDACT_PII.into(),
325 detail: None,
326 fulfillment: Some(ObligationFulfillment {
327 endpoint: REQUEST_REDACTION_PATH.into(),
328 method: "POST".into(),
329 phase: PHASE_REQUEST.into(),
330 content_types: Some(vec![CONTENT_TYPE_TEXT.into()]),
331 }),
332 }
333 }
334
335 fn allow_with(obligations: Vec<Obligation>) -> DecideResponse {
336 DecideResponse {
337 verdict: VERDICT_ALLOW.into(),
338 obligations,
339 ..Default::default()
340 }
341 }
342
343 #[test]
346 fn endpoint_path_matches_exact_and_absolute() {
347 assert!(endpoint_path_matches(
348 REQUEST_REDACTION_PATH,
349 REQUEST_REDACTION_PATH
350 ));
351 assert!(endpoint_path_matches(
352 " /api/v1/mcp/check-input ",
353 REQUEST_REDACTION_PATH
354 ));
355 assert!(endpoint_path_matches(
356 "https://pdp.internal:8443/api/v1/mcp/check-input",
357 REQUEST_REDACTION_PATH
358 ));
359 assert!(endpoint_path_matches(
360 "https://pdp.internal/api/v1/mcp/check-input?x=1",
361 REQUEST_REDACTION_PATH
362 ));
363 }
364
365 #[test]
366 fn endpoint_path_matches_rejects_foreign() {
367 assert!(!endpoint_path_matches("", REQUEST_REDACTION_PATH));
368 assert!(!endpoint_path_matches(
369 "/api/v1/mcp/check-output",
370 REQUEST_REDACTION_PATH
371 ));
372 assert!(!endpoint_path_matches(
373 "https://evil.example.com/steal",
374 REQUEST_REDACTION_PATH
375 ));
376 assert!(!endpoint_path_matches(
378 "https://pdp.internal",
379 REQUEST_REDACTION_PATH
380 ));
381 }
382
383 #[test]
386 fn has_request_redaction_detects_request_phase() {
387 assert!(has_request_redaction(&[redact_obligation()]));
388 }
389
390 #[test]
391 fn has_request_redaction_ignores_response_phase_and_no_fulfillment() {
392 let resp_phase = Obligation {
393 r#type: OBLIGATION_REDACT_PII.into(),
394 detail: None,
395 fulfillment: Some(ObligationFulfillment {
396 endpoint: RESPONSE_REDACTION_PATH.into(),
397 method: "POST".into(),
398 phase: PHASE_RESPONSE.into(),
399 content_types: None,
400 }),
401 };
402 let no_fulfillment = Obligation {
403 r#type: OBLIGATION_REDACT_PII.into(),
404 detail: None,
405 fulfillment: None,
406 };
407 let other_type = Obligation {
408 r#type: "log_only".into(),
409 detail: None,
410 fulfillment: Some(ObligationFulfillment {
411 endpoint: REQUEST_REDACTION_PATH.into(),
412 method: "POST".into(),
413 phase: PHASE_REQUEST.into(),
414 content_types: None,
415 }),
416 };
417 assert!(!has_request_redaction(&[
418 resp_phase,
419 no_fulfillment,
420 other_type
421 ]));
422 assert!(!has_request_redaction(&[]));
423 }
424
425 #[tokio::test]
428 async fn decide_parses_allow_with_obligation() {
429 let server = MockServer::start().await;
430 Mock::given(method("POST"))
431 .and(path("/api/v1/decide"))
432 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
433 "verdict": "allow",
434 "decision_id": "dec-1",
435 "trace_id": "04110a0b50577bbbdda23a00dcbaf6da",
436 "obligations": [{
437 "type": "redact_pii",
438 "fulfillment": {
439 "endpoint": "/api/v1/mcp/check-input",
440 "method": "POST",
441 "phase": "request",
442 "content_types": ["text/plain"],
443 },
444 }],
445 "evaluated_policies": ["sys_pii_email"],
446 "stage": "tool",
447 "expires_at": "2026-06-09T05:05:06.8Z",
448 })))
449 .mount(&server)
450 .await;
451
452 let client = make_client(server.uri());
453 let d = client
454 .decide(DecideRequest::new("tool", "send to a@b.com"))
455 .await
456 .unwrap();
457 assert_eq!(d.verdict, "allow");
458 assert_eq!(d.decision_id.as_deref(), Some("dec-1"));
459 assert_eq!(d.obligations.len(), 1);
460 assert_eq!(d.obligations[0].r#type, "redact_pii");
461 assert!(has_request_redaction(&d.obligations));
462 assert_eq!(d.evaluated_policies, vec!["sys_pii_email"]);
463 }
464
465 #[tokio::test]
466 async fn decide_returns_deny_in_body_not_error() {
467 let server = MockServer::start().await;
468 Mock::given(method("POST"))
469 .and(path("/api/v1/decide"))
470 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
471 "verdict": "deny",
472 "error": "stage is required and must be one of: llm, tool, agent",
473 })))
474 .mount(&server)
475 .await;
476
477 let client = make_client(server.uri());
478 let d = client
479 .decide(DecideRequest::new("", "x"))
480 .await
481 .expect("deny is a 200 body, not an error");
482 assert_eq!(d.verdict, "deny");
483 assert!(d.error.is_some());
484 assert!(d.obligations.is_empty());
485 }
486
487 #[tokio::test]
488 async fn decide_maps_401_to_api_error() {
489 let server = MockServer::start().await;
490 Mock::given(method("POST"))
491 .and(path("/api/v1/decide"))
492 .respond_with(ResponseTemplate::new(401).set_body_string("unauthorized"))
493 .mount(&server)
494 .await;
495
496 let client = make_client(server.uri());
497 let err = client
498 .decide(DecideRequest::new("tool", "x"))
499 .await
500 .unwrap_err();
501 match err {
502 AxonFlowError::ApiError { status, .. } => assert_eq!(status, 401),
503 other => panic!("expected ApiError 401, got {other:?}"),
504 }
505 }
506
507 #[tokio::test]
508 async fn decide_sends_basic_auth_and_body() {
509 let server = MockServer::start().await;
510 Mock::given(method("POST"))
511 .and(path("/api/v1/decide"))
512 .and(body_partial_json(json!({"stage": "tool", "query": "hi"})))
513 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"verdict": "allow"})))
514 .expect(1)
515 .mount(&server)
516 .await;
517
518 let client = make_client(server.uri());
519 let d = client
520 .decide(DecideRequest::new("tool", "hi"))
521 .await
522 .unwrap();
523 assert_eq!(d.verdict, "allow");
524 }
525
526 #[tokio::test]
529 async fn fulfill_request_returns_engine_masked_content() {
530 let server = MockServer::start().await;
531 Mock::given(method("POST"))
532 .and(path("/api/v1/mcp/check-input"))
533 .and(body_partial_json(
534 json!({"connector_type": "gateway", "content_type": "text/plain"}),
535 ))
536 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
537 "allowed": true,
538 "redacted": true,
539 "redacted_statement": "Email jo****om and card 4****1",
540 "redaction_evaluated": true,
541 })))
542 .mount(&server)
543 .await;
544
545 let client = make_client(server.uri());
546 let (content, did_redact) = client
547 .fulfill_request(
548 &allow_with(vec![redact_obligation()]),
549 "Email john and card 4111",
550 )
551 .await
552 .unwrap();
553 assert!(did_redact);
554 assert_eq!(content, "Email jo****om and card 4****1");
555 }
556
557 #[tokio::test]
558 async fn fulfill_request_no_obligation_is_passthrough() {
559 let client = make_client("http://127.0.0.1:1".into());
561 let (content, did_redact) = client
562 .fulfill_request(&allow_with(vec![]), "untouched")
563 .await
564 .unwrap();
565 assert!(!did_redact);
566 assert_eq!(content, "untouched");
567 }
568
569 #[tokio::test]
570 async fn fulfill_request_engine_found_nothing_is_passthrough() {
571 let server = MockServer::start().await;
572 Mock::given(method("POST"))
573 .and(path("/api/v1/mcp/check-input"))
574 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
575 "allowed": true,
576 "redacted": false,
577 "redaction_evaluated": true,
578 })))
579 .mount(&server)
580 .await;
581
582 let client = make_client(server.uri());
583 let (content, did_redact) = client
584 .fulfill_request(&allow_with(vec![redact_obligation()]), "no pii here")
585 .await
586 .unwrap();
587 assert!(!did_redact);
588 assert_eq!(content, "no pii here");
589 }
590
591 #[tokio::test]
594 async fn fulfill_fails_closed_on_missing_request_phase_fulfillment() {
595 let ob = Obligation {
596 r#type: OBLIGATION_REDACT_PII.into(),
597 detail: None,
598 fulfillment: None,
599 };
600 let client = make_client("http://127.0.0.1:1".into());
601 let err = client
602 .fulfill_request(&allow_with(vec![ob]), "secret a@b.com")
603 .await
604 .unwrap_err();
605 assert!(matches!(err, AxonFlowError::ObligationNotFulfillable(_)));
606 }
607
608 #[tokio::test]
609 async fn fulfill_fails_closed_on_response_phase_obligation() {
610 let ob = Obligation {
611 r#type: OBLIGATION_REDACT_PII.into(),
612 detail: None,
613 fulfillment: Some(ObligationFulfillment {
614 endpoint: REQUEST_REDACTION_PATH.into(),
615 method: "POST".into(),
616 phase: PHASE_RESPONSE.into(),
617 content_types: None,
618 }),
619 };
620 let client = make_client("http://127.0.0.1:1".into());
621 let err = client
622 .fulfill_request(&allow_with(vec![ob]), "secret a@b.com")
623 .await
624 .unwrap_err();
625 assert!(matches!(err, AxonFlowError::ObligationNotFulfillable(_)));
626 }
627
628 #[tokio::test]
629 async fn fulfill_fails_closed_on_unadvertised_content_type() {
630 let ob = Obligation {
631 r#type: OBLIGATION_REDACT_PII.into(),
632 detail: None,
633 fulfillment: Some(ObligationFulfillment {
634 endpoint: REQUEST_REDACTION_PATH.into(),
635 method: "POST".into(),
636 phase: PHASE_REQUEST.into(),
637 content_types: Some(vec!["image/png".into()]),
638 }),
639 };
640 let client = make_client("http://127.0.0.1:1".into());
641 let err = client
642 .fulfill_request(&allow_with(vec![ob]), "secret a@b.com")
643 .await
644 .unwrap_err();
645 assert!(matches!(err, AxonFlowError::ObligationNotFulfillable(_)));
646 }
647
648 #[tokio::test]
649 async fn fulfill_fails_closed_on_foreign_endpoint() {
650 let ob = Obligation {
651 r#type: OBLIGATION_REDACT_PII.into(),
652 detail: None,
653 fulfillment: Some(ObligationFulfillment {
654 endpoint: "https://evil.example.com/steal".into(),
655 method: "POST".into(),
656 phase: PHASE_REQUEST.into(),
657 content_types: Some(vec![CONTENT_TYPE_TEXT.into()]),
658 }),
659 };
660 let client = make_client("http://127.0.0.1:1".into());
661 let err = client
662 .fulfill_request(&allow_with(vec![ob]), "secret a@b.com")
663 .await
664 .unwrap_err();
665 assert!(matches!(err, AxonFlowError::ObligationNotFulfillable(_)));
666 }
667
668 #[tokio::test]
669 async fn fulfill_fails_closed_on_engine_error() {
670 let server = MockServer::start().await;
671 Mock::given(method("POST"))
672 .and(path("/api/v1/mcp/check-input"))
673 .respond_with(ResponseTemplate::new(500).set_body_string("boom"))
674 .mount(&server)
675 .await;
676
677 let client = make_client(server.uri());
678 let err = client
679 .fulfill_request(&allow_with(vec![redact_obligation()]), "secret a@b.com")
680 .await
681 .unwrap_err();
682 assert!(matches!(err, AxonFlowError::ObligationNotFulfillable(_)));
683 }
684
685 #[tokio::test]
686 async fn fulfill_fails_closed_when_redaction_evaluated_false() {
687 let server = MockServer::start().await;
688 Mock::given(method("POST"))
689 .and(path("/api/v1/mcp/check-input"))
690 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
691 "allowed": true,
692 "redacted": false,
693 "redaction_evaluated": false,
694 })))
695 .mount(&server)
696 .await;
697
698 let client = make_client(server.uri());
699 let err = client
700 .fulfill_request(&allow_with(vec![redact_obligation()]), "secret a@b.com")
701 .await
702 .unwrap_err();
703 assert!(matches!(err, AxonFlowError::ObligationNotFulfillable(_)));
704 }
705
706 #[tokio::test]
707 async fn fulfill_fails_closed_when_redacted_true_without_statement() {
708 let server = MockServer::start().await;
711 Mock::given(method("POST"))
712 .and(path("/api/v1/mcp/check-input"))
713 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
714 "allowed": true,
715 "redacted": true,
716 "redaction_evaluated": true,
717 })))
718 .mount(&server)
719 .await;
720
721 let client = make_client(server.uri());
722 let err = client
723 .fulfill_request(&allow_with(vec![redact_obligation()]), "secret a@b.com")
724 .await
725 .unwrap_err();
726 assert!(matches!(err, AxonFlowError::ObligationNotFulfillable(_)));
727 }
728
729 #[tokio::test]
730 async fn fulfill_fails_closed_when_redaction_evaluated_absent() {
731 let server = MockServer::start().await;
732 Mock::given(method("POST"))
734 .and(path("/api/v1/mcp/check-input"))
735 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
736 "allowed": true,
737 "redacted": true,
738 "redacted_statement": "Email jo****om",
739 })))
740 .mount(&server)
741 .await;
742
743 let client = make_client(server.uri());
744 let err = client
745 .fulfill_request(&allow_with(vec![redact_obligation()]), "Email john")
746 .await
747 .unwrap_err();
748 assert!(matches!(err, AxonFlowError::ObligationNotFulfillable(_)));
749 }
750
751 #[tokio::test]
752 async fn fulfill_ignores_non_redact_obligation_types() {
753 let ob = Obligation {
755 r#type: "audit_only".into(),
756 detail: None,
757 fulfillment: None,
758 };
759 let client = make_client("http://127.0.0.1:1".into());
760 let (content, did_redact) = client
761 .fulfill_request(&allow_with(vec![ob]), "left alone")
762 .await
763 .unwrap();
764 assert!(!did_redact);
765 assert_eq!(content, "left alone");
766 }
767
768 #[tokio::test]
771 async fn decide_and_fulfill_allow_redacts() {
772 let server = MockServer::start().await;
773 Mock::given(method("POST"))
774 .and(path("/api/v1/decide"))
775 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
776 "verdict": "allow",
777 "obligations": [{
778 "type": "redact_pii",
779 "fulfillment": {
780 "endpoint": "/api/v1/mcp/check-input",
781 "phase": "request",
782 "content_types": ["text/plain"],
783 },
784 }],
785 })))
786 .mount(&server)
787 .await;
788 Mock::given(method("POST"))
789 .and(path("/api/v1/mcp/check-input"))
790 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
791 "allowed": true,
792 "redacted": true,
793 "redacted_statement": "card 4****1",
794 "redaction_evaluated": true,
795 })))
796 .mount(&server)
797 .await;
798
799 let client = make_client(server.uri());
800 let (verdict, content, decision) = client
801 .decide_and_fulfill(DecideRequest::new("tool", "card 4111111111111111"))
802 .await
803 .unwrap();
804 assert_eq!(verdict, "allow");
805 assert_eq!(content, "card 4****1");
806 assert_eq!(decision.verdict, "allow");
807 assert!(!content.contains("4111111111111111"));
808 }
809
810 #[tokio::test]
811 async fn decide_and_fulfill_deny_returns_original_without_engine_call() {
812 let server = MockServer::start().await;
813 Mock::given(method("POST"))
815 .and(path("/api/v1/decide"))
816 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
817 "verdict": "deny",
818 "reasons": ["blocked by policy"],
819 })))
820 .mount(&server)
821 .await;
822
823 let client = make_client(server.uri());
824 let (verdict, content, _) = client
825 .decide_and_fulfill(DecideRequest::new("tool", "original query"))
826 .await
827 .unwrap();
828 assert_eq!(verdict, "deny");
829 assert_eq!(content, "original query");
830 }
831
832 #[tokio::test]
833 async fn decide_and_fulfill_unfulfillable_surfaces_error_not_original() {
834 let server = MockServer::start().await;
835 Mock::given(method("POST"))
836 .and(path("/api/v1/decide"))
837 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
838 "verdict": "allow",
839 "obligations": [{
840 "type": "redact_pii",
841 "fulfillment": {
842 "endpoint": "https://evil.example.com/steal",
843 "phase": "request",
844 },
845 }],
846 })))
847 .mount(&server)
848 .await;
849
850 let client = make_client(server.uri());
851 let err = client
852 .decide_and_fulfill(DecideRequest::new("tool", "leak me a@b.com"))
853 .await
854 .unwrap_err();
855 assert!(matches!(err, AxonFlowError::ObligationNotFulfillable(_)));
857 }
858}