use regex::Regex;
use std::sync::LazyLock;
const SAFE_QUOTED_CHARS: &str = r"[a-zA-Z0-9_.\-/:@*?, ]+";
const SAFE_PATH_CHARS: &str = r"[a-zA-Z0-9_.\-/*?]+";
static ALLOWED_TEMPLATES: LazyLock<Vec<Regex>> = LazyLock::new(|| {
vec![
Regex::new(&format!(
r#"^sqry query "{SAFE_QUOTED_CHARS}?"(\s+--language\s+[a-z]+)*(\s+--path\s+"{SAFE_PATH_CHARS}")?(\s+--limit\s+\d+)?$"#
)).expect("Invalid query template"),
Regex::new(&format!(
r#"^sqry search "{SAFE_QUOTED_CHARS}?"(\s+--language\s+[a-z]+)*(\s+--path\s+"{SAFE_PATH_CHARS}")?$"#
)).expect("Invalid search template"),
Regex::new(&format!(
r#"^sqry graph trace-path "{SAFE_QUOTED_CHARS}" "{SAFE_QUOTED_CHARS}"(\s+--max-depth\s+\d+)?$"#
)).expect("Invalid trace-path template"),
Regex::new(&format!(
r#"^sqry graph direct-callers "{SAFE_QUOTED_CHARS}"(\s+--language\s+[a-z]+)?$"#
)).expect("Invalid direct-callers template"),
Regex::new(&format!(
r#"^sqry graph direct-callees "{SAFE_QUOTED_CHARS}"(\s+--language\s+[a-z]+)?$"#
)).expect("Invalid direct-callees template"),
Regex::new(&format!(
r#"^sqry visualize --relation\s+(call|import|export|inherit|impl)\s+--symbol\s+"{SAFE_QUOTED_CHARS}"(\s+--format\s+(mermaid|dot|json))?$"#
)).expect("Invalid visualize template"),
Regex::new(&format!(
r#"^sqry index --status(\s+--path\s+"{SAFE_PATH_CHARS}")?(\s+--json)?$"#
)).expect("Invalid index template"),
]
});
#[must_use]
pub fn matches_allowed_template(command: &str) -> bool {
let trimmed = command.trim();
ALLOWED_TEMPLATES
.iter()
.any(|template| template.is_match(trimmed))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_query_basic() {
assert!(matches_allowed_template("sqry query \"foo\""));
}
#[test]
fn test_query_with_language() {
assert!(matches_allowed_template(
"sqry query \"foo\" --language rust"
));
}
#[test]
fn test_query_with_multiple_languages() {
assert!(matches_allowed_template(
"sqry query \"foo\" --language rust --language python"
));
}
#[test]
fn test_query_with_kind() {
assert!(matches_allowed_template("sqry query \"foo kind:function\""));
}
#[test]
fn test_query_with_limit() {
assert!(matches_allowed_template("sqry query \"foo\" --limit 10"));
}
#[test]
fn test_query_with_path() {
assert!(matches_allowed_template(
"sqry query \"foo\" --path \"src/**\""
));
}
#[test]
fn test_search_basic() {
assert!(matches_allowed_template("sqry search \"TODO\""));
}
#[test]
fn test_graph_callers() {
assert!(matches_allowed_template(
"sqry graph direct-callers \"authenticate\""
));
}
#[test]
fn test_graph_callees() {
assert!(matches_allowed_template(
"sqry graph direct-callees \"main\""
));
}
#[test]
fn test_trace_path() {
assert!(matches_allowed_template(
"sqry graph trace-path \"login\" \"database\""
));
}
#[test]
fn test_trace_path_with_depth() {
assert!(matches_allowed_template(
"sqry graph trace-path \"login\" \"database\" --max-depth 5"
));
}
#[test]
fn test_visualize() {
assert!(matches_allowed_template(
"sqry visualize --relation call --symbol \"main\""
));
}
#[test]
fn test_visualize_with_format() {
assert!(matches_allowed_template(
"sqry visualize --relation call --symbol \"main\" --format mermaid"
));
}
#[test]
fn test_index_status() {
assert!(matches_allowed_template("sqry index --status"));
}
#[test]
fn test_index_status_with_json() {
assert!(matches_allowed_template("sqry index --status --json"));
}
#[test]
fn test_reject_unknown_command() {
assert!(!matches_allowed_template("sqry unknown \"foo\""));
}
#[test]
fn test_reject_shell_in_quotes() {
assert!(!matches_allowed_template("sqry query \"foo; rm -rf /\""));
}
#[test]
fn test_kind_in_query_expression() {
assert!(matches_allowed_template("sqry query \"kind:function\""));
}
}