use crate::app::CommitType;
use crate::git::FileStatus;
#[derive(Debug, Clone)]
pub struct CommitSuggestion {
pub commit_type: CommitType,
pub scope: Option<String>,
pub message: String,
pub confidence: f32,
}
impl CommitSuggestion {
pub fn full_message(&self) -> String {
match &self.scope {
Some(scope) => format!("{}({}): {}", self.commit_type.name(), scope, self.message),
None => format!("{}: {}", self.commit_type.name(), self.message),
}
}
}
pub fn generate_suggestions(statuses: &[FileStatus]) -> Vec<CommitSuggestion> {
if statuses.is_empty() {
return Vec::new();
}
let paths: Vec<&str> = statuses.iter().map(|s| s.path.as_str()).collect();
let status_refs: Vec<&FileStatus> = statuses.iter().collect();
let mut suggestions = Vec::new();
let type_counts = count_inferred_types_with_stats(&paths, &status_refs);
let mut type_vec: Vec<_> = type_counts.into_iter().collect();
type_vec.sort_by(|a, b| b.1.cmp(&a.1));
for (commit_type, count) in type_vec.iter().take(3) {
let confidence = *count as f32 / paths.len() as f32;
if confidence < 0.2 {
continue;
}
let scope = infer_scope_from_paths(&paths);
let message = generate_message(*commit_type, scope.as_deref(), &paths);
suggestions.push(CommitSuggestion {
commit_type: *commit_type,
scope,
message,
confidence,
});
}
if suggestions.is_empty() {
let scope = infer_scope_from_paths(&paths);
let message = generate_message(CommitType::Chore, scope.as_deref(), &paths);
suggestions.push(CommitSuggestion {
commit_type: CommitType::Chore,
scope,
message,
confidence: 0.3,
});
}
suggestions.sort_by(|a, b| {
b.confidence
.partial_cmp(&a.confidence)
.unwrap_or(std::cmp::Ordering::Equal)
});
suggestions.truncate(3);
suggestions
}
fn infer_type_from_path(path: &str) -> Option<CommitType> {
let path_lower = path.to_lowercase();
let file_name = path_lower.split('/').next_back().unwrap_or(&path_lower);
if path_lower.contains("/tests/")
|| path_lower.starts_with("tests/")
|| file_name.contains("_test.")
|| file_name.contains(".test.")
|| file_name.ends_with("_test.rs")
|| file_name.ends_with("_test.go")
|| file_name.ends_with("_test.py")
|| file_name.ends_with(".spec.js")
|| file_name.ends_with(".spec.ts")
|| file_name.starts_with("test_")
{
return Some(CommitType::Test);
}
if path_lower.starts_with("readme")
|| path_lower.ends_with(".md")
|| path_lower.contains("/docs/")
|| path_lower.starts_with("docs/")
|| path_lower.contains("license")
|| path_lower.contains("changelog")
{
return Some(CommitType::Docs);
}
if path_lower == "cargo.toml"
|| path_lower == "package.json"
|| path_lower == "go.mod"
|| path_lower == "requirements.txt"
|| path_lower == "pyproject.toml"
|| path_lower == "tsconfig.json"
|| path_lower == "jest.config.json"
|| path_lower == "eslint.config.json"
|| path_lower == ".eslintrc.json"
|| path_lower == ".prettierrc"
|| path_lower == ".prettierrc.json"
|| path_lower.ends_with(".lock")
|| path_lower.starts_with(".github/")
|| path_lower == ".gitignore"
|| path_lower == ".dockerignore"
|| path_lower == "dockerfile"
|| path_lower == "docker-compose.yml"
|| path_lower == "docker-compose.yaml"
|| path_lower == "makefile"
|| path_lower.ends_with(".yml")
|| path_lower.ends_with(".yaml")
{
return Some(CommitType::Chore);
}
if path_lower.ends_with(".css")
|| path_lower.ends_with(".scss")
|| path_lower.ends_with(".sass")
|| path_lower.ends_with(".less")
{
return Some(CommitType::Style);
}
None
}
fn count_inferred_types_with_stats(
paths: &[&str],
statuses: &[&FileStatus],
) -> std::collections::HashMap<CommitType, usize> {
use crate::git::FileStatusKind;
let mut counts = std::collections::HashMap::new();
for path in paths {
if let Some(commit_type) = infer_type_from_path(path) {
*counts.entry(commit_type).or_insert(0) += 1;
}
}
if counts.is_empty() && !statuses.is_empty() {
let new_count = statuses
.iter()
.filter(|s| s.kind == FileStatusKind::StagedNew)
.count();
let deleted_count = statuses
.iter()
.filter(|s| s.kind == FileStatusKind::StagedDeleted)
.count();
let modified_count = statuses
.iter()
.filter(|s| s.kind == FileStatusKind::StagedModified)
.count();
let total = statuses.len();
if new_count == total {
counts.insert(CommitType::Feat, total);
} else if deleted_count == total {
counts.insert(CommitType::Refactor, total);
} else if deleted_count > new_count && deleted_count > modified_count {
counts.insert(CommitType::Refactor, total);
} else if modified_count > 0 && new_count == 0 && deleted_count == 0 {
counts.insert(CommitType::Fix, modified_count);
if modified_count > 1 {
counts.insert(CommitType::Refactor, modified_count / 2);
}
} else if new_count > deleted_count {
counts.insert(CommitType::Feat, total);
} else {
counts.insert(CommitType::Feat, total);
}
} else if counts.is_empty() {
counts.insert(CommitType::Feat, paths.len());
}
counts
}
pub fn infer_scope_from_paths(paths: &[&str]) -> Option<String> {
if paths.is_empty() {
return None;
}
let first_parts: Vec<&str> = paths[0].split('/').collect();
if first_parts.len() < 2 {
return None;
}
let scope_candidates: Vec<Option<&str>> = paths
.iter()
.map(|p| {
let parts: Vec<&str> = p.split('/').collect();
if parts.len() >= 2 && parts[0] == "src" {
Some(parts[1])
} else if parts.len() >= 2 {
Some(parts[0])
} else {
None
}
})
.collect();
let first_scope = scope_candidates.first().and_then(|s| *s)?;
if scope_candidates
.iter()
.all(|s| s.map(|x| x == first_scope).unwrap_or(false))
{
if !first_scope.contains('.') {
return Some(first_scope.to_string());
}
}
None
}
fn generate_message(commit_type: CommitType, scope: Option<&str>, paths: &[&str]) -> String {
let file_count = paths.len();
match commit_type {
CommitType::Test => {
if file_count == 1 {
format!("add tests for {}", extract_module_name(paths[0]))
} else {
"add tests".to_string()
}
}
CommitType::Docs => {
if file_count == 1 && paths[0].to_lowercase().starts_with("readme") {
"update README".to_string()
} else if file_count == 1 {
format!("update {}", extract_file_name(paths[0]))
} else {
"update documentation".to_string()
}
}
CommitType::Chore => {
if file_count == 1 {
format!("update {}", extract_file_name(paths[0]))
} else {
"update configuration".to_string()
}
}
CommitType::Style => "update styles".to_string(),
CommitType::Feat => {
if let Some(s) = scope {
format!("add {} feature", s)
} else if file_count == 1 {
format!("add {}", extract_module_name(paths[0]))
} else {
"add new feature".to_string()
}
}
CommitType::Fix => {
if let Some(s) = scope {
format!("fix {} issue", s)
} else {
"fix issue".to_string()
}
}
CommitType::Refactor => {
if let Some(s) = scope {
format!("refactor {}", s)
} else {
"refactor code".to_string()
}
}
CommitType::Perf => {
if let Some(s) = scope {
format!("improve {} performance", s)
} else {
"improve performance".to_string()
}
}
}
}
fn extract_module_name(path: &str) -> String {
let file_name = path.split('/').next_back().unwrap_or(path);
file_name
.strip_suffix(".rs")
.or_else(|| file_name.strip_suffix(".go"))
.or_else(|| file_name.strip_suffix(".py"))
.or_else(|| file_name.strip_suffix(".js"))
.or_else(|| file_name.strip_suffix(".ts"))
.or_else(|| file_name.strip_suffix(".tsx"))
.or_else(|| file_name.strip_suffix(".jsx"))
.unwrap_or(file_name)
.to_string()
}
fn extract_file_name(path: &str) -> String {
path.split('/').next_back().unwrap_or(path).to_string()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::git::FileStatusKind;
fn create_staged_status(path: &str) -> FileStatus {
FileStatus {
path: path.to_string(),
kind: FileStatusKind::StagedNew,
}
}
#[test]
fn test_infer_type_from_test_file() {
assert_eq!(
infer_type_from_path("src/app_test.rs"),
Some(CommitType::Test)
);
assert_eq!(
infer_type_from_path("tests/integration_test.rs"),
Some(CommitType::Test)
);
assert_eq!(
infer_type_from_path("src/utils.spec.js"),
Some(CommitType::Test)
);
}
#[test]
fn test_infer_type_from_readme() {
assert_eq!(infer_type_from_path("README.md"), Some(CommitType::Docs));
assert_eq!(infer_type_from_path("readme.txt"), Some(CommitType::Docs));
}
#[test]
fn test_infer_type_from_docs() {
assert_eq!(infer_type_from_path("docs/api.md"), Some(CommitType::Docs));
assert_eq!(infer_type_from_path("CHANGELOG.md"), Some(CommitType::Docs));
}
#[test]
fn test_infer_type_from_cargo_toml() {
assert_eq!(infer_type_from_path("Cargo.toml"), Some(CommitType::Chore));
assert_eq!(infer_type_from_path("cargo.toml"), Some(CommitType::Chore));
}
#[test]
fn test_infer_type_from_package_json() {
assert_eq!(
infer_type_from_path("package.json"),
Some(CommitType::Chore)
);
}
#[test]
fn test_infer_type_from_regular_json_is_none() {
assert_eq!(infer_type_from_path("src/data.json"), None);
assert_eq!(infer_type_from_path("config/settings.json"), None);
}
#[test]
fn test_infer_type_from_regular_toml_is_none() {
assert_eq!(infer_type_from_path("src/config.toml"), None);
}
#[test]
fn test_infer_type_from_github_workflow() {
assert_eq!(
infer_type_from_path(".github/workflows/ci.yml"),
Some(CommitType::Chore)
);
}
#[test]
fn test_infer_type_from_css() {
assert_eq!(
infer_type_from_path("styles/main.css"),
Some(CommitType::Style)
);
assert_eq!(infer_type_from_path("app.scss"), Some(CommitType::Style));
}
#[test]
fn test_infer_type_from_regular_source() {
assert_eq!(infer_type_from_path("src/main.rs"), None);
assert_eq!(infer_type_from_path("src/app.rs"), None);
}
#[test]
fn test_infer_scope_from_src_auth() {
let paths = vec!["src/auth/login.rs", "src/auth/logout.rs"];
assert_eq!(infer_scope_from_paths(&paths), Some("auth".to_string()));
}
#[test]
fn test_infer_scope_from_src_tui() {
let paths = vec!["src/tui/ui.rs", "src/tui/render.rs"];
assert_eq!(infer_scope_from_paths(&paths), Some("tui".to_string()));
}
#[test]
fn test_infer_scope_mixed_paths() {
let paths = vec!["src/auth/login.rs", "src/tui/ui.rs"];
assert_eq!(infer_scope_from_paths(&paths), None);
}
#[test]
fn test_infer_scope_single_file() {
let paths = vec!["src/main.rs"];
assert_eq!(infer_scope_from_paths(&paths), None);
}
#[test]
fn test_generate_suggestions_empty() {
let statuses: Vec<FileStatus> = vec![];
let suggestions = generate_suggestions(&statuses);
assert!(suggestions.is_empty());
}
#[test]
fn test_generate_suggestions_test_files() {
let statuses = vec![
create_staged_status("src/app_test.rs"),
create_staged_status("src/utils_test.rs"),
];
let suggestions = generate_suggestions(&statuses);
assert!(!suggestions.is_empty());
assert_eq!(suggestions[0].commit_type, CommitType::Test);
}
#[test]
fn test_generate_suggestions_readme() {
let statuses = vec![create_staged_status("README.md")];
let suggestions = generate_suggestions(&statuses);
assert!(!suggestions.is_empty());
assert_eq!(suggestions[0].commit_type, CommitType::Docs);
assert!(suggestions[0].message.contains("README"));
}
#[test]
fn test_generate_suggestions_cargo_toml() {
let statuses = vec![create_staged_status("Cargo.toml")];
let suggestions = generate_suggestions(&statuses);
assert!(!suggestions.is_empty());
assert_eq!(suggestions[0].commit_type, CommitType::Chore);
}
#[test]
fn test_generate_suggestions_max_three() {
let statuses = vec![
create_staged_status("src/a.rs"),
create_staged_status("src/b.rs"),
create_staged_status("src/c.rs"),
create_staged_status("src/d.rs"),
create_staged_status("src/e.rs"),
];
let suggestions = generate_suggestions(&statuses);
assert!(suggestions.len() <= 3);
}
#[test]
fn test_commit_suggestion_full_message_with_scope() {
let suggestion = CommitSuggestion {
commit_type: CommitType::Feat,
scope: Some("auth".to_string()),
message: "add login".to_string(),
confidence: 0.8,
};
assert_eq!(suggestion.full_message(), "feat(auth): add login");
}
#[test]
fn test_commit_suggestion_full_message_without_scope() {
let suggestion = CommitSuggestion {
commit_type: CommitType::Fix,
scope: None,
message: "fix bug".to_string(),
confidence: 0.7,
};
assert_eq!(suggestion.full_message(), "fix: fix bug");
}
#[test]
fn test_extract_module_name() {
assert_eq!(extract_module_name("src/app.rs"), "app");
assert_eq!(extract_module_name("main.go"), "main");
assert_eq!(extract_module_name("utils.py"), "utils");
}
#[test]
fn test_extract_file_name() {
assert_eq!(extract_file_name("src/app.rs"), "app.rs");
assert_eq!(extract_file_name("Cargo.toml"), "Cargo.toml");
}
}