1use super::PolicyContext;
32use async_trait::async_trait;
33use std::sync::Arc;
34
35#[derive(Debug, Clone)]
37pub enum InputProcessorResult {
38 Pass,
40
41 Block {
43 reason: String,
44 processor: String,
46 },
47
48 Modify {
50 original: String,
51 modified: String,
52 changes: Vec<String>,
54 },
55}
56
57impl InputProcessorResult {
58 pub fn should_proceed(&self) -> bool {
60 matches!(self, Self::Pass | Self::Modify { .. })
61 }
62
63 pub fn is_blocked(&self) -> bool {
65 matches!(self, Self::Block { .. })
66 }
67
68 pub fn effective_input<'a>(&'a self, original: &'a str) -> &'a str {
70 match self {
71 Self::Modify { modified, .. } => modified,
72 _ => original,
73 }
74 }
75}
76
77#[async_trait]
82pub trait InputProcessor: Send + Sync {
83 fn name(&self) -> &str;
85
86 fn priority(&self) -> u32 {
88 100 }
90
91 async fn process(
102 &self,
103 input: &str,
104 ctx: &PolicyContext,
105 ) -> anyhow::Result<InputProcessorResult>;
106}
107
108pub struct InputProcessorPipeline {
110 processors: Vec<Arc<dyn InputProcessor>>,
111}
112
113impl InputProcessorPipeline {
114 pub fn new() -> Self {
116 Self { processors: vec![] }
117 }
118
119 #[allow(clippy::should_implement_trait)]
121 pub fn add(mut self, processor: Arc<dyn InputProcessor>) -> Self {
122 self.processors.push(processor);
123 self.processors.sort_by_key(|p| p.priority());
125 self
126 }
127
128 pub async fn process(
133 &self,
134 input: &str,
135 ctx: &PolicyContext,
136 ) -> anyhow::Result<InputProcessorResult> {
137 let mut current_input = input.to_string();
138 let mut all_changes: Vec<String> = vec![];
139 let mut was_modified = false;
140
141 for processor in &self.processors {
142 let result = processor.process(¤t_input, ctx).await?;
143
144 match result {
145 InputProcessorResult::Pass => {
146 continue;
148 }
149 InputProcessorResult::Block { .. } => {
150 return Ok(result);
152 }
153 InputProcessorResult::Modify {
154 modified, changes, ..
155 } => {
156 was_modified = true;
158 all_changes.extend(changes);
159 current_input = modified;
160 }
161 }
162 }
163
164 if was_modified {
166 Ok(InputProcessorResult::Modify {
167 original: input.to_string(),
168 modified: current_input,
169 changes: all_changes,
170 })
171 } else {
172 Ok(InputProcessorResult::Pass)
173 }
174 }
175
176 pub fn is_empty(&self) -> bool {
178 self.processors.is_empty()
179 }
180
181 pub fn len(&self) -> usize {
183 self.processors.len()
184 }
185}
186
187impl Default for InputProcessorPipeline {
188 fn default() -> Self {
189 Self::new()
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use crate::policy::PolicyAction;
197 use std::collections::HashMap;
198
199 struct MockPassProcessor;
201
202 #[async_trait]
203 impl InputProcessor for MockPassProcessor {
204 fn name(&self) -> &str {
205 "mock-pass"
206 }
207
208 async fn process(
209 &self,
210 _input: &str,
211 _ctx: &PolicyContext,
212 ) -> anyhow::Result<InputProcessorResult> {
213 Ok(InputProcessorResult::Pass)
214 }
215 }
216
217 struct MockBlockProcessor {
218 reason: String,
219 }
220
221 #[async_trait]
222 impl InputProcessor for MockBlockProcessor {
223 fn name(&self) -> &str {
224 "mock-block"
225 }
226
227 async fn process(
228 &self,
229 _input: &str,
230 _ctx: &PolicyContext,
231 ) -> anyhow::Result<InputProcessorResult> {
232 Ok(InputProcessorResult::Block {
233 reason: self.reason.clone(),
234 processor: self.name().to_string(),
235 })
236 }
237 }
238
239 struct MockModifyProcessor {
240 suffix: String,
241 }
242
243 #[async_trait]
244 impl InputProcessor for MockModifyProcessor {
245 fn name(&self) -> &str {
246 "mock-modify"
247 }
248
249 async fn process(
250 &self,
251 input: &str,
252 _ctx: &PolicyContext,
253 ) -> anyhow::Result<InputProcessorResult> {
254 Ok(InputProcessorResult::Modify {
255 original: input.to_string(),
256 modified: format!("{}{}", input, self.suffix),
257 changes: vec![format!("Added suffix: {}", self.suffix)],
258 })
259 }
260 }
261
262 fn test_context() -> PolicyContext {
263 PolicyContext {
264 tenant_id: Some("test-tenant".to_string()),
265 user_id: Some("test-user".to_string()),
266 action: PolicyAction::StartExecution { graph_id: None },
267 metadata: HashMap::new(),
268 }
269 }
270
271 #[test]
272 fn test_input_processor_result_should_proceed() {
273 assert!(InputProcessorResult::Pass.should_proceed());
274 assert!(InputProcessorResult::Modify {
275 original: "a".to_string(),
276 modified: "b".to_string(),
277 changes: vec![],
278 }
279 .should_proceed());
280 assert!(!InputProcessorResult::Block {
281 reason: "test".to_string(),
282 processor: "test".to_string(),
283 }
284 .should_proceed());
285 }
286
287 #[test]
288 fn test_input_processor_result_is_blocked() {
289 assert!(!InputProcessorResult::Pass.is_blocked());
290 assert!(InputProcessorResult::Block {
291 reason: "test".to_string(),
292 processor: "test".to_string(),
293 }
294 .is_blocked());
295 }
296
297 #[test]
298 fn test_input_processor_result_effective_input() {
299 let original = "hello";
300
301 assert_eq!(
303 InputProcessorResult::Pass.effective_input(original),
304 "hello"
305 );
306
307 let block = InputProcessorResult::Block {
309 reason: "blocked".to_string(),
310 processor: "test".to_string(),
311 };
312 assert_eq!(block.effective_input(original), "hello");
313
314 let modify = InputProcessorResult::Modify {
316 original: "hello".to_string(),
317 modified: "hello world".to_string(),
318 changes: vec![],
319 };
320 assert_eq!(modify.effective_input(original), "hello world");
321 }
322
323 #[tokio::test]
324 async fn test_pipeline_empty() {
325 let pipeline = InputProcessorPipeline::new();
326 assert!(pipeline.is_empty());
327 assert_eq!(pipeline.len(), 0);
328
329 let ctx = test_context();
330 let result = pipeline.process("test input", &ctx).await.unwrap();
331 assert!(matches!(result, InputProcessorResult::Pass));
332 }
333
334 #[tokio::test]
335 async fn test_pipeline_pass_through() {
336 let pipeline = InputProcessorPipeline::new().add(Arc::new(MockPassProcessor));
337
338 let ctx = test_context();
339 let result = pipeline.process("test input", &ctx).await.unwrap();
340 assert!(matches!(result, InputProcessorResult::Pass));
341 }
342
343 #[tokio::test]
344 async fn test_pipeline_block() {
345 let pipeline = InputProcessorPipeline::new()
346 .add(Arc::new(MockPassProcessor))
347 .add(Arc::new(MockBlockProcessor {
348 reason: "forbidden".to_string(),
349 }));
350
351 let ctx = test_context();
352 let result = pipeline.process("test input", &ctx).await.unwrap();
353
354 assert!(result.is_blocked());
355 if let InputProcessorResult::Block { reason, processor } = result {
356 assert_eq!(reason, "forbidden");
357 assert_eq!(processor, "mock-block");
358 }
359 }
360
361 #[tokio::test]
362 async fn test_pipeline_modify() {
363 let pipeline = InputProcessorPipeline::new().add(Arc::new(MockModifyProcessor {
364 suffix: " [sanitized]".to_string(),
365 }));
366
367 let ctx = test_context();
368 let result = pipeline.process("test input", &ctx).await.unwrap();
369
370 if let InputProcessorResult::Modify {
371 original,
372 modified,
373 changes,
374 } = result
375 {
376 assert_eq!(original, "test input");
377 assert_eq!(modified, "test input [sanitized]");
378 assert_eq!(changes.len(), 1);
379 } else {
380 panic!("Expected Modify result");
381 }
382 }
383
384 #[tokio::test]
385 async fn test_pipeline_chained_modify() {
386 let pipeline = InputProcessorPipeline::new()
387 .add(Arc::new(MockModifyProcessor {
388 suffix: " [a]".to_string(),
389 }))
390 .add(Arc::new(MockModifyProcessor {
391 suffix: " [b]".to_string(),
392 }));
393
394 let ctx = test_context();
395 let result = pipeline.process("input", &ctx).await.unwrap();
396
397 if let InputProcessorResult::Modify {
398 modified, changes, ..
399 } = result
400 {
401 assert_eq!(modified, "input [a] [b]");
403 assert_eq!(changes.len(), 2);
404 } else {
405 panic!("Expected Modify result");
406 }
407 }
408}