assay-core 3.0.0

High-performance evaluation framework for LLM agents (Core)
Documentation
use std::collections::BTreeSet;

#[derive(Debug, Clone)]
pub enum ToolContext<'a> {
    Borrowed {
        tool_name: &'a str,
        tool_classes: &'a BTreeSet<String>,
    },
    Owned {
        tool_name: &'a str,
        tool_classes: BTreeSet<String>,
    },
}

impl<'a> ToolContext<'a> {
    pub fn tool_name(&self) -> &str {
        match self {
            Self::Borrowed { tool_name, .. } | Self::Owned { tool_name, .. } => tool_name,
        }
    }

    pub fn tool_classes(&self) -> &BTreeSet<String> {
        match self {
            Self::Borrowed { tool_classes, .. } => tool_classes,
            Self::Owned { tool_classes, .. } => tool_classes,
        }
    }
}

#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub enum MatchBasis {
    #[default]
    None,
    Name,
    Class,
    NameAndClass,
}

impl MatchBasis {
    pub fn as_str(&self) -> Option<&'static str> {
        match self {
            Self::None => None,
            Self::Name => Some("name"),
            Self::Class => Some("class"),
            Self::NameAndClass => Some("name+class"),
        }
    }
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MatchResult {
    pub matched: bool,
    pub matched_classes: Vec<String>,
    pub basis: MatchBasis,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ToolRuleSelector {
    pub name: Option<String>,
    pub class: Option<String>,
}

impl ToolRuleSelector {
    pub fn new(name: Option<String>, class: Option<String>) -> Self {
        Self { name, class }
    }

    pub fn matches(&self, ctx: &ToolContext<'_>) -> MatchResult {
        let name_ok = match &self.name {
            Some(name) => ctx.tool_name() == name,
            None => true,
        };

        let mut matched_classes = Vec::new();
        let class_ok = match &self.class {
            Some(class_name) => {
                let ok = ctx.tool_classes().contains(class_name);
                if ok {
                    matched_classes.push(class_name.clone());
                }
                ok
            }
            None => true,
        };

        let basis = match (&self.name, &self.class) {
            (Some(_), Some(_)) => MatchBasis::NameAndClass,
            (Some(_), None) => MatchBasis::Name,
            (None, Some(_)) => MatchBasis::Class,
            (None, None) => MatchBasis::None,
        };

        MatchResult {
            matched: name_ok && class_ok,
            matched_classes,
            basis,
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn match_by_name_only() {
        let ctx = ToolContext::Owned {
            tool_name: "web_search",
            tool_classes: BTreeSet::new(),
        };
        let selector = ToolRuleSelector::new(Some("web_search".to_string()), None);
        let result = selector.matches(&ctx);
        assert!(result.matched);
        assert_eq!(result.basis, MatchBasis::Name);
        assert!(result.matched_classes.is_empty());
    }

    #[test]
    fn match_by_class_only() {
        let ctx = ToolContext::Owned {
            tool_name: "web_search_alt",
            tool_classes: BTreeSet::from(["sink:network".to_string()]),
        };
        let selector = ToolRuleSelector::new(None, Some("sink:network".to_string()));
        let result = selector.matches(&ctx);
        assert!(result.matched);
        assert_eq!(result.basis, MatchBasis::Class);
        assert_eq!(result.matched_classes, vec!["sink:network".to_string()]);
    }

    #[test]
    fn name_and_class_requires_both() {
        let ctx = ToolContext::Owned {
            tool_name: "web_search",
            tool_classes: BTreeSet::from(["sink:network".to_string()]),
        };
        let selector = ToolRuleSelector::new(
            Some("web_search".to_string()),
            Some("sink:network".to_string()),
        );
        let result = selector.matches(&ctx);
        assert!(result.matched);
        assert_eq!(result.basis, MatchBasis::NameAndClass);
    }

    #[test]
    fn missing_class_fails_and_selector() {
        let ctx = ToolContext::Owned {
            tool_name: "web_search",
            tool_classes: BTreeSet::new(),
        };
        let selector = ToolRuleSelector::new(
            Some("web_search".to_string()),
            Some("sink:network".to_string()),
        );
        let result = selector.matches(&ctx);
        assert!(!result.matched);
    }
}