Skip to main content

enact_core/runner/
protected_runner.rs

1//! Protected Runner
2//!
3//! Wraps Runner with input processor pipeline to validate/transform input
4//! before execution begins.
5//!
6//! @see docs/TECHNICAL/17-GUARDRAILS-PROTECTION.md
7//! @see docs/TECHNICAL/25-STREAM-PROCESSORS.md
8
9use super::execution_runner::Runner;
10use crate::callable::Callable;
11use crate::graph::{CheckpointStore, CompiledGraph, NodeState};
12use crate::kernel::ExecutionId;
13use crate::policy::{
14    InputProcessor, InputProcessorPipeline, InputProcessorResult, PolicyAction, PolicyContext,
15};
16use crate::streaming::{EventEmitter, ProtectedEventEmitter};
17use std::sync::Arc;
18
19/// Error returned when input is blocked by processors
20#[derive(Debug, Clone)]
21pub struct InputBlockedError {
22    pub reason: String,
23    pub processor: String,
24}
25
26impl std::fmt::Display for InputBlockedError {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        write!(f, "Input blocked by {}: {}", self.processor, self.reason)
29    }
30}
31
32impl std::error::Error for InputBlockedError {}
33
34/// Protected Runner
35///
36/// Wraps a Runner with input processor pipeline and optional protected emitter.
37/// Input is validated/transformed BEFORE execution begins.
38///
39/// ## Usage
40///
41/// ```ignore
42/// use enact_core::runner::{ProtectedRunner, DefaultRunner};
43/// use enact_core::policy::{PiiInputProcessor, PiiInputMode};
44///
45/// let runner = DefaultRunner::default_new();
46/// let protected = ProtectedRunner::new(runner)
47///     .with_input_processor(Arc::new(
48///         PiiInputProcessor::new().with_mode(PiiInputMode::BlockDirect)
49///     ));
50///
51/// // Input will be validated before callable runs
52/// let result = protected.run_callable(&my_callable, "user input").await;
53/// ```
54pub struct ProtectedRunner<S: CheckpointStore> {
55    inner: Runner<S>,
56    input_pipeline: InputProcessorPipeline,
57    protected_emitter: Option<ProtectedEventEmitter>,
58}
59
60impl<S: CheckpointStore> ProtectedRunner<S> {
61    /// Create a new protected runner wrapping an existing runner
62    pub fn new(runner: Runner<S>) -> Self {
63        Self {
64            inner: runner,
65            input_pipeline: InputProcessorPipeline::new(),
66            protected_emitter: None,
67        }
68    }
69
70    /// Add an input processor to the pipeline
71    pub fn with_input_processor(mut self, processor: Arc<dyn InputProcessor>) -> Self {
72        self.input_pipeline = self.input_pipeline.add(processor);
73        self
74    }
75
76    /// Set a protected emitter for output processing
77    pub fn with_protected_emitter(mut self, emitter: ProtectedEventEmitter) -> Self {
78        self.protected_emitter = Some(emitter);
79        self
80    }
81
82    /// Get the execution ID
83    pub fn execution_id(&self) -> &ExecutionId {
84        self.inner.execution_id()
85    }
86
87    /// Get the event emitter (from inner runner)
88    pub fn emitter(&self) -> &EventEmitter {
89        self.inner.emitter()
90    }
91
92    /// Cancel the run
93    pub fn cancel(&self) {
94        self.inner.cancel();
95    }
96
97    /// Check if cancelled
98    pub fn is_cancelled(&self) -> bool {
99        self.inner.is_cancelled()
100    }
101
102    /// Pause the run
103    pub async fn pause(&self) -> anyhow::Result<()> {
104        self.inner.pause().await
105    }
106
107    /// Resume the run
108    pub fn resume(&self) {
109        self.inner.resume();
110    }
111
112    /// Check if paused
113    pub fn is_paused(&self) -> bool {
114        self.inner.is_paused()
115    }
116
117    /// Create policy context for input processing
118    fn create_policy_context(&self) -> PolicyContext {
119        PolicyContext {
120            tenant_id: None, // Could be set from runner context
121            user_id: None,   // Could be set from runner context
122            action: PolicyAction::StartExecution { graph_id: None },
123            metadata: std::collections::HashMap::new(),
124        }
125    }
126
127    /// Process input through the pipeline
128    async fn process_input(&self, input: &str) -> anyhow::Result<String> {
129        if self.input_pipeline.is_empty() {
130            return Ok(input.to_string());
131        }
132
133        let ctx = self.create_policy_context();
134        let result = self.input_pipeline.process(input, &ctx).await?;
135
136        match result {
137            InputProcessorResult::Pass => Ok(input.to_string()),
138            InputProcessorResult::Block { reason, processor } => {
139                Err(InputBlockedError { reason, processor }.into())
140            }
141            InputProcessorResult::Modify { modified, .. } => Ok(modified),
142        }
143    }
144
145    /// Run a callable with input validation
146    pub async fn run_callable<A: Callable>(
147        &mut self,
148        callable: &A,
149        input: &str,
150    ) -> anyhow::Result<String> {
151        // Process input through pipeline
152        let processed_input = self.process_input(input).await?;
153
154        // Run with processed input
155        self.inner.run_callable(callable, &processed_input).await
156    }
157
158    /// Run a graph with input validation
159    pub async fn run_graph(
160        &mut self,
161        graph: &CompiledGraph,
162        input: &str,
163    ) -> anyhow::Result<NodeState> {
164        // Process input through pipeline
165        let processed_input = self.process_input(input).await?;
166
167        // Run with processed input
168        self.inner.run_graph(graph, &processed_input).await
169    }
170}
171
172/// Protected Runner with in-memory checkpoint store (default)
173pub type DefaultProtectedRunner = ProtectedRunner<crate::graph::InMemoryCheckpointStore>;
174
175impl DefaultProtectedRunner {
176    /// Create a new protected runner with in-memory checkpoint store
177    pub fn default_new() -> Self {
178        Self::new(crate::runner::DefaultRunner::default_new())
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use crate::runner::DefaultRunner;
186    use async_trait::async_trait;
187
188    struct MockCallable {
189        response: String,
190    }
191
192    impl MockCallable {
193        fn new(response: &str) -> Self {
194            Self {
195                response: response.to_string(),
196            }
197        }
198    }
199
200    #[async_trait]
201    impl Callable for MockCallable {
202        fn name(&self) -> &str {
203            "mock"
204        }
205        async fn run(&self, _input: &str) -> anyhow::Result<String> {
206            Ok(self.response.clone())
207        }
208    }
209
210    struct BlockingProcessor;
211
212    #[async_trait]
213    impl InputProcessor for BlockingProcessor {
214        fn name(&self) -> &str {
215            "blocker"
216        }
217
218        async fn process(
219            &self,
220            _input: &str,
221            _ctx: &PolicyContext,
222        ) -> anyhow::Result<InputProcessorResult> {
223            Ok(InputProcessorResult::Block {
224                reason: "Always blocks".to_string(),
225                processor: "blocker".to_string(),
226            })
227        }
228    }
229
230    #[tokio::test]
231    async fn test_protected_runner_no_processors() {
232        let runner = DefaultRunner::default_new();
233        let mut protected = ProtectedRunner::new(runner);
234        let callable = MockCallable::new("response");
235
236        let result = protected.run_callable(&callable, "input").await;
237        assert!(result.is_ok());
238        assert_eq!(result.unwrap(), "response");
239    }
240
241    #[tokio::test]
242    async fn test_protected_runner_blocked_input() {
243        let runner = DefaultRunner::default_new();
244        let mut protected =
245            ProtectedRunner::new(runner).with_input_processor(Arc::new(BlockingProcessor));
246        let callable = MockCallable::new("response");
247
248        let result = protected.run_callable(&callable, "input").await;
249        assert!(result.is_err());
250        assert!(result.unwrap_err().to_string().contains("blocked"));
251    }
252
253    #[tokio::test]
254    async fn test_protected_runner_execution_id() {
255        let runner = DefaultRunner::default_new();
256        let protected = ProtectedRunner::new(runner);
257
258        // Should have a valid execution ID
259        assert!(!protected.execution_id().as_str().is_empty());
260    }
261
262    #[tokio::test]
263    async fn test_protected_runner_cancel() {
264        let runner = DefaultRunner::default_new();
265        let protected = ProtectedRunner::new(runner);
266
267        assert!(!protected.is_cancelled());
268        protected.cancel();
269        assert!(protected.is_cancelled());
270    }
271
272    #[tokio::test]
273    async fn test_default_protected_runner() {
274        let mut protected = DefaultProtectedRunner::default_new();
275        let callable = MockCallable::new("hello");
276
277        let result = protected.run_callable(&callable, "test").await;
278        assert!(result.is_ok());
279    }
280}