use super::PolicyContext;
use async_trait::async_trait;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub enum InputProcessorResult {
Pass,
Block {
reason: String,
processor: String,
},
Modify {
original: String,
modified: String,
changes: Vec<String>,
},
}
impl InputProcessorResult {
pub fn should_proceed(&self) -> bool {
matches!(self, Self::Pass | Self::Modify { .. })
}
pub fn is_blocked(&self) -> bool {
matches!(self, Self::Block { .. })
}
pub fn effective_input<'a>(&'a self, original: &'a str) -> &'a str {
match self {
Self::Modify { modified, .. } => modified,
_ => original,
}
}
}
#[async_trait]
pub trait InputProcessor: Send + Sync {
fn name(&self) -> &str;
fn priority(&self) -> u32 {
100 }
async fn process(
&self,
input: &str,
ctx: &PolicyContext,
) -> anyhow::Result<InputProcessorResult>;
}
pub struct InputProcessorPipeline {
processors: Vec<Arc<dyn InputProcessor>>,
}
impl InputProcessorPipeline {
pub fn new() -> Self {
Self { processors: vec![] }
}
#[allow(clippy::should_implement_trait)]
pub fn add(mut self, processor: Arc<dyn InputProcessor>) -> Self {
self.processors.push(processor);
self.processors.sort_by_key(|p| p.priority());
self
}
pub async fn process(
&self,
input: &str,
ctx: &PolicyContext,
) -> anyhow::Result<InputProcessorResult> {
let mut current_input = input.to_string();
let mut all_changes: Vec<String> = vec![];
let mut was_modified = false;
for processor in &self.processors {
let result = processor.process(¤t_input, ctx).await?;
match result {
InputProcessorResult::Pass => {
continue;
}
InputProcessorResult::Block { .. } => {
return Ok(result);
}
InputProcessorResult::Modify {
modified, changes, ..
} => {
was_modified = true;
all_changes.extend(changes);
current_input = modified;
}
}
}
if was_modified {
Ok(InputProcessorResult::Modify {
original: input.to_string(),
modified: current_input,
changes: all_changes,
})
} else {
Ok(InputProcessorResult::Pass)
}
}
pub fn is_empty(&self) -> bool {
self.processors.is_empty()
}
pub fn len(&self) -> usize {
self.processors.len()
}
}
impl Default for InputProcessorPipeline {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::policy::PolicyAction;
use std::collections::HashMap;
struct MockPassProcessor;
#[async_trait]
impl InputProcessor for MockPassProcessor {
fn name(&self) -> &str {
"mock-pass"
}
async fn process(
&self,
_input: &str,
_ctx: &PolicyContext,
) -> anyhow::Result<InputProcessorResult> {
Ok(InputProcessorResult::Pass)
}
}
struct MockBlockProcessor {
reason: String,
}
#[async_trait]
impl InputProcessor for MockBlockProcessor {
fn name(&self) -> &str {
"mock-block"
}
async fn process(
&self,
_input: &str,
_ctx: &PolicyContext,
) -> anyhow::Result<InputProcessorResult> {
Ok(InputProcessorResult::Block {
reason: self.reason.clone(),
processor: self.name().to_string(),
})
}
}
struct MockModifyProcessor {
suffix: String,
}
#[async_trait]
impl InputProcessor for MockModifyProcessor {
fn name(&self) -> &str {
"mock-modify"
}
async fn process(
&self,
input: &str,
_ctx: &PolicyContext,
) -> anyhow::Result<InputProcessorResult> {
Ok(InputProcessorResult::Modify {
original: input.to_string(),
modified: format!("{}{}", input, self.suffix),
changes: vec![format!("Added suffix: {}", self.suffix)],
})
}
}
fn test_context() -> PolicyContext {
PolicyContext {
tenant_id: Some("test-tenant".to_string()),
user_id: Some("test-user".to_string()),
action: PolicyAction::StartExecution { graph_id: None },
metadata: HashMap::new(),
}
}
#[test]
fn test_input_processor_result_should_proceed() {
assert!(InputProcessorResult::Pass.should_proceed());
assert!(InputProcessorResult::Modify {
original: "a".to_string(),
modified: "b".to_string(),
changes: vec![],
}
.should_proceed());
assert!(!InputProcessorResult::Block {
reason: "test".to_string(),
processor: "test".to_string(),
}
.should_proceed());
}
#[test]
fn test_input_processor_result_is_blocked() {
assert!(!InputProcessorResult::Pass.is_blocked());
assert!(InputProcessorResult::Block {
reason: "test".to_string(),
processor: "test".to_string(),
}
.is_blocked());
}
#[test]
fn test_input_processor_result_effective_input() {
let original = "hello";
assert_eq!(
InputProcessorResult::Pass.effective_input(original),
"hello"
);
let block = InputProcessorResult::Block {
reason: "blocked".to_string(),
processor: "test".to_string(),
};
assert_eq!(block.effective_input(original), "hello");
let modify = InputProcessorResult::Modify {
original: "hello".to_string(),
modified: "hello world".to_string(),
changes: vec![],
};
assert_eq!(modify.effective_input(original), "hello world");
}
#[tokio::test]
async fn test_pipeline_empty() {
let pipeline = InputProcessorPipeline::new();
assert!(pipeline.is_empty());
assert_eq!(pipeline.len(), 0);
let ctx = test_context();
let result = pipeline.process("test input", &ctx).await.unwrap();
assert!(matches!(result, InputProcessorResult::Pass));
}
#[tokio::test]
async fn test_pipeline_pass_through() {
let pipeline = InputProcessorPipeline::new().add(Arc::new(MockPassProcessor));
let ctx = test_context();
let result = pipeline.process("test input", &ctx).await.unwrap();
assert!(matches!(result, InputProcessorResult::Pass));
}
#[tokio::test]
async fn test_pipeline_block() {
let pipeline = InputProcessorPipeline::new()
.add(Arc::new(MockPassProcessor))
.add(Arc::new(MockBlockProcessor {
reason: "forbidden".to_string(),
}));
let ctx = test_context();
let result = pipeline.process("test input", &ctx).await.unwrap();
assert!(result.is_blocked());
if let InputProcessorResult::Block { reason, processor } = result {
assert_eq!(reason, "forbidden");
assert_eq!(processor, "mock-block");
}
}
#[tokio::test]
async fn test_pipeline_modify() {
let pipeline = InputProcessorPipeline::new().add(Arc::new(MockModifyProcessor {
suffix: " [sanitized]".to_string(),
}));
let ctx = test_context();
let result = pipeline.process("test input", &ctx).await.unwrap();
if let InputProcessorResult::Modify {
original,
modified,
changes,
} = result
{
assert_eq!(original, "test input");
assert_eq!(modified, "test input [sanitized]");
assert_eq!(changes.len(), 1);
} else {
panic!("Expected Modify result");
}
}
#[tokio::test]
async fn test_pipeline_chained_modify() {
let pipeline = InputProcessorPipeline::new()
.add(Arc::new(MockModifyProcessor {
suffix: " [a]".to_string(),
}))
.add(Arc::new(MockModifyProcessor {
suffix: " [b]".to_string(),
}));
let ctx = test_context();
let result = pipeline.process("input", &ctx).await.unwrap();
if let InputProcessorResult::Modify {
modified, changes, ..
} = result
{
assert_eq!(modified, "input [a] [b]");
assert_eq!(changes.len(), 2);
} else {
panic!("Expected Modify result");
}
}
}