enact-core 0.0.2

Core agent runtime for Enact - Graph-Native AI agents
Documentation
//! Protected Runner
//!
//! Wraps Runner with input processor pipeline to validate/transform input
//! before execution begins.
//!
//! @see docs/TECHNICAL/17-GUARDRAILS-PROTECTION.md
//! @see docs/TECHNICAL/25-STREAM-PROCESSORS.md

use super::execution_runner::Runner;
use crate::callable::Callable;
use crate::graph::{CheckpointStore, CompiledGraph, NodeState};
use crate::kernel::ExecutionId;
use crate::policy::{
    InputProcessor, InputProcessorPipeline, InputProcessorResult, PolicyAction, PolicyContext,
};
use crate::streaming::{EventEmitter, ProtectedEventEmitter};
use std::sync::Arc;

/// Error returned when input is blocked by processors
#[derive(Debug, Clone)]
pub struct InputBlockedError {
    pub reason: String,
    pub processor: String,
}

impl std::fmt::Display for InputBlockedError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "Input blocked by {}: {}", self.processor, self.reason)
    }
}

impl std::error::Error for InputBlockedError {}

/// Protected Runner
///
/// Wraps a Runner with input processor pipeline and optional protected emitter.
/// Input is validated/transformed BEFORE execution begins.
///
/// ## Usage
///
/// ```ignore
/// use enact_core::runner::{ProtectedRunner, DefaultRunner};
/// use enact_core::policy::{PiiInputProcessor, PiiInputMode};
///
/// let runner = DefaultRunner::default_new();
/// let protected = ProtectedRunner::new(runner)
///     .with_input_processor(Arc::new(
///         PiiInputProcessor::new().with_mode(PiiInputMode::BlockDirect)
///     ));
///
/// // Input will be validated before callable runs
/// let result = protected.run_callable(&my_callable, "user input").await;
/// ```
pub struct ProtectedRunner<S: CheckpointStore> {
    inner: Runner<S>,
    input_pipeline: InputProcessorPipeline,
    protected_emitter: Option<ProtectedEventEmitter>,
}

impl<S: CheckpointStore> ProtectedRunner<S> {
    /// Create a new protected runner wrapping an existing runner
    pub fn new(runner: Runner<S>) -> Self {
        Self {
            inner: runner,
            input_pipeline: InputProcessorPipeline::new(),
            protected_emitter: None,
        }
    }

    /// Add an input processor to the pipeline
    pub fn with_input_processor(mut self, processor: Arc<dyn InputProcessor>) -> Self {
        self.input_pipeline = self.input_pipeline.add(processor);
        self
    }

    /// Set a protected emitter for output processing
    pub fn with_protected_emitter(mut self, emitter: ProtectedEventEmitter) -> Self {
        self.protected_emitter = Some(emitter);
        self
    }

    /// Get the execution ID
    pub fn execution_id(&self) -> &ExecutionId {
        self.inner.execution_id()
    }

    /// Get the event emitter (from inner runner)
    pub fn emitter(&self) -> &EventEmitter {
        self.inner.emitter()
    }

    /// Cancel the run
    pub fn cancel(&self) {
        self.inner.cancel();
    }

    /// Check if cancelled
    pub fn is_cancelled(&self) -> bool {
        self.inner.is_cancelled()
    }

    /// Pause the run
    pub async fn pause(&self) -> anyhow::Result<()> {
        self.inner.pause().await
    }

    /// Resume the run
    pub fn resume(&self) {
        self.inner.resume();
    }

    /// Check if paused
    pub fn is_paused(&self) -> bool {
        self.inner.is_paused()
    }

    /// Create policy context for input processing
    fn create_policy_context(&self) -> PolicyContext {
        PolicyContext {
            tenant_id: None, // Could be set from runner context
            user_id: None,   // Could be set from runner context
            action: PolicyAction::StartExecution { graph_id: None },
            metadata: std::collections::HashMap::new(),
        }
    }

