use super::execution_runner::Runner;
use crate::callable::Callable;
use crate::graph::{CheckpointStore, CompiledGraph, NodeState};
use crate::kernel::ExecutionId;
use crate::policy::{
InputProcessor, InputProcessorPipeline, InputProcessorResult, PolicyAction, PolicyContext,
};
use crate::streaming::{EventEmitter, ProtectedEventEmitter};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct InputBlockedError {
pub reason: String,
pub processor: String,
}
impl std::fmt::Display for InputBlockedError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Input blocked by {}: {}", self.processor, self.reason)
}
}
impl std::error::Error for InputBlockedError {}
pub struct ProtectedRunner<S: CheckpointStore> {
inner: Runner<S>,
input_pipeline: InputProcessorPipeline,
protected_emitter: Option<ProtectedEventEmitter>,
}
impl<S: CheckpointStore> ProtectedRunner<S> {
pub fn new(runner: Runner<S>) -> Self {
Self {
inner: runner,
input_pipeline: InputProcessorPipeline::new(),
protected_emitter: None,
}
}
pub fn with_input_processor(mut self, processor: Arc<dyn InputProcessor>) -> Self {
self.input_pipeline = self.input_pipeline.add(processor);
self
}
pub fn with_protected_emitter(mut self, emitter: ProtectedEventEmitter) -> Self {
self.protected_emitter = Some(emitter);
self
}
pub fn execution_id(&self) -> &ExecutionId {
self.inner.execution_id()
}
pub fn emitter(&self) -> &EventEmitter {
self.inner.emitter()
}
pub fn cancel(&self) {
self.inner.cancel();
}
pub fn is_cancelled(&self) -> bool {
self.inner.is_cancelled()
}
pub async fn pause(&self) -> anyhow::Result<()> {
self.inner.pause().await
}
pub fn resume(&self) {
self.inner.resume();
}
pub fn is_paused(&self) -> bool {
self.inner.is_paused()
}
fn create_policy_context(&self) -> PolicyContext {
PolicyContext {
tenant_id: None, user_id: None, action: PolicyAction::StartExecution { graph_id: None },
metadata: std::collections::HashMap::new(),
}
}
async fn process_input(&self, input: &str) -> anyhow::Result<String> {
if self.input_pipeline.is_empty() {
return Ok(input.to_string());
}
let ctx = self.create_policy_context();
let result = self.input_pipeline.process(input, &ctx).await?;
match result {
InputProcessorResult::Pass => Ok(input.to_string()),
InputProcessorResult::Block { reason, processor } => {
Err(InputBlockedError { reason, processor }.into())
}
InputProcessorResult::Modify { modified, .. } => Ok(modified),
}
}
pub async fn run_callable<A: Callable>(
&mut self,
callable: &A,
input: &str,
) -> anyhow::Result<String> {
let processed_input = self.process_input(input).await?;
self.inner.run_callable(callable, &processed_input).await
}
pub async fn run_graph(
&mut self,
graph: &CompiledGraph,
input: &str,
) -> anyhow::Result<NodeState> {
let processed_input = self.process_input(input).await?;
self.inner.run_graph(graph, &processed_input).await
}
}
pub type DefaultProtectedRunner = ProtectedRunner<crate::graph::InMemoryCheckpointStore>;
impl DefaultProtectedRunner {
pub fn default_new() -> Self {
Self::new(crate::runner::DefaultRunner::default_new())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runner::DefaultRunner;
use async_trait::async_trait;
struct MockCallable {
response: String,
}
impl MockCallable {
fn new(response: &str) -> Self {
Self {
response: response.to_string(),
}
}
}
#[async_trait]
impl Callable for MockCallable {
fn name(&self) -> &str {
"mock"
}
async fn run(&self, _input: &str) -> anyhow::Result<String> {
Ok(self.response.clone())
}
}
struct BlockingProcessor;
#[async_trait]
impl InputProcessor for BlockingProcessor {
fn name(&self) -> &str {
"blocker"
}
async fn process(
&self,
_input: &str,
_ctx: &PolicyContext,
) -> anyhow::Result<InputProcessorResult> {
Ok(InputProcessorResult::Block {
reason: "Always blocks".to_string(),
processor: "blocker".to_string(),
})
}
}
#[tokio::test]
async fn test_protected_runner_no_processors() {
let runner = DefaultRunner::default_new();
let mut protected = ProtectedRunner::new(runner);
let callable = MockCallable::new("response");
let result = protected.run_callable(&callable, "input").await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "response");
}
#[tokio::test]
async fn test_protected_runner_blocked_input() {
let runner = DefaultRunner::default_new();
let mut protected =
ProtectedRunner::new(runner).with_input_processor(Arc::new(BlockingProcessor));
let callable = MockCallable::new("response");
let result = protected.run_callable(&callable, "input").await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("blocked"));
}
#[tokio::test]
async fn test_protected_runner_execution_id() {
let runner = DefaultRunner::default_new();
let protected = ProtectedRunner::new(runner);
assert!(!protected.execution_id().as_str().is_empty());
}
#[tokio::test]
async fn test_protected_runner_cancel() {
let runner = DefaultRunner::default_new();
let protected = ProtectedRunner::new(runner);
assert!(!protected.is_cancelled());
protected.cancel();
assert!(protected.is_cancelled());
}
#[tokio::test]
async fn test_default_protected_runner() {
let mut protected = DefaultProtectedRunner::default_new();
let callable = MockCallable::new("hello");
let result = protected.run_callable(&callable, "test").await;
assert!(result.is_ok());
}
}