use std::collections::HashSet;
const ACTION_KEYWORDS: &[&str] = &[
"create",
"add",
"implement",
"write",
"make",
"build",
"generate",
"set up",
"initialize",
"configure",
];
const MODIFICATION_KEYWORDS: &[&str] = &[
"fix", "bug", "issue", "problem", "error", "broken", "repair", "patch", "update", "change",
"modify", "refactor",
];
const ANALYSIS_KEYWORDS: &[&str] = &[
"analyze", "review", "check", "examine", "inspect", "audit", "find", "search", "debug",
"explain",
];
pub fn get_action_keywords() -> HashSet<&'static str> {
ACTION_KEYWORDS.iter().cloned().collect()
}
pub fn get_modification_keywords() -> HashSet<&'static str> {
MODIFICATION_KEYWORDS.iter().cloned().collect()
}
pub fn get_analysis_keywords() -> HashSet<&'static str> {
ANALYSIS_KEYWORDS.iter().cloned().collect()
}
pub fn is_action_prompt(prompt: &str) -> bool {
let prompt_lower = prompt.to_lowercase();
ACTION_KEYWORDS.iter().any(|kw| prompt_lower.contains(kw))
}
pub fn is_modification_prompt(prompt: &str) -> bool {
let prompt_lower = prompt.to_lowercase();
MODIFICATION_KEYWORDS
.iter()
.any(|kw| prompt_lower.contains(kw))
}
pub fn is_analysis_prompt(prompt: &str) -> bool {
let prompt_lower = prompt.to_lowercase();
ANALYSIS_KEYWORDS.iter().any(|kw| prompt_lower.contains(kw))
}
pub fn classify_prompt(prompt: &str) -> &str {
if is_modification_prompt(prompt) {
"modification"
} else if is_action_prompt(prompt) {
"action"
} else if is_analysis_prompt(prompt) {
"analysis"
} else {
"unknown"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_action_keyword() {
assert!(is_action_prompt("Create a new file"));
assert!(is_action_prompt("Implement feature"));
}
#[test]
fn test_modification_keyword() {
assert!(is_modification_prompt("Fix the bug"));
assert!(is_modification_prompt("Update the code"));
}
#[test]
fn test_classification() {
assert_eq!(classify_prompt("Fix this issue"), "modification");
assert_eq!(classify_prompt("Create new file"), "action");
assert_eq!(classify_prompt("Analyze code"), "analysis");
}
}