use crate::context::Context;
use crate::errors::{ErrorCode, ModuleError};
pub const DEFAULT_MAX_CALL_DEPTH: usize = 32;
pub const DEFAULT_MAX_MODULE_REPEAT: usize = 3;
#[must_use]
pub fn match_pattern(pattern: &str, value: &str) -> bool {
if pattern == "*" {
return true;
}
if !pattern.contains('*') {
return pattern == value;
}
let segments: Vec<&str> = pattern.split('*').collect();
let mut pos: usize = 0;
if !pattern.starts_with('*') {
if !value.starts_with(segments[0]) {
return false;
}
pos = segments[0].len();
}
for segment in &segments[1..] {
if segment.is_empty() {
continue;
}
match value[pos..].find(segment) {
Some(idx) => {
pos += idx + segment.len();
}
None => return false,
}
}
if !pattern.ends_with('*') && !value.ends_with(segments[segments.len() - 1]) {
return false;
}
true
}
pub fn guard_call_chain(
ctx: &Context<serde_json::Value>,
module_name: &str,
max_depth: u32,
) -> Result<(), ModuleError> {
guard_call_chain_with_repeat(ctx, module_name, max_depth, DEFAULT_MAX_MODULE_REPEAT)
}
pub fn guard_call_chain_with_repeat(
ctx: &Context<serde_json::Value>,
module_name: &str,
max_depth: u32,
max_module_repeat: usize,
) -> Result<(), ModuleError> {
#[allow(clippy::cast_possible_truncation)]
if ctx.call_chain.len() as u32 > max_depth {
return Err(ModuleError::new(
ErrorCode::CallDepthExceeded,
format!(
"Call depth exceeded: chain length {} > max_depth {}",
ctx.call_chain.len(),
max_depth
),
));
}
let prior = if ctx.call_chain.is_empty() {
&ctx.call_chain[..]
} else {
&ctx.call_chain[..ctx.call_chain.len() - 1]
};
if let Some(last_idx) = prior.iter().rposition(|n| n.as_str() == module_name) {
let subsequence = &prior[last_idx + 1..];
if !subsequence.is_empty() {
return Err(ModuleError::new(
ErrorCode::CircularCall,
format!(
"Circular call detected: '{}' already in call chain {:?}",
module_name, ctx.call_chain
),
));
}
}
let count = ctx
.call_chain
.iter()
.filter(|name| name.as_str() == module_name)
.count();
if count > max_module_repeat {
return Err(ModuleError::new(
ErrorCode::CallFrequencyExceeded,
format!(
"Module '{module_name}' called {count} times, exceeds max repeat limit of {max_module_repeat}"
),
));
}
Ok(())
}
fn to_snake_case(segment: &str) -> String {
let chars: Vec<char> = segment.chars().collect();
let mut result = String::with_capacity(segment.len() + 4);
for (i, &ch) in chars.iter().enumerate() {
if i > 0 {
let prev = chars[i - 1];
let boundary = if (prev.is_lowercase() || prev.is_ascii_digit()) && ch.is_uppercase() {
true
} else {
prev.is_uppercase()
&& ch.is_uppercase()
&& i + 1 < chars.len()
&& chars[i + 1].is_lowercase()
};
if boundary {
result.push('_');
}
}
result.push(ch.to_lowercase().next().unwrap_or(ch));
}
result.replace("__", "_")
}
fn separator_for_language(language: &str) -> &'static str {
match language {
"rust" => "::",
_ => ".",
}
}
pub fn normalize_to_canonical_id(local_id: &str, language: &str) -> String {
let separator = separator_for_language(language);
local_id
.split(separator)
.map(to_snake_case)
.collect::<Vec<_>>()
.join(".")
}
#[must_use]
pub fn calculate_specificity(pattern: &str) -> u32 {
if pattern == "*" {
return 0;
}
let mut score: u32 = 0;
for segment in pattern.split('.') {
if segment == "*" {
} else if segment.contains('*') {
score += 1;
} else {
score += 2;
}
}
score
}
#[cfg(test)]
mod tests {
use super::*;
use crate::context::Context;
use crate::errors::ErrorCode;
#[test]
fn test_match_pattern_wildcard_matches_everything() {
assert!(match_pattern("*", "anything"));
assert!(match_pattern("*", ""));
assert!(match_pattern("*", "a.b.c"));
}
#[test]
fn test_match_pattern_exact_match() {
assert!(match_pattern("foo.bar", "foo.bar"));
assert!(!match_pattern("foo.bar", "foo.baz"));
assert!(!match_pattern("foo.bar", "foo.bar.baz"));
}
#[test]
fn test_match_pattern_no_wildcards_no_match() {
assert!(!match_pattern("abc", "def"));
}
#[test]
fn test_match_pattern_prefix_wildcard() {
assert!(match_pattern("foo.*", "foo.bar"));
assert!(match_pattern("foo.*", "foo.anything"));
assert!(!match_pattern("foo.*", "bar.baz"));
}
#[test]
fn test_match_pattern_suffix_wildcard() {
assert!(match_pattern("*.bar", "foo.bar"));
assert!(match_pattern("*.bar", "x.y.bar"));
assert!(!match_pattern("*.bar", "foo.baz"));
}
#[test]
fn test_match_pattern_middle_wildcard() {
assert!(match_pattern("a.*.c", "a.b.c"));
assert!(match_pattern("a.*.c", "a.xyz.c"));
assert!(!match_pattern("a.*.c", "a.b.d"));
}
#[test]
fn test_match_pattern_multiple_wildcards() {
assert!(match_pattern("a.*.*.d", "a.b.c.d"));
}
#[test]
fn test_guard_call_chain_empty_chain_passes() {
let ctx = Context::<serde_json::Value>::anonymous();
assert!(guard_call_chain(&ctx, "mod.a", 10).is_ok());
}
#[test]
fn test_guard_call_chain_depth_exceeded() {
let mut ctx = Context::<serde_json::Value>::anonymous();
ctx.call_chain = vec!["a".into(), "b".into(), "c".into(), "d".into()];
let result = guard_call_chain(&ctx, "e", 3);
assert!(result.is_err());
assert_eq!(result.unwrap_err().code, ErrorCode::CallDepthExceeded);
}
#[test]
fn test_guard_call_chain_circular_detection() {
let mut ctx = Context::<serde_json::Value>::anonymous();
ctx.call_chain = vec!["mod.a".into(), "mod.b".into(), "mod.a".into()];
let result = guard_call_chain(&ctx, "mod.a", 100);
assert!(result.is_err());
assert_eq!(result.unwrap_err().code, ErrorCode::CircularCall);
}
#[test]
fn test_guard_call_chain_frequency_at_default_limit_passes() {
let mut ctx = Context::<serde_json::Value>::anonymous();
ctx.call_chain = vec!["mod.a".into(), "mod.a".into(), "mod.a".into()];
let result = guard_call_chain(&ctx, "mod.a", 100);
assert!(
result.is_ok(),
"exactly max_module_repeat (3) occurrences must pass, got {result:?}"
);
}
#[test]
fn test_guard_call_chain_frequency_exceeded() {
let mut ctx = Context::<serde_json::Value>::anonymous();
ctx.call_chain = vec![
"mod.a".into(),
"mod.a".into(),
"mod.a".into(),
"mod.a".into(),
];
let result = guard_call_chain(&ctx, "mod.a", 100);
assert!(result.is_err());
assert_eq!(result.unwrap_err().code, ErrorCode::CallFrequencyExceeded);
}
#[test]
fn test_guard_call_chain_with_repeat_custom_limit() {
let mut ctx = Context::<serde_json::Value>::anonymous();
ctx.call_chain = vec!["mod.a".into(), "mod.a".into()];
let result = guard_call_chain_with_repeat(&ctx, "mod.a", 100, 1);
assert!(result.is_err());
assert_eq!(result.unwrap_err().code, ErrorCode::CallFrequencyExceeded);
}
#[test]
fn test_guard_call_chain_with_repeat_single_self_within_limit() {
let mut ctx = Context::<serde_json::Value>::anonymous();
ctx.call_chain = vec!["mod.a".into()];
let result = guard_call_chain_with_repeat(&ctx, "mod.a", 100, 1);
assert!(
result.is_ok(),
"count==max_module_repeat must pass: {result:?}"
);
}
#[test]
fn test_guard_call_chain_ok_within_limits() {
let mut ctx = Context::<serde_json::Value>::anonymous();
ctx.call_chain = vec!["mod.a".into(), "mod.b".into()];
assert!(guard_call_chain(&ctx, "mod.c", 10).is_ok());
}
#[test]
fn test_normalize_python_dotted() {
assert_eq!(
normalize_to_canonical_id("MyModule.SendEmail", "python"),
"my_module.send_email"
);
}
#[test]
fn test_normalize_rust_double_colon() {
assert_eq!(
normalize_to_canonical_id("MyModule::SendEmail", "rust"),
"my_module.send_email"
);
}
#[test]
fn test_normalize_already_snake_case() {
assert_eq!(
normalize_to_canonical_id("my_module.send_email", "python"),
"my_module.send_email"
);
}
#[test]
fn test_normalize_acronym_handling() {
assert_eq!(
normalize_to_canonical_id("HTTPClient", "python"),
"http_client"
);
assert_eq!(
normalize_to_canonical_id("HTMLParser", "python"),
"html_parser"
);
}
#[test]
fn test_normalize_camel_case_boundary() {
assert_eq!(normalize_to_canonical_id("getValue", "python"), "get_value");
}
#[test]
fn test_normalize_digit_boundary() {
assert_eq!(normalize_to_canonical_id("log2Base", "python"), "log2_base");
}
#[test]
fn test_specificity_wildcard_only() {
assert_eq!(calculate_specificity("*"), 0);
}
#[test]
fn test_specificity_exact_segments() {
assert_eq!(calculate_specificity("foo.bar"), 4);
}
#[test]
fn test_specificity_partial_wildcard() {
assert_eq!(calculate_specificity("foo.*"), 2);
}
#[test]
fn test_specificity_partial_wildcard_in_segment() {
assert_eq!(calculate_specificity("foo.ba*"), 3);
}
#[test]
fn test_specificity_single_exact() {
assert_eq!(calculate_specificity("executor"), 2);
}
#[test]
fn test_specificity_all_wildcards() {
assert_eq!(calculate_specificity("*.*.*"), 0);
}
#[test]
fn test_specificity_mixed() {
assert_eq!(calculate_specificity("a.*.b.c*"), 5);
}
}