    /// Process input through the pipeline
    async fn process_input(&self, input: &str) -> anyhow::Result<String> {
        if self.input_pipeline.is_empty() {
            return Ok(input.to_string());
        }

        let ctx = self.create_policy_context();
        let result = self.input_pipeline.process(input, &ctx).await?;

        match result {
            InputProcessorResult::Pass => Ok(input.to_string()),
            InputProcessorResult::Block { reason, processor } => {
                Err(InputBlockedError { reason, processor }.into())
            }
            InputProcessorResult::Modify { modified, .. } => Ok(modified),
        }
    }

    /// Run a callable with input validation
    pub async fn run_callable<A: Callable>(
        &mut self,
        callable: &A,
        input: &str,
    ) -> anyhow::Result<String> {
        // Process input through pipeline
        let processed_input = self.process_input(input).await?;

        // Run with processed input
        self.inner.run_callable(callable, &processed_input).await
    }

    /// Run a graph with input validation
    pub async fn run_graph(
        &mut self,
        graph: &CompiledGraph,
        input: &str,
    ) -> anyhow::Result<NodeState> {
        // Process input through pipeline
        let processed_input = self.process_input(input).await?;

        // Run with processed input
        self.inner.run_graph(graph, &processed_input).await
    }
}

/// Protected Runner with in-memory checkpoint store (default)
pub type DefaultProtectedRunner = ProtectedRunner<crate::graph::InMemoryCheckpointStore>;

impl DefaultProtectedRunner {
    /// Create a new protected runner with in-memory checkpoint store
    pub fn default_new() -> Self {
        Self::new(crate::runner::DefaultRunner::default_new())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::runner::DefaultRunner;
    use async_trait::async_trait;

    struct MockCallable {
        response: String,
    }

    impl MockCallable {
        fn new(response: &str) -> Self {
            Self {
                response: response.to_string(),
            }
        }
    }

    #[async_trait]
    impl Callable for MockCallable {
        fn name(&self) -> &str {
            "mock"
        }
        async fn run(&self, _input: &str) -> anyhow::Result<String> {
            Ok(self.response.clone())
        }
    }

    struct BlockingProcessor;

    #[async_trait]
    impl InputProcessor for BlockingProcessor {
        fn name(&self) -> &str {
            "blocker"
        }

        async fn process(
            &self,
            _input: &str,
            _ctx: &PolicyContext,
        ) -> anyhow::Result<InputProcessorResult> {
            Ok(InputProcessorResult::Block {
                reason: "Always blocks".to_string(),
                processor: "blocker".to_string(),
            })
        }
    }

    #[tokio::test]
    async fn test_protected_runner_no_processors() {
        let runner = DefaultRunner::default_new();
        let mut protected = ProtectedRunner::new(runner);
        let callable = MockCallable::new("response");

        let result = protected.run_callable(&callable, "input").await;
        assert!(result.is_ok());
        assert_eq!(result.unwrap(), "response");
    }

    #[tokio::test]
    async fn test_protected_runner_blocked_input() {
        let runner = DefaultRunner::default_new();
        let mut protected =
            ProtectedRunner::new(runner).with_input_processor(Arc::new(BlockingProcessor));
        let callable = MockCallable::new("response");

        let result = protected.run_callable(&callable, "input").await;
        assert!(result.is_err());
        assert!(result.unwrap_err().to_string().contains("blocked"));
    }

    #[tokio::test]
    async fn test_protected_runner_execution_id() {
        let runner = DefaultRunner::default_new();
        let protected = ProtectedRunner::new(runner);

        // Should have a valid execution ID
        assert!(!protected.execution_id().as_str().is_empty());
    }

    #[tokio::test]
    async fn test_protected_runner_cancel() {
        let runner = DefaultRunner::default_new();
        let protected = ProtectedRunner::new(runner);

        assert!(!protected.is_cancelled());
        protected.cancel();
        assert!(protected.is_cancelled());
    }

    #[tokio::test]
    async fn test_default_protected_runner() {
        let mut protected = DefaultProtectedRunner::default_new();
        let callable = MockCallable::new("hello");

        let result = protected.run_callable(&callable, "test").await;
        assert!(result.is_ok());
    }
}