use std::collections::BTreeMap;
use regex::Regex;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use thiserror::Error;
pub const MAX_VERSION: u64 = 3;
pub const MAX_EXTENDS_DEPTH: usize = 4;
#[derive(Debug, Error)]
pub enum DslError {
#[error("failed to parse invariants YAML: {0}")]
Parse(#[from] serde_yaml::Error),
#[error("invariant `{0}` must define exactly one of `generate` or `fixed`")]
InvalidInputMode(String),
#[error("invariants file declares unsupported version `{0}`; expected ≤ {MAX_VERSION}")]
UnsupportedVersion(u64),
#[error("undefined template parameter(s): {0:?}")]
UndefinedParameters(Vec<String>),
#[error("override key `{0}` is not declared in metadata.parameters")]
UnknownParameterOverride(String),
#[error("invalid `where` regex: {0}")]
InvalidWhereRegex(String),
}
pub type Result<T> = std::result::Result<T, DslError>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InvariantFile {
pub version: u64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub metadata: Option<PackMetadata>,
pub invariants: Vec<Invariant>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub for_each_tool: Vec<ForEachToolBlock>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ForEachToolBlock {
pub name: String,
#[serde(rename = "where")]
pub matches: ToolMatch,
pub apply: ApplyTemplate,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ToolMatch {
#[serde(default, skip_serializing_if = "ToolAnnotationMatch::is_empty")]
pub annotations: ToolAnnotationMatch,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub name_matches: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub description_matches: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ToolAnnotationMatch {
#[serde(
default,
rename = "readOnlyHint",
skip_serializing_if = "Option::is_none"
)]
pub read_only_hint: Option<bool>,
#[serde(
default,
rename = "destructiveHint",
skip_serializing_if = "Option::is_none"
)]
pub destructive_hint: Option<bool>,
#[serde(
default,
rename = "idempotentHint",
skip_serializing_if = "Option::is_none"
)]
pub idempotent_hint: Option<bool>,
#[serde(
default,
rename = "openWorldHint",
skip_serializing_if = "Option::is_none"
)]
pub open_world_hint: Option<bool>,
}
impl ToolAnnotationMatch {
pub fn is_empty(&self) -> bool {
self.read_only_hint.is_none()
&& self.destructive_hint.is_none()
&& self.idempotent_hint.is_none()
&& self.open_world_hint.is_none()
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ApplyTemplate {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub generate: Option<BTreeMap<String, ValueSpec>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub fixed: Option<BTreeMap<String, Value>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cases: Option<u32>,
#[serde(rename = "assert")]
pub assertions: Vec<Assertion>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub test_fixtures: Vec<TestFixture>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PackMetadata {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub authors: Vec<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tags: Vec<String>,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub parameters: BTreeMap<String, Parameter>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub extends: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Parameter {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(default = "default_param_kind", rename = "type")]
pub kind: ParamKind,
pub default: Value,
}
fn default_param_kind() -> ParamKind {
ParamKind::String
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ParamKind {
String,
Integer,
Number,
Boolean,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Invariant {
pub name: String,
pub tool: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub generate: Option<BTreeMap<String, ValueSpec>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub fixed: Option<BTreeMap<String, Value>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cases: Option<u32>,
#[serde(rename = "assert")]
pub assertions: Vec<Assertion>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub test_fixtures: Vec<TestFixture>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TestFixture {
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub input: Option<Value>,
pub response: Value,
pub expect: FixtureExpect,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum FixtureExpect {
Pass,
Fail,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValueSpec {
#[serde(rename = "type")]
pub kind: ValueKind,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub min_length: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_length: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub min: Option<i64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max: Option<i64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub items: Option<Box<ValueSpec>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub min_items: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_items: Option<usize>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ValueKind {
String,
Integer,
Number,
Boolean,
Array,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Operand {
Path {
path: String,
},
Literal {
value: Value,
},
Direct(Value),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum Assertion {
Equals { lhs: Operand, rhs: Operand },
NotEquals { lhs: Operand, rhs: Operand },
AtMost { path: String, value: Operand },
AtLeast { path: String, value: Operand },
LengthEq { path: String, value: Operand },
LengthAtMost { path: String, value: Operand },
LengthAtLeast { path: String, value: Operand },
IsType {
path: String,
#[serde(rename = "type")]
expected: JsonType,
},
MatchesRegex { path: String, pattern: String },
AllOf {
#[serde(rename = "assert")]
assertions: Vec<Assertion>,
},
AnyOf {
#[serde(rename = "assert")]
assertions: Vec<Assertion>,
},
Not {
assertion: Box<Assertion>,
},
ForEach {
path: String,
#[serde(rename = "assert")]
assertions: Vec<Assertion>,
},
MatchesSchema {
path: String,
schema: Value,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum JsonType {
String,
Number,
Integer,
Boolean,
Array,
Object,
Null,
}
pub fn parse(source: &str) -> Result<InvariantFile> {
parse_with_overrides(source, &BTreeMap::new())
}
pub fn parse_with_overrides(
source: &str,
overrides: &BTreeMap<String, String>,
) -> Result<InvariantFile> {
let raw: serde_yaml::Value = serde_yaml::from_str(source)?;
let parameters = extract_parameters(&raw);
for key in overrides.keys() {
if !parameters.contains_key(key) {
return Err(DslError::UnknownParameterOverride(key.clone()));
}
}
let mut subst: BTreeMap<String, String> = parameters
.iter()
.map(|(name, param)| (name.clone(), stringify_default(¶m.default)))
.collect();
for (key, value) in overrides {
subst.insert(key.clone(), value.clone());
}
if has_for_each_tool(&raw) {
subst
.entry("tool_name".to_string())
.or_insert_with(|| "{{tool_name}}".to_string());
}
let substituted = render_template(source, &subst)?;
let file: InvariantFile = serde_yaml::from_str(&substituted)?;
if file.version == 0 || file.version > MAX_VERSION {
return Err(DslError::UnsupportedVersion(file.version));
}
for invariant in &file.invariants {
if invariant.generate.is_some() == invariant.fixed.is_some() {
return Err(DslError::InvalidInputMode(invariant.name.clone()));
}
}
Ok(file)
}
pub fn synthesize_for_test(block: &ForEachToolBlock, placeholder: &str) -> Result<Invariant> {
let yaml = serde_yaml::to_string(&block.apply)?;
let substituted = yaml
.replace("{{tool_name}}", placeholder)
.replace("{{ tool_name }}", placeholder);
let apply: ApplyTemplate = serde_yaml::from_str(&substituted)?;
let name = block
.name
.replace("{{tool_name}}", placeholder)
.replace("{{ tool_name }}", placeholder);
Ok(Invariant {
name,
tool: placeholder.to_string(),
generate: apply.generate,
fixed: apply.fixed,
cases: apply.cases,
assertions: apply.assertions,
test_fixtures: apply.test_fixtures,
})
}
pub fn expand_for_each_tool(
blocks: &[ForEachToolBlock],
tools: &[rmcp::model::Tool],
) -> Result<Vec<Invariant>> {
let mut out = Vec::new();
for block in blocks {
let name_re = block
.matches
.name_matches
.as_deref()
.map(Regex::new)
.transpose()
.map_err(|err| DslError::InvalidWhereRegex(err.to_string()))?;
let description_re = block
.matches
.description_matches
.as_deref()
.map(Regex::new)
.transpose()
.map_err(|err| DslError::InvalidWhereRegex(err.to_string()))?;
for tool in tools {
if !block
.matches
.matches(tool, name_re.as_ref(), description_re.as_ref())
{
continue;
}
let tool_name = tool.name.as_ref();
let yaml = serde_yaml::to_string(&block.apply)?;
let substituted = yaml
.replace("{{tool_name}}", tool_name)
.replace("{{ tool_name }}", tool_name);
let apply: ApplyTemplate = serde_yaml::from_str(&substituted)?;
let name = block
.name
.replace("{{tool_name}}", tool_name)
.replace("{{ tool_name }}", tool_name);
out.push(Invariant {
name,
tool: tool_name.to_string(),
generate: apply.generate,
fixed: apply.fixed,
cases: apply.cases,
assertions: apply.assertions,
test_fixtures: apply.test_fixtures,
});
}
}
Ok(out)
}
impl ToolMatch {
pub fn matches(
&self,
tool: &rmcp::model::Tool,
name_re: Option<&Regex>,
description_re: Option<&Regex>,
) -> bool {
let annotations = tool.annotations.as_ref();
let check_bool = |configured: Option<bool>, actual: Option<bool>| -> bool {
match configured {
Some(want) => actual == Some(want),
None => true,
}
};
if !check_bool(
self.annotations.read_only_hint,
annotations.and_then(|a| a.read_only_hint),
) {
return false;
}
if !check_bool(
self.annotations.destructive_hint,
annotations.and_then(|a| a.destructive_hint),
) {
return false;
}
if !check_bool(
self.annotations.idempotent_hint,
annotations.and_then(|a| a.idempotent_hint),
) {
return false;
}
if !check_bool(
self.annotations.open_world_hint,
annotations.and_then(|a| a.open_world_hint),
) {
return false;
}
if let Some(re) = name_re {
if !re.is_match(tool.name.as_ref()) {
return false;
}
}
if let Some(re) = description_re {
let description = tool.description.as_deref().unwrap_or("");
if !re.is_match(description) {
return false;
}
}
true
}
}
fn has_for_each_tool(value: &serde_yaml::Value) -> bool {
let key = serde_yaml::Value::String("for_each_tool".to_string());
value
.as_mapping()
.and_then(|m| m.get(&key))
.and_then(|v| v.as_sequence())
.is_some_and(|seq| !seq.is_empty())
}
fn extract_parameters(value: &serde_yaml::Value) -> BTreeMap<String, Parameter> {
let metadata_key = serde_yaml::Value::String("metadata".to_string());
let parameters_key = serde_yaml::Value::String("parameters".to_string());
let Some(metadata) = value.as_mapping().and_then(|m| m.get(&metadata_key)) else {
return BTreeMap::new();
};
let Some(parameters) = metadata.as_mapping().and_then(|m| m.get(¶meters_key)) else {
return BTreeMap::new();
};
serde_yaml::from_value(parameters.clone()).unwrap_or_default()
}
fn stringify_default(value: &Value) -> String {
match value {
Value::String(s) => s.clone(),
Value::Bool(b) => b.to_string(),
Value::Number(n) => n.to_string(),
Value::Null => String::new(),
other => other.to_string(),
}
}
#[allow(
clippy::expect_used,
clippy::unwrap_in_result,
reason = "static regex pattern is checked at compile-time review and cannot fail at runtime"
)]
fn render_template(template: &str, vars: &BTreeMap<String, String>) -> Result<String> {
let re =
Regex::new(r"\{\{\s*([A-Za-z_][A-Za-z0-9_]*)\s*\}\}").expect("static regex must compile");
let mut missing: Vec<String> = Vec::new();
let result = re.replace_all(template, |captures: ®ex::Captures<'_>| {
let name = captures.get(1).map(|m| m.as_str()).unwrap_or("");
match vars.get(name) {
Some(value) => value.clone(),
None => {
if !missing.iter().any(|existing| existing == name) {
missing.push(name.to_string());
}
captures
.get(0)
.map(|m| m.as_str().to_string())
.unwrap_or_default()
}
}
});
if !missing.is_empty() {
return Err(DslError::UndefinedParameters(missing));
}
Ok(result.into_owned())
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn v1_legacy_form_still_parses() {
let source = r#"
version: 1
invariants:
- name: demo
tool: echo
fixed: { text: hello }
assert:
- kind: equals
lhs: "$.response.text"
rhs: "$.input.text"
"#;
let file = parse(source).unwrap();
assert_eq!(file.version, 1);
assert_eq!(file.invariants.len(), 1);
match &file.invariants[0].assertions[0] {
Assertion::Equals { lhs, rhs } => {
assert!(matches!(lhs, Operand::Direct(Value::String(s)) if s == "$.response.text"));
assert!(matches!(rhs, Operand::Direct(Value::String(s)) if s == "$.input.text"));
}
other => panic!("unexpected: {other:?}"),
}
}
#[test]
fn v2_explicit_operands_parse() {
let source = r#"
version: 2
invariants:
- name: demo
tool: echo
fixed: { text: hello }
assert:
- kind: equals
lhs: { path: "$.response.text" }
rhs: { value: hello }
"#;
let file = parse(source).unwrap();
match &file.invariants[0].assertions[0] {
Assertion::Equals { lhs, rhs } => {
assert!(matches!(lhs, Operand::Path { path } if path == "$.response.text"));
assert!(
matches!(rhs, Operand::Literal { value } if value == &Value::String("hello".into()))
);
}
other => panic!("unexpected: {other:?}"),
}
}
#[test]
fn combinators_round_trip() {
let source = r#"
version: 2
invariants:
- name: combinators
tool: t
fixed: {}
assert:
- kind: all_of
assert:
- kind: equals
lhs: { path: "$.response.a" }
rhs: { value: 1 }
- kind: any_of
assert:
- kind: at_least
path: "$.response.b"
value: { value: 0 }
- kind: not
assertion:
kind: equals
lhs: { path: "$.response.b" }
rhs: { value: -1 }
"#;
let file = parse(source).unwrap();
let serialized = serde_yaml::to_string(&file).unwrap();
let reparsed = parse(&serialized).unwrap();
assert_eq!(reparsed.invariants.len(), 1);
let Assertion::AllOf { assertions } = &reparsed.invariants[0].assertions[0] else {
panic!("expected all_of");
};
assert_eq!(assertions.len(), 2);
assert!(matches!(assertions[1], Assertion::AnyOf { .. }));
}
#[test]
fn for_each_parses() {
let source = r#"
version: 2
invariants:
- name: items
tool: list
fixed: {}
assert:
- kind: for_each
path: "$.response.items[*]"
assert:
- kind: is_type
path: "$.item.id"
type: integer
"#;
let file = parse(source).unwrap();
let Assertion::ForEach { path, assertions } = &file.invariants[0].assertions[0] else {
panic!("expected for_each");
};
assert_eq!(path, "$.response.items[*]");
assert_eq!(assertions.len(), 1);
}
#[test]
fn matches_schema_carries_inline_schema() {
let source = r#"
version: 2
invariants:
- name: shape
tool: t
fixed: {}
assert:
- kind: matches_schema
path: "$.response.user"
schema:
type: object
required: [name]
properties:
name: { type: string }
"#;
let file = parse(source).unwrap();
let Assertion::MatchesSchema { path, schema } = &file.invariants[0].assertions[0] else {
panic!("expected matches_schema");
};
assert_eq!(path, "$.response.user");
assert_eq!(schema["type"], Value::String("object".into()));
let required = schema["required"].as_array().unwrap();
assert_eq!(required[0], Value::String("name".into()));
}
#[test]
fn unsupported_version_is_rejected() {
let source = r#"
version: 99
invariants: []
"#;
let err = parse(source).unwrap_err();
assert!(matches!(err, DslError::UnsupportedVersion(99)));
}
#[test]
fn generate_xor_fixed_is_enforced() {
let source = r#"
version: 2
invariants:
- name: bad
tool: t
generate: { x: { type: integer, min: 0, max: 1 } }
fixed: { x: 0 }
assert: []
"#;
let err = parse(source).unwrap_err();
assert!(matches!(err, DslError::InvalidInputMode(_)));
}
#[test]
fn v3_minimal_pack_parses() {
let source = r#"
version: 3
metadata:
name: demo
description: "demo pack"
authors: ["wallfacer-core"]
tags: [security]
invariants:
- name: t
tool: echo
fixed: {}
assert:
- kind: equals
lhs: { value: 1 }
rhs: { value: 1 }
"#;
let file = parse(source).unwrap();
assert_eq!(file.version, 3);
let meta = file.metadata.as_ref().expect("metadata");
assert_eq!(meta.name.as_deref(), Some("demo"));
assert_eq!(meta.tags, vec!["security".to_string()]);
}
#[test]
fn templating_substitutes_defaults() {
let source = r#"
version: 3
metadata:
name: demo
parameters:
whoami_tool:
description: tool returning the current user
type: string
default: whoami
invariants:
- name: t
tool: "{{whoami_tool}}"
fixed: {}
assert: []
"#;
let file = parse(source).unwrap();
assert_eq!(file.invariants[0].tool, "whoami");
}
#[test]
fn templating_overrides_take_precedence() {
let source = r#"
version: 3
metadata:
name: demo
parameters:
whoami_tool:
type: string
default: whoami
invariants:
- name: t
tool: "{{whoami_tool}}"
fixed: {}
assert: []
"#;
let mut overrides = BTreeMap::new();
overrides.insert("whoami_tool".to_string(), "getCurrentUser".to_string());
let file = parse_with_overrides(source, &overrides).unwrap();
assert_eq!(file.invariants[0].tool, "getCurrentUser");
}
#[test]
fn templating_undeclared_reference_errors() {
let source = r#"
version: 3
metadata:
name: demo
invariants:
- name: t
tool: "{{whoami_tool}}"
fixed: {}
assert: []
"#;
let err = parse(source).unwrap_err();
match err {
DslError::UndefinedParameters(names) => {
assert_eq!(names, vec!["whoami_tool".to_string()]);
}
other => panic!("expected UndefinedParameters, got {other:?}"),
}
}
#[test]
fn templating_unknown_override_errors() {
let source = r#"
version: 3
metadata:
name: demo
invariants:
- name: t
tool: echo
fixed: {}
assert: []
"#;
let mut overrides = BTreeMap::new();
overrides.insert("typoed".to_string(), "x".to_string());
let err = parse_with_overrides(source, &overrides).unwrap_err();
assert!(matches!(err, DslError::UnknownParameterOverride(name) if name == "typoed"));
}
#[test]
fn templating_handles_repeated_references() {
let source = r#"
version: 3
metadata:
name: demo
parameters:
user_tool:
type: string
default: whoami
invariants:
- name: same
tool: "{{user_tool}}"
fixed: {}
assert:
- kind: equals
lhs: { path: "$.input" }
rhs: { value: "{{ user_tool }}" }
"#;
let file = parse(source).unwrap();
assert_eq!(file.invariants[0].tool, "whoami");
}
#[test]
fn v2_packs_remain_valid_under_v3_parser() {
let source = r#"
version: 2
invariants:
- name: legacy
tool: echo
fixed: { x: 1 }
assert:
- kind: equals
lhs: { path: "$.input.x" }
rhs: { value: 1 }
"#;
let file = parse(source).unwrap();
assert_eq!(file.version, 2);
assert!(file.metadata.is_none());
}
#[test]
fn v3_round_trip_serde_preserves_metadata_and_invariants() {
let source = r#"
version: 3
metadata:
name: roundtrip
description: probe for serde drift
authors: [w]
tags: [t]
parameters:
a: { type: string, default: foo }
extends: [parent]
invariants:
- name: i1
tool: "{{a}}"
fixed: {}
assert: []
"#;
let parsed = parse(source).unwrap();
let yaml = serde_yaml::to_string(&parsed).unwrap();
let reparsed = parse(&yaml).unwrap();
assert_eq!(parsed.invariants.len(), reparsed.invariants.len());
let m1 = parsed.metadata.unwrap();
let m2 = reparsed.metadata.unwrap();
assert_eq!(m1.name, m2.name);
assert_eq!(m1.tags, m2.tags);
assert_eq!(m1.extends, m2.extends);
assert_eq!(m1.parameters.len(), m2.parameters.len());
}
fn make_tool(name: &str, read_only: Option<bool>) -> rmcp::model::Tool {
let mut tool = rmcp::model::Tool::new(
name.to_string(),
"test tool".to_string(),
std::sync::Arc::new(serde_json::Map::new()),
);
if let Some(read_only) = read_only {
let mut annotations = rmcp::model::ToolAnnotations::default();
annotations.read_only_hint = Some(read_only);
tool = tool.annotate(annotations);
}
tool
}
#[test]
fn for_each_tool_parses_with_auto_injected_tool_name() {
let source = r#"
version: 3
metadata:
name: tool-annotations
for_each_tool:
- name: "tool-annotations.read_only_does_not_mutate.{{tool_name}}"
where:
annotations:
readOnlyHint: true
apply:
fixed: {}
assert:
- kind: matches_schema
path: "$.response.structuredContent"
schema: { type: object }
invariants: []
"#;
let file = parse(source).expect("parse");
assert_eq!(file.for_each_tool.len(), 1);
let block = &file.for_each_tool[0];
assert!(block.name.contains("{{tool_name}}"));
assert_eq!(
block.matches.annotations.read_only_hint,
Some(true),
"where clause didn't deserialise"
);
}
#[test]
fn for_each_tool_expands_per_matching_tool() {
let source = r#"
version: 3
metadata:
name: tool-annotations
for_each_tool:
- name: "rule.{{tool_name}}"
where:
annotations:
readOnlyHint: true
apply:
fixed: {}
assert:
- kind: equals
lhs: { value: 1 }
rhs: { value: 1 }
invariants: []
"#;
let file = parse(source).unwrap();
let tools = vec![
make_tool("read_user", Some(true)),
make_tool("delete_user", Some(false)),
make_tool("get_status", Some(true)),
make_tool("no_annotations", None),
];
let expanded = expand_for_each_tool(&file.for_each_tool, &tools).unwrap();
let names: Vec<String> = expanded.iter().map(|i| i.name.clone()).collect();
assert_eq!(
names,
vec!["rule.read_user".to_string(), "rule.get_status".to_string()]
);
assert_eq!(expanded[0].tool, "read_user");
}
#[test]
fn for_each_tool_filter_by_name_regex() {
let source = r#"
version: 3
for_each_tool:
- name: "rule.{{tool_name}}"
where:
name_matches: "^read_"
apply:
fixed: {}
assert: []
invariants: []
"#;
let file = parse(source).unwrap();
let tools = vec![
make_tool("read_user", None),
make_tool("write_user", None),
make_tool("read_post", None),
];
let expanded = expand_for_each_tool(&file.for_each_tool, &tools).unwrap();
let names: Vec<String> = expanded.iter().map(|i| i.name.clone()).collect();
assert_eq!(
names,
vec!["rule.read_user".to_string(), "rule.read_post".to_string()]
);
}
#[test]
fn for_each_tool_substitutes_in_apply_body() {
let source = r#"
version: 3
for_each_tool:
- name: "{{tool_name}}.contract"
where: {}
apply:
fixed:
echo_back: "{{tool_name}}"
assert:
- kind: equals
lhs: { path: "$.input.echo_back" }
rhs: { value: "{{tool_name}}" }
invariants: []
"#;
let file = parse(source).unwrap();
let tools = vec![make_tool("only_one", None)];
let expanded = expand_for_each_tool(&file.for_each_tool, &tools).unwrap();
assert_eq!(expanded.len(), 1);
assert_eq!(expanded[0].name, "only_one.contract");
let fixed = expanded[0].fixed.as_ref().unwrap();
assert_eq!(fixed["echo_back"], serde_json::json!("only_one"));
}
}