use serde_yaml::Value;
use crate::compiler::mcp_ref::{McpRef, try_parse_mcp_tool_name};
use crate::compiler::tool_names::{ParsedToolName, parse_mars_tool_name};
use crate::frontmatter::Frontmatter;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct EffectiveToolPolicy {
pub allowed: Vec<String>,
pub disallowed: Vec<String>,
pub(crate) mcp_allowed: Vec<McpRef>,
pub(crate) mcp_disallowed: Vec<McpRef>,
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct ParsedToolsField {
pub allowed: Vec<String>,
pub denied: Vec<String>,
}
pub fn effective_tool_policy(
allowed: &[String],
denied: &[String],
disallowed: &[String],
) -> EffectiveToolPolicy {
let mut allowed_tools = Vec::new();
let mut mcp_allowed = Vec::new();
for tool in allowed {
if let Some(mcp_ref) = try_parse_mcp_tool_name(tool) {
mcp_allowed.push(mcp_ref);
} else {
allowed_tools.push(tool.clone());
}
}
let mut disallowed_tools = Vec::new();
let mut mcp_disallowed = Vec::new();
for tool in denied.iter().chain(disallowed.iter()) {
if let Some(mcp_ref) = try_parse_mcp_tool_name(tool) {
mcp_disallowed.push(mcp_ref);
} else {
disallowed_tools.push(tool.clone());
}
}
EffectiveToolPolicy {
allowed: dedupe_ordered(allowed_tools),
disallowed: dedupe_ordered(disallowed_tools),
mcp_allowed: dedupe_mcp_refs(mcp_allowed),
mcp_disallowed: dedupe_mcp_refs(mcp_disallowed),
}
}
pub(crate) fn dedupe_ordered(values: Vec<String>) -> Vec<String> {
let mut seen = std::collections::HashSet::new();
let mut out = Vec::new();
for value in values {
let trimmed = value.trim();
if trimmed.is_empty() {
continue;
}
let key = trimmed.to_string();
if seen.insert(key.clone()) {
out.push(key);
}
}
out
}
fn dedupe_mcp_refs(refs: Vec<McpRef>) -> Vec<McpRef> {
let mut seen = std::collections::HashSet::new();
let mut out = Vec::new();
for mcp_ref in refs {
let key = mcp_ref.to_canonical();
if seen.insert(key) {
out.push(mcp_ref);
}
}
out
}
pub(crate) fn append_mcp_server_entries_to_tools(fm: &mut Frontmatter, servers: &[String]) -> bool {
let entries: Vec<String> = servers
.iter()
.map(|server| server.trim())
.filter(|server| !server.is_empty())
.map(|server| format!("mcp({server})"))
.collect();
if entries.is_empty() {
return false;
}
let Some(existing) = fm.get("tools").cloned() else {
fm.insert(
"tools",
Value::Sequence(entries.into_iter().map(Value::String).collect()),
);
return true;
};
match existing {
Value::Mapping(mut mapping) => {
let mut changed = false;
for entry in entries {
let key = Value::String(entry);
if !mapping.contains_key(&key) {
mapping.insert(key, Value::String("allow".into()));
changed = true;
}
}
if changed {
fm.insert("tools", Value::Mapping(mapping));
}
changed
}
other => {
let mut tools = yaml_str_list(&other);
let mut changed = false;
for entry in entries {
if !tools.iter().any(|existing| existing == &entry) {
tools.push(entry);
changed = true;
}
}
if changed {
fm.insert(
"tools",
Value::Sequence(tools.into_iter().map(Value::String).collect()),
);
}
changed
}
}
}
pub(crate) fn yaml_str_list(val: &Value) -> Vec<String> {
match val {
Value::Sequence(seq) => seq
.iter()
.filter_map(|v| v.as_str())
.map(str::to_owned)
.collect(),
Value::String(s) => vec![s.clone()],
_ => vec![],
}
}
fn parse_tool_name_field(
field: &str,
raw: &str,
on_invalid: &mut dyn FnMut(&str, &str, &'static str),
) -> Option<String> {
match parse_mars_tool_name(raw) {
Ok(ParsedToolName { name, .. }) => Some(name),
Err(err) => {
on_invalid(field, raw, err.allowed());
None
}
}
}
pub(crate) fn yaml_tool_list(
field: &str,
val: &Value,
on_invalid: &mut dyn FnMut(&str, &str, &'static str),
) -> Vec<String> {
dedupe_ordered(
yaml_str_list(val)
.into_iter()
.enumerate()
.filter_map(|(idx, tool)| {
parse_tool_name_field(&format!("{field}[{idx}]"), &tool, on_invalid)
})
.collect(),
)
}
pub(crate) fn parse_tools_field(
field: &str,
val: &Value,
on_invalid: &mut dyn FnMut(&str, &str, &'static str),
) -> ParsedToolsField {
match val {
Value::Mapping(mapping) => {
let mut allowed = Vec::new();
let mut denied = Vec::new();
for (key, value) in mapping {
let Some(tool_name) = key.as_str() else {
on_invalid(field, &format!("{key:?}"), "string tool keys");
continue;
};
let Some(policy) = value.as_str() else {
on_invalid(
&format!("{field}.{tool_name}"),
&format!("{value:?}"),
"allow or deny",
);
continue;
};
let normalized_tool =
parse_tool_name_field(&format!("{field}.{tool_name}"), tool_name, on_invalid);
if policy.eq_ignore_ascii_case("allow") {
if let Some(normalized_tool) = normalized_tool {
allowed.push(normalized_tool);
}
} else if policy.eq_ignore_ascii_case("deny") {
if let Some(normalized_tool) = normalized_tool {
denied.push(normalized_tool);
}
} else {
on_invalid(&format!("{field}.{tool_name}"), policy, "allow or deny");
}
}
ParsedToolsField {
allowed: dedupe_ordered(allowed),
denied: dedupe_ordered(denied),
}
}
_ => ParsedToolsField {
allowed: yaml_tool_list(field, val, on_invalid),
denied: vec![],
},
}
}
pub const REMOVED_MCP_TOOLS_FIELDS: &[&str] = &["mcp-tools", "mcp_tools"];
pub fn removed_mcp_tools_replacement() -> &'static str {
"use `mcp(server)` or `mcp(server/tool)` entries in `tools:` / `disallowed-tools:` instead"
}
pub(crate) fn strip_removed_mcp_tools_fields(fm: &mut Frontmatter) -> bool {
let mut changed = false;
for key in REMOVED_MCP_TOOLS_FIELDS {
if fm.remove(key).is_some() {
changed = true;
}
}
changed
}
pub const NON_CANONICAL_TOOL_FIELD_ALIASES: &[(&str, &str)] = &[
("allowed-tools", "tools:"),
("allowed_tools", "tools:"),
("disallowed_tools", "disallowed-tools:"),
];
fn canonical_key_from_label(label: &str) -> &str {
label.strip_suffix(':').unwrap_or(label)
}
pub(crate) fn non_canonical_aliases_for(canonical_key: &str) -> Vec<&'static str> {
NON_CANONICAL_TOOL_FIELD_ALIASES
.iter()
.filter(|&(_, label)| canonical_key_from_label(label) == canonical_key)
.map(|&(alias, _)| alias)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::compiler::agents::parse_agent_content;
use crate::compiler::mcp_ref::{McpSegment, parse_mcp_ref};
use crate::compiler::skills::parse_skill_content;
use crate::frontmatter::Frontmatter;
mod non_canonical_alias_tests {
use super::*;
#[test]
fn aliases_grouped_by_canonical_field() {
assert_eq!(
non_canonical_aliases_for("tools"),
vec!["allowed-tools", "allowed_tools"]
);
assert_eq!(
non_canonical_aliases_for("disallowed-tools"),
vec!["disallowed_tools"]
);
}
}
fn agent_policy(yaml: &str) -> EffectiveToolPolicy {
let mut diags = Vec::new();
let (profile, _) = parse_agent_content(yaml, &mut diags).unwrap();
assert!(diags.is_empty(), "agent diags: {diags:?}");
profile.effective_tool_policy(&crate::compiler::agents::HarnessKind::Claude)
}
fn skill_policy(yaml: &str) -> EffectiveToolPolicy {
let mut diags = Vec::new();
let (profile, _) = parse_skill_content(yaml, &mut diags).unwrap();
assert!(diags.is_empty(), "skill diags: {diags:?}");
profile.effective_tool_policy()
}
#[test]
fn tools_mcp_entry_grants_whole_server() {
let policy = agent_policy("---\nname: a\ndescription: d\ntools: [mcp(context7)]\n---\n");
assert!(policy.allowed.is_empty());
assert_eq!(
policy.mcp_allowed,
vec![McpRef {
server: McpSegment::Named("context7".into()),
tool: McpSegment::Any,
}]
);
}
#[test]
fn skill_tools_mcp_entry_grants_whole_server() {
let policy =
skill_policy("---\nname: a\ndescription: d\ntools: [mcp(context7)]\n---\nbody");
assert_eq!(policy.mcp_allowed.len(), 1);
}
#[test]
fn append_mcp_server_entries_to_tools_merges_into_existing_list() {
let mut fm = Frontmatter::parse("---\ntools: [Bash]\n---\n").unwrap();
assert!(append_mcp_server_entries_to_tools(
&mut fm,
&["context7".to_string()]
));
assert_eq!(
fm.get("tools").map(yaml_str_list),
Some(vec!["Bash".to_string(), "mcp(context7)".to_string()])
);
}
#[test]
fn append_mcp_server_entries_to_tools_is_idempotent() {
let mut fm = Frontmatter::parse("---\ntools: [mcp(context7)]\n---\n").unwrap();
assert!(!append_mcp_server_entries_to_tools(
&mut fm,
&["context7".to_string()]
));
}
#[test]
fn append_mcp_server_entries_to_tools_preserves_map_form() {
let mut fm = Frontmatter::parse("---\ntools:\n Bash: allow\n Read: deny\n---\n").unwrap();
assert!(append_mcp_server_entries_to_tools(
&mut fm,
&["context7".to_string(), "github".to_string()]
));
let tools = fm.get("tools").unwrap();
let Value::Mapping(mapping) = tools else {
panic!("expected map-form tools, got {tools:?}");
};
assert_eq!(
mapping.get(Value::String("Bash".into())),
Some(&Value::String("allow".into()))
);
assert_eq!(
mapping.get(Value::String("Read".into())),
Some(&Value::String("deny".into()))
);
assert_eq!(
mapping.get(Value::String("mcp(context7)".into())),
Some(&Value::String("allow".into()))
);
assert_eq!(
mapping.get(Value::String("mcp(github)".into())),
Some(&Value::String("allow".into()))
);
}
#[test]
fn disallowed_mcp_ref_round_trips_through_policy() {
let policy = agent_policy(
"---\nname: a\ndescription: d\ndisallowed-tools: [mcp(github/delete_repo)]\n---\n",
);
assert!(policy.allowed.is_empty());
assert!(policy.mcp_allowed.is_empty());
assert!(policy.disallowed.is_empty());
assert_eq!(policy.mcp_disallowed.len(), 1);
assert_eq!(
policy.mcp_disallowed[0].to_canonical(),
"mcp(github/delete_repo)"
);
}
#[test]
fn plugin_colon_server_names_preserve_verbatim_in_mcp_refs() {
let policy = agent_policy(
"---\nname: a\ndescription: d\ntools: [mcp(plugin:context7:context7)]\n---\n",
);
assert_eq!(policy.mcp_allowed.len(), 1);
assert_eq!(
policy.mcp_allowed[0],
parse_mcp_ref("plugin:context7:context7").unwrap()
);
}
}