use crate::callable::Callable;
use std::sync::Arc;
pub type Condition = Box<dyn Fn(&str) -> bool + Send + Sync>;
pub struct Branch<C: Callable> {
pub condition: Condition,
pub callable: Arc<C>,
pub name: String,
}
impl<C: Callable> Branch<C> {
pub fn new(
name: impl Into<String>,
condition: impl Fn(&str) -> bool + Send + Sync + 'static,
callable: Arc<C>,
) -> Self {
Self {
condition: Box::new(condition),
callable,
name: name.into(),
}
}
pub fn default(name: impl Into<String>, callable: Arc<C>) -> Self {
Self {
condition: Box::new(|_| true),
callable,
name: name.into(),
}
}
pub fn matches(&self, input: &str) -> bool {
(self.condition)(input)
}
}
pub struct ConditionalFlow<C: Callable> {
branches: Vec<Branch<C>>,
name: String,
default: Option<Arc<C>>,
}
impl<C: Callable> ConditionalFlow<C> {
pub fn new(name: impl Into<String>) -> Self {
Self {
branches: Vec::new(),
name: name.into(),
default: None,
}
}
pub fn add_branch(mut self, branch: Branch<C>) -> Self {
self.branches.push(branch);
self
}
pub fn when(
mut self,
name: impl Into<String>,
condition: impl Fn(&str) -> bool + Send + Sync + 'static,
callable: Arc<C>,
) -> Self {
self.branches.push(Branch::new(name, condition, callable));
self
}
pub fn otherwise(mut self, callable: Arc<C>) -> Self {
self.default = Some(callable);
self
}
pub async fn execute(&self, input: &str) -> anyhow::Result<String> {
for branch in &self.branches {
if branch.matches(input) {
return branch.callable.run(input).await;
}
}
if let Some(default) = &self.default {
return default.run(input).await;
}
anyhow::bail!("No matching branch and no default")
}
pub fn name(&self) -> &str {
&self.name
}
pub fn branch_count(&self) -> usize {
self.branches.len()
}
}
pub fn contains_condition(needle: impl Into<String>) -> Box<dyn Fn(&str) -> bool + Send + Sync> {
let needle = needle.into();
Box::new(move |input: &str| input.contains(&needle))
}
pub fn starts_with_condition(prefix: impl Into<String>) -> Box<dyn Fn(&str) -> bool + Send + Sync> {
let prefix = prefix.into();
Box::new(move |input: &str| input.starts_with(&prefix))
}
pub fn ends_with_condition(suffix: impl Into<String>) -> Box<dyn Fn(&str) -> bool + Send + Sync> {
let suffix = suffix.into();
Box::new(move |input: &str| input.ends_with(&suffix))
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
struct MockCallable {
name: String,
response: String,
}
impl MockCallable {
fn new(name: &str, response: &str) -> Self {
Self {
name: name.to_string(),
response: response.to_string(),
}
}
}
#[async_trait]
impl Callable for MockCallable {
fn name(&self) -> &str {
&self.name
}
async fn run(&self, _input: &str) -> anyhow::Result<String> {
Ok(self.response.clone())
}
}
#[tokio::test]
async fn test_conditional_first_match_wins() {
let flow = ConditionalFlow::new("router")
.when(
"branch_a",
|s| s.contains("foo"),
Arc::new(MockCallable::new("a", "matched_a")),
)
.when(
"branch_b",
|s| s.contains("bar"),
Arc::new(MockCallable::new("b", "matched_b")),
)
.otherwise(Arc::new(MockCallable::new("default", "matched_default")));
let result = flow.execute("foo").await.unwrap();
assert_eq!(result, "matched_a");
let result = flow.execute("bar").await.unwrap();
assert_eq!(result, "matched_b");
let result = flow.execute("baz").await.unwrap();
assert_eq!(result, "matched_default");
}
#[tokio::test]
async fn test_conditional_first_match_priority() {
let flow = ConditionalFlow::new("priority")
.when(
"first",
|s| !s.is_empty(),
Arc::new(MockCallable::new("a", "first_wins")),
)
.when(
"second",
|s| s.contains("x"),
Arc::new(MockCallable::new("b", "second_wins")),
);
let result = flow.execute("xyz").await.unwrap();
assert_eq!(result, "first_wins");
}
#[tokio::test]
async fn test_conditional_no_match_no_default() {
let flow: ConditionalFlow<MockCallable> = ConditionalFlow::new("strict").when(
"only_a",
|s| s == "a",
Arc::new(MockCallable::new("a", "matched")),
);
let result = flow.execute("b").await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("No matching branch"));
}
#[tokio::test]
async fn test_conditional_with_default() {
let flow = ConditionalFlow::new("with_default")
.when(
"specific",
|s| s == "specific",
Arc::new(MockCallable::new("s", "specific_response")),
)
.otherwise(Arc::new(MockCallable::new("d", "default_response")));
let result = flow.execute("anything").await.unwrap();
assert_eq!(result, "default_response");
}
#[tokio::test]
async fn test_branch_new_and_matches() {
let callable = Arc::new(MockCallable::new("test", "response"));
let branch = Branch::new("test_branch", |s| s.starts_with("hello"), callable);
assert!(branch.matches("hello world"));
assert!(!branch.matches("world hello"));
assert_eq!(branch.name, "test_branch");
}
#[tokio::test]
async fn test_branch_default_always_matches() {
let callable = Arc::new(MockCallable::new("test", "response"));
let branch = Branch::default("default_branch", callable);
assert!(branch.matches("anything"));
assert!(branch.matches(""));
assert!(branch.matches("123"));
}
#[tokio::test]
async fn test_contains_condition() {
let condition = contains_condition("needle");
assert!(condition("haystack needle here"));
assert!(!condition("no match"));
}
#[tokio::test]
async fn test_starts_with_condition() {
let condition = starts_with_condition("prefix");
assert!(condition("prefix_rest"));
assert!(!condition("no_prefix"));
}
#[tokio::test]
async fn test_ends_with_condition() {
let condition = ends_with_condition("suffix");
assert!(condition("word_suffix"));
assert!(!condition("suffix_not"));
}
#[tokio::test]
async fn test_conditional_flow_properties() {
let flow = ConditionalFlow::new("test_flow")
.when("b1", |_| true, Arc::new(MockCallable::new("1", "r1")))
.when("b2", |_| true, Arc::new(MockCallable::new("2", "r2")));
assert_eq!(flow.name(), "test_flow");
assert_eq!(flow.branch_count(), 2);
}
#[tokio::test]
async fn test_conditional_error_propagation() {
struct FailingCallable;
#[async_trait]
impl Callable for FailingCallable {
fn name(&self) -> &str {
"failing"
}
async fn run(&self, _input: &str) -> anyhow::Result<String> {
anyhow::bail!("Branch failed")
}
}
let flow: ConditionalFlow<FailingCallable> =
ConditionalFlow::new("failing").when("fail", |_| true, Arc::new(FailingCallable));
let result = flow.execute("any").await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Branch failed"));
}
}