enact-core 0.0.2

Core agent runtime for Enact - Graph-Native AI agents
Documentation
//! Conditional Flow - If/then/else branching
//!
//! Route execution based on conditions evaluated against the input.

use crate::callable::Callable;
use std::sync::Arc;

/// Condition for branching
pub type Condition = Box<dyn Fn(&str) -> bool + Send + Sync>;

/// A conditional branch
pub struct Branch<C: Callable> {
    /// Condition to evaluate
    pub condition: Condition,
    /// Callable to execute if condition is true
    pub callable: Arc<C>,
    /// Branch name (for debugging/logging)
    pub name: String,
}

impl<C: Callable> Branch<C> {
    /// Create a new branch
    pub fn new(
        name: impl Into<String>,
        condition: impl Fn(&str) -> bool + Send + Sync + 'static,
        callable: Arc<C>,
    ) -> Self {
        Self {
            condition: Box::new(condition),
            callable,
            name: name.into(),
        }
    }

    /// Create an "always true" branch (default/else)
    pub fn default(name: impl Into<String>, callable: Arc<C>) -> Self {
        Self {
            condition: Box::new(|_| true),
            callable,
            name: name.into(),
        }
    }

    /// Check if this branch matches the input
    pub fn matches(&self, input: &str) -> bool {
        (self.condition)(input)
    }
}

/// Conditional execution flow
pub struct ConditionalFlow<C: Callable> {
    /// Ordered list of branches (first match wins)
    branches: Vec<Branch<C>>,
    /// Flow name
    name: String,
    /// Default branch if no conditions match
    default: Option<Arc<C>>,
}

impl<C: Callable> ConditionalFlow<C> {
    /// Create a new conditional flow
    pub fn new(name: impl Into<String>) -> Self {
        Self {
            branches: Vec::new(),
            name: name.into(),
            default: None,
        }
    }

    /// Add a conditional branch
    pub fn add_branch(mut self, branch: Branch<C>) -> Self {
        self.branches.push(branch);
        self
    }

    /// Add a condition with callable
    pub fn when(
        mut self,
        name: impl Into<String>,
        condition: impl Fn(&str) -> bool + Send + Sync + 'static,
        callable: Arc<C>,
    ) -> Self {
        self.branches.push(Branch::new(name, condition, callable));
        self
    }

    /// Set default (else) branch
    pub fn otherwise(mut self, callable: Arc<C>) -> Self {
        self.default = Some(callable);
        self
    }

    /// Execute the flow - first matching branch wins
    pub async fn execute(&self, input: &str) -> anyhow::Result<String> {
        // Find first matching branch
        for branch in &self.branches {
            if branch.matches(input) {
                return branch.callable.run(input).await;
            }
        }

        // Fall back to default
        if let Some(default) = &self.default {
            return default.run(input).await;
        }

        anyhow::bail!("No matching branch and no default")
    }

    /// Get the flow name
    pub fn name(&self) -> &str {
        &self.name
    }

    /// Get branch count
    pub fn branch_count(&self) -> usize {
        self.branches.len()
    }
}

// Common condition builders as free functions
/// Create a condition that checks if input contains a string
pub fn contains_condition(needle: impl Into<String>) -> Box<dyn Fn(&str) -> bool + Send + Sync> {
    let needle = needle.into();
    Box::new(move |input: &str| input.contains(&needle))
}

/// Create a condition that checks if input starts with a string
pub fn starts_with_condition(prefix: impl Into<String>) -> Box<dyn Fn(&str) -> bool + Send + Sync> {
    let prefix = prefix.into();
    Box::new(move |input: &str| input.starts_with(&prefix))
}

/// Create a condition that checks if input ends with a string
pub fn ends_with_condition(suffix: impl Into<String>) -> Box<dyn Fn(&str) -> bool + Send + Sync> {
    let suffix = suffix.into();
    Box::new(move |input: &str| input.ends_with(&suffix))
}

#[cfg(test)]
mod tests {
    use super::*;
    use async_trait::async_trait;

    /// Mock callable for testing
    struct MockCallable {
        name: String,
        response: String,
    }

    impl MockCallable {
        fn new(name: &str, response: &str) -> Self {
            Self {
                name: name.to_string(),
                response: response.to_string(),
            }
        }
    }

    #[async_trait]
    impl Callable for MockCallable {
        fn name(&self) -> &str {
            &self.name
        }

        async fn run(&self, _input: &str) -> anyhow::Result<String> {
            Ok(self.response.clone())
        }
    }

