enact_core/runner/
protected_runner.rs1use super::execution_runner::Runner;
10use crate::callable::Callable;
11use crate::graph::{CheckpointStore, CompiledGraph, NodeState};
12use crate::kernel::ExecutionId;
13use crate::policy::{
14 InputProcessor, InputProcessorPipeline, InputProcessorResult, PolicyAction, PolicyContext,
15};
16use crate::streaming::{EventEmitter, ProtectedEventEmitter};
17use std::sync::Arc;
18
19#[derive(Debug, Clone)]
21pub struct InputBlockedError {
22 pub reason: String,
23 pub processor: String,
24}
25
26impl std::fmt::Display for InputBlockedError {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 write!(f, "Input blocked by {}: {}", self.processor, self.reason)
29 }
30}
31
32impl std::error::Error for InputBlockedError {}
33
34pub struct ProtectedRunner<S: CheckpointStore> {
55 inner: Runner<S>,
56 input_pipeline: InputProcessorPipeline,
57 protected_emitter: Option<ProtectedEventEmitter>,
58}
59
60impl<S: CheckpointStore> ProtectedRunner<S> {
61 pub fn new(runner: Runner<S>) -> Self {
63 Self {
64 inner: runner,
65 input_pipeline: InputProcessorPipeline::new(),
66 protected_emitter: None,
67 }
68 }
69
70 pub fn with_input_processor(mut self, processor: Arc<dyn InputProcessor>) -> Self {
72 self.input_pipeline = self.input_pipeline.add(processor);
73 self
74 }
75
76 pub fn with_protected_emitter(mut self, emitter: ProtectedEventEmitter) -> Self {
78 self.protected_emitter = Some(emitter);
79 self
80 }
81
82 pub fn execution_id(&self) -> &ExecutionId {
84 self.inner.execution_id()
85 }
86
87 pub fn emitter(&self) -> &EventEmitter {
89 self.inner.emitter()
90 }
91
92 pub fn cancel(&self) {
94 self.inner.cancel();
95 }
96
97 pub fn is_cancelled(&self) -> bool {
99 self.inner.is_cancelled()
100 }
101
102 pub async fn pause(&self) -> anyhow::Result<()> {
104 self.inner.pause().await
105 }
106
107 pub fn resume(&self) {
109 self.inner.resume();
110 }
111
112 pub fn is_paused(&self) -> bool {
114 self.inner.is_paused()
115 }
116
117 fn create_policy_context(&self) -> PolicyContext {
119 PolicyContext {
120 tenant_id: None, user_id: None, action: PolicyAction::StartExecution { graph_id: None },
123 metadata: std::collections::HashMap::new(),
124 }
125 }
126
127 async fn process_input(&self, input: &str) -> anyhow::Result<String> {
129 if self.input_pipeline.is_empty() {
130 return Ok(input.to_string());
131 }
132
133 let ctx = self.create_policy_context();
134 let result = self.input_pipeline.process(input, &ctx).await?;
135
136 match result {
137 InputProcessorResult::Pass => Ok(input.to_string()),
138 InputProcessorResult::Block { reason, processor } => {
139 Err(InputBlockedError { reason, processor }.into())
140 }
141 InputProcessorResult::Modify { modified, .. } => Ok(modified),
142 }
143 }
144
145 pub async fn run_callable<A: Callable>(
147 &mut self,
148 callable: &A,
149 input: &str,
150 ) -> anyhow::Result<String> {
151 let processed_input = self.process_input(input).await?;
153
154 self.inner.run_callable(callable, &processed_input).await
156 }
157
158 pub async fn run_graph(
160 &mut self,
161 graph: &CompiledGraph,
162 input: &str,
163 ) -> anyhow::Result<NodeState> {
164 let processed_input = self.process_input(input).await?;
166
167 self.inner.run_graph(graph, &processed_input).await
169 }
170}
171
172pub type DefaultProtectedRunner = ProtectedRunner<crate::graph::InMemoryCheckpointStore>;
174
175impl DefaultProtectedRunner {
176 pub fn default_new() -> Self {
178 Self::new(crate::runner::DefaultRunner::default_new())
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use crate::runner::DefaultRunner;
186 use async_trait::async_trait;
187
188 struct MockCallable {
189 response: String,
190 }
191
192 impl MockCallable {
193 fn new(response: &str) -> Self {
194 Self {
195 response: response.to_string(),
196 }
197 }
198 }
199
200 #[async_trait]
201 impl Callable for MockCallable {
202 fn name(&self) -> &str {
203 "mock"
204 }
205 async fn run(&self, _input: &str) -> anyhow::Result<String> {
206 Ok(self.response.clone())
207 }
208 }
209
210 struct BlockingProcessor;
211
212 #[async_trait]
213 impl InputProcessor for BlockingProcessor {
214 fn name(&self) -> &str {
215 "blocker"
216 }
217
218 async fn process(
219 &self,
220 _input: &str,
221 _ctx: &PolicyContext,
222 ) -> anyhow::Result<InputProcessorResult> {
223 Ok(InputProcessorResult::Block {
224 reason: "Always blocks".to_string(),
225 processor: "blocker".to_string(),
226 })
227 }
228 }
229
230 #[tokio::test]
231 async fn test_protected_runner_no_processors() {
232 let runner = DefaultRunner::default_new();
233 let mut protected = ProtectedRunner::new(runner);
234 let callable = MockCallable::new("response");
235
236 let result = protected.run_callable(&callable, "input").await;
237 assert!(result.is_ok());
238 assert_eq!(result.unwrap(), "response");
239 }
240
241 #[tokio::test]
242 async fn test_protected_runner_blocked_input() {
243 let runner = DefaultRunner::default_new();
244 let mut protected =
245 ProtectedRunner::new(runner).with_input_processor(Arc::new(BlockingProcessor));
246 let callable = MockCallable::new("response");
247
248 let result = protected.run_callable(&callable, "input").await;
249 assert!(result.is_err());
250 assert!(result.unwrap_err().to_string().contains("blocked"));
251 }
252
253 #[tokio::test]
254 async fn test_protected_runner_execution_id() {
255 let runner = DefaultRunner::default_new();
256 let protected = ProtectedRunner::new(runner);
257
258 assert!(!protected.execution_id().as_str().is_empty());
260 }
261
262 #[tokio::test]
263 async fn test_protected_runner_cancel() {
264 let runner = DefaultRunner::default_new();
265 let protected = ProtectedRunner::new(runner);
266
267 assert!(!protected.is_cancelled());
268 protected.cancel();
269 assert!(protected.is_cancelled());
270 }
271
272 #[tokio::test]
273 async fn test_default_protected_runner() {
274 let mut protected = DefaultProtectedRunner::default_new();
275 let callable = MockCallable::new("hello");
276
277 let result = protected.run_callable(&callable, "test").await;
278 assert!(result.is_ok());
279 }
280}