    #[tokio::test]
    async fn test_conditional_first_match_wins() {
        let flow = ConditionalFlow::new("router")
            .when(
                "branch_a",
                |s| s.contains("foo"),
                Arc::new(MockCallable::new("a", "matched_a")),
            )
            .when(
                "branch_b",
                |s| s.contains("bar"),
                Arc::new(MockCallable::new("b", "matched_b")),
            )
            .otherwise(Arc::new(MockCallable::new("default", "matched_default")));

        // Should match branch_a
        let result = flow.execute("foo").await.unwrap();
        assert_eq!(result, "matched_a");

        // Should match branch_b
        let result = flow.execute("bar").await.unwrap();
        assert_eq!(result, "matched_b");

        // Should match default
        let result = flow.execute("baz").await.unwrap();
        assert_eq!(result, "matched_default");
    }

    #[tokio::test]
    async fn test_conditional_first_match_priority() {
        // If input matches multiple conditions, first one wins
        let flow = ConditionalFlow::new("priority")
            .when(
                "first",
                |s| !s.is_empty(),
                Arc::new(MockCallable::new("a", "first_wins")),
            )
            .when(
                "second",
                |s| s.contains("x"),
                Arc::new(MockCallable::new("b", "second_wins")),
            );

        // Both conditions match "xyz", but first one wins
        let result = flow.execute("xyz").await.unwrap();
        assert_eq!(result, "first_wins");
    }

    #[tokio::test]
    async fn test_conditional_no_match_no_default() {
        let flow: ConditionalFlow<MockCallable> = ConditionalFlow::new("strict").when(
            "only_a",
            |s| s == "a",
            Arc::new(MockCallable::new("a", "matched")),
        );

        let result = flow.execute("b").await;
        assert!(result.is_err());
        assert!(result
            .unwrap_err()
            .to_string()
            .contains("No matching branch"));
    }

    #[tokio::test]
    async fn test_conditional_with_default() {
        let flow = ConditionalFlow::new("with_default")
            .when(
                "specific",
                |s| s == "specific",
                Arc::new(MockCallable::new("s", "specific_response")),
            )
            .otherwise(Arc::new(MockCallable::new("d", "default_response")));

        let result = flow.execute("anything").await.unwrap();
        assert_eq!(result, "default_response");
    }

    #[tokio::test]
    async fn test_branch_new_and_matches() {
        let callable = Arc::new(MockCallable::new("test", "response"));
        let branch = Branch::new("test_branch", |s| s.starts_with("hello"), callable);

        assert!(branch.matches("hello world"));
        assert!(!branch.matches("world hello"));
        assert_eq!(branch.name, "test_branch");
    }

    #[tokio::test]
    async fn test_branch_default_always_matches() {
        let callable = Arc::new(MockCallable::new("test", "response"));
        let branch = Branch::default("default_branch", callable);

        assert!(branch.matches("anything"));
        assert!(branch.matches(""));
        assert!(branch.matches("123"));
    }

    #[tokio::test]
    async fn test_contains_condition() {
        let condition = contains_condition("needle");
        assert!(condition("haystack needle here"));
        assert!(!condition("no match"));
    }

    #[tokio::test]
    async fn test_starts_with_condition() {
        let condition = starts_with_condition("prefix");
        assert!(condition("prefix_rest"));
        assert!(!condition("no_prefix"));
    }

    #[tokio::test]
    async fn test_ends_with_condition() {
        let condition = ends_with_condition("suffix");
        assert!(condition("word_suffix"));
        assert!(!condition("suffix_not"));
    }

    #[tokio::test]
    async fn test_conditional_flow_properties() {
        let flow = ConditionalFlow::new("test_flow")
            .when("b1", |_| true, Arc::new(MockCallable::new("1", "r1")))
            .when("b2", |_| true, Arc::new(MockCallable::new("2", "r2")));

        assert_eq!(flow.name(), "test_flow");
        assert_eq!(flow.branch_count(), 2);
    }

    #[tokio::test]
    async fn test_conditional_error_propagation() {
        struct FailingCallable;

        #[async_trait]
        impl Callable for FailingCallable {
            fn name(&self) -> &str {
                "failing"
            }
            async fn run(&self, _input: &str) -> anyhow::Result<String> {
                anyhow::bail!("Branch failed")
            }
        }

        let flow: ConditionalFlow<FailingCallable> =
            ConditionalFlow::new("failing").when("fail", |_| true, Arc::new(FailingCallable));

        let result = flow.execute("any").await;
        assert!(result.is_err());
        assert!(result.unwrap_err().to_string().contains("Branch failed"));
    }
}