use crate::classify::taxonomy::TopLevelCategory;
use crate::classify::tiers::ClassificationResult;
use crate::core::models::ClassificationMethod;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Cat {
Feature = 0,
Bugfix = 1,
Ktlo = 2,
Integrations = 3,
PlatformWork = 4,
Content = 5,
Maintenance = 6,
Merge = 7,
}
const NUM_CATS: usize = 8;
impl Cat {
const ALL: [Cat; NUM_CATS] = [
Cat::Feature,
Cat::Bugfix,
Cat::Ktlo,
Cat::Integrations,
Cat::PlatformWork,
Cat::Content,
Cat::Maintenance,
Cat::Merge,
];
fn index(self) -> usize {
self as usize
}
fn to_verdict(self) -> (&'static str, TopLevelCategory) {
match self {
Cat::Feature => ("feature", TopLevelCategory::Feature),
Cat::Bugfix => ("bugfix", TopLevelCategory::Bugfix),
Cat::Ktlo => ("chore", TopLevelCategory::Ktlo),
Cat::Integrations => ("integration", TopLevelCategory::Integrations),
Cat::PlatformWork => ("platform", TopLevelCategory::PlatformWork),
Cat::Content => ("docs", TopLevelCategory::Content),
Cat::Maintenance => ("refactor", TopLevelCategory::Maintenance),
Cat::Merge => ("merge", TopLevelCategory::Maintenance),
}
}
}
static MERGE_WEIGHTS: &[f32; NUM_CATS] = &[
-0.05, -0.05, -0.05, -0.05, -0.05, -0.05, -0.05, 0.65,
];
const FEATURE_KEYWORDS: &[&str] = &[
"add",
"implement",
"feature",
"support",
"introduce",
"create",
"build",
"new",
"extend",
"enable",
];
const BUGFIX_KEYWORDS: &[&str] = &[
"fix",
"bug",
"issue",
"broken",
"regression",
"hotfix",
"patch",
"resolve",
"repair",
"correct",
];
const KTLO_KEYWORDS: &[&str] = &[
"chore", "ci", "build", "ops", "release", "version", "bump", "update", "upgrade", "automate",
];
const INTEGRATION_KEYWORDS: &[&str] = &[
"integrate",
"api",
"webhook",
"sdk",
"plugin",
"connector",
"bridge",
"endpoint",
"external",
"third-party",
];
const PLATFORM_KEYWORDS: &[&str] = &[
"perf",
"performance",
"infra",
"infrastructure",
"architecture",
"devops",
"deploy",
"scale",
"optimize",
"database",
];
const CONTENT_KEYWORDS: &[&str] = &[
"docs",
"readme",
"documentation",
"comment",
"typo",
"copy",
"translation",
"locale",
"i18n",
"asset",
];
const MAINTENANCE_KEYWORDS: &[&str] = &[
"refactor",
"cleanup",
"rename",
"deps",
"dependency",
"style",
"lint",
"format",
"test",
"remove",
];
static KEYWORD_BAGS: &[&[&str]] = &[
FEATURE_KEYWORDS, BUGFIX_KEYWORDS, KTLO_KEYWORDS, INTEGRATION_KEYWORDS, PLATFORM_KEYWORDS, CONTENT_KEYWORDS, MAINTENANCE_KEYWORDS, &[], ];
fn score_keywords(lower: &str) -> [f32; NUM_CATS] {
let mut out = [0.0f32; NUM_CATS];
for cat in Cat::ALL {
let bag = KEYWORD_BAGS[cat.index()];
if bag.is_empty() {
continue;
}
let matched = bag.iter().filter(|&&kw| lower.contains(kw)).count();
let score = match matched {
0 => 0.0,
1 => 0.40,
2 => 0.60,
_ => 0.75,
};
out[cat.index()] = score;
}
out
}
fn score_ticket_prefix(message: &str) -> [f32; NUM_CATS] {
const TICKET_WEIGHT: f32 = 0.05;
if has_jira_style_prefix(message) {
[TICKET_WEIGHT; NUM_CATS]
} else {
[0.0; NUM_CATS]
}
}
fn score_message_length(trimmed: &str) -> [f32; NUM_CATS] {
let len = trimmed.len();
let mut out = [0.0f32; NUM_CATS];
if len < 12 {
out[Cat::Ktlo.index()] = 0.10;
out[Cat::Merge.index()] = 0.10;
out[Cat::Maintenance.index()] = 0.05;
out[Cat::Feature.index()] = -0.05;
out[Cat::Bugfix.index()] = -0.03;
} else if len > 80 {
out[Cat::Feature.index()] = 0.10;
out[Cat::PlatformWork.index()] = 0.10;
out[Cat::Maintenance.index()] = 0.05;
out[Cat::Bugfix.index()] = 0.05;
}
out
}
fn score_merge_indicator(is_merge: bool, lower: &str) -> [f32; NUM_CATS] {
let is_merge_commit = is_merge
|| lower.starts_with("merge pull request")
|| lower.starts_with("merge branch")
|| lower.starts_with("merge remote-tracking")
|| lower.starts_with("merge ");
if !is_merge_commit {
return [0.0; NUM_CATS];
}
*MERGE_WEIGHTS
}
fn score_file_paths(paths: &[String]) -> [f32; NUM_CATS] {
if paths.is_empty() {
return [0.0; NUM_CATS];
}
let total = paths.len() as f32;
let test_count = paths
.iter()
.filter(|p| {
p.contains("tests/")
|| p.contains("test/")
|| p.contains("spec/")
|| p.ends_with("_test.rs")
|| p.ends_with("_spec.rb")
|| p.ends_with(".test.ts")
|| p.ends_with(".spec.ts")
})
.count() as f32;
let docs_count = paths
.iter()
.filter(|p| {
p.contains("docs/")
|| p.contains("doc/")
|| p.ends_with(".md")
|| p.ends_with(".rst")
|| p.ends_with(".txt")
})
.count() as f32;
let manifest_count = paths
.iter()
.filter(|p| {
let name = p.split('/').next_back().unwrap_or(p.as_str());
matches!(
name,
"Cargo.toml"
| "package.json"
| "pyproject.toml"
| "requirements.txt"
| "Gemfile"
| "pom.xml"
| "build.gradle"
| "go.mod"
| "Pipfile"
| "setup.py"
| "composer.json"
)
})
.count() as f32;
let mut out = [0.0f32; NUM_CATS];
let test_ratio = test_count / total;
if test_ratio >= 0.5 {
out[Cat::Maintenance.index()] += test_ratio * 0.20;
out[Cat::Bugfix.index()] += test_ratio * 0.10;
}
let docs_ratio = docs_count / total;
if docs_ratio >= 0.5 {
out[Cat::Content.index()] += docs_ratio * 0.20;
}
let manifest_ratio = manifest_count / total;
if manifest_ratio >= 0.5 {
out[Cat::Ktlo.index()] += manifest_ratio * 0.20;
out[Cat::Maintenance.index()] += manifest_ratio * 0.10;
}
out
}
fn has_jira_style_prefix(message: &str) -> bool {
let first = match message.split_whitespace().next() {
Some(s) => s,
None => return false,
};
let candidate = first.trim_end_matches([':', '-', ',']);
let mut parts = candidate.split('-');
let project = match parts.next() {
Some(s) => s,
None => return false,
};
let number = match parts.next() {
Some(s) => s,
None => return false,
};
if parts.next().is_some() {
return false;
}
if project.is_empty() || number.is_empty() {
return false;
}
project
.chars()
.next()
.is_some_and(|c| c.is_ascii_uppercase())
&& project
.chars()
.all(|c| c.is_ascii_uppercase() || c.is_ascii_digit())
&& number.chars().all(|c| c.is_ascii_digit())
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(deny_unknown_fields)]
pub struct WeightedSumConfig {
#[serde(default = "default_weighted_sum_enabled")]
pub enabled: bool,
#[serde(default = "default_weighted_sum_min_confidence")]
pub min_confidence: f32,
}
fn default_weighted_sum_enabled() -> bool {
true
}
fn default_weighted_sum_min_confidence() -> f32 {
0.55
}
impl Default for WeightedSumConfig {
fn default() -> Self {
Self {
enabled: default_weighted_sum_enabled(),
min_confidence: default_weighted_sum_min_confidence(),
}
}
}
pub struct WeightedSumClassifier {
config: WeightedSumConfig,
}
impl WeightedSumClassifier {
pub fn new(config: WeightedSumConfig) -> Self {
Self { config }
}
pub fn classify(
&self,
message: &str,
is_merge: bool,
paths: &[String],
) -> Option<ClassificationResult> {
if !self.config.enabled {
return None;
}
let trimmed = message.trim();
let lower = trimmed.to_lowercase();
let mut scores = [0.0f32; NUM_CATS];
let keyword_scores = score_keywords(&lower);
for i in 0..NUM_CATS {
scores[i] += keyword_scores[i];
}
let ticket_scores = score_ticket_prefix(trimmed);
for i in 0..NUM_CATS {
scores[i] += ticket_scores[i];
}
let length_scores = score_message_length(trimmed);
for i in 0..NUM_CATS {
scores[i] += length_scores[i];
}
let merge_scores = score_merge_indicator(is_merge, &lower);
for i in 0..NUM_CATS {
scores[i] += merge_scores[i];
}
let path_scores = score_file_paths(paths);
for i in 0..NUM_CATS {
scores[i] += path_scores[i];
}
let (best_cat_idx, &best_score) = scores
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))?;
let tie_count = scores.iter().filter(|&&s| s == best_score).count();
if tie_count > 1 || best_score <= 0.0 {
return None;
}
if (best_score as f64) < self.config.min_confidence as f64 {
return None;
}
let best_cat = Cat::ALL[best_cat_idx];
let (category, top_level) = best_cat.to_verdict();
let confidence = (best_score as f64)
.max(self.config.min_confidence as f64)
.min(0.95);
Some(ClassificationResult {
category: category.to_string(),
subcategory: None,
top_level: Some(top_level),
confidence,
method: ClassificationMethod::WeightedSum,
ticket_id: None,
complexity: None,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn default_classifier() -> WeightedSumClassifier {
WeightedSumClassifier::new(WeightedSumConfig::default())
}
#[test]
fn keyword_score_bugfix_keywords_dominate_bugfix_category() {
let lower = "fix null pointer regression hotfix";
let scores = score_keywords(lower);
let bugfix_score = scores[Cat::Bugfix.index()];
let feature_score = scores[Cat::Feature.index()];
assert!(
bugfix_score > feature_score,
"bugfix keywords should score higher for Bugfix than Feature, got bugfix={bugfix_score:.3} feature={feature_score:.3}"
);
assert!(bugfix_score > 0.0, "bugfix score must be positive");
}
#[test]
fn keyword_score_feature_keywords_dominate_feature_category() {
let lower = "add implement feature support";
let scores = score_keywords(lower);
let feature_score = scores[Cat::Feature.index()];
let bugfix_score = scores[Cat::Bugfix.index()];
assert!(
feature_score > bugfix_score,
"feature keywords should score higher for Feature, got feature={feature_score:.3} bugfix={bugfix_score:.3}"
);
}
#[test]
fn ticket_prefix_signal_fires_for_jira_prefix() {
let msg = "PROJ-123: update auth module";
let scores = score_ticket_prefix(msg);
for (i, &s) in scores.iter().enumerate() {
assert!(s > 0.0, "category {i} should get a ticket-prefix boost");
}
}
#[test]
fn ticket_prefix_signal_zero_for_no_prefix() {
let msg = "update auth module";
let scores = score_ticket_prefix(msg);
for (i, &s) in scores.iter().enumerate() {
assert_eq!(
s, 0.0,
"category {i} should score 0.0 without ticket prefix"
);
}
}
#[test]
fn length_signal_short_message_nudges_ktlo_not_feature() {
let scores = score_message_length("wip");
assert!(
scores[Cat::Ktlo.index()] > 0.0,
"short message should nudge KTLO"
);
assert!(
scores[Cat::Feature.index()] < 0.0,
"short message should penalise Feature"
);
}
#[test]
fn length_signal_long_message_nudges_feature() {
let long = "add new payment integration with Stripe — supports 3DS, refunds, webhooks, and idempotency keys";
assert!(long.len() > 80, "test message must be >80 chars");
let scores = score_message_length(long);
assert!(
scores[Cat::Feature.index()] > 0.0,
"long message should nudge Feature"
);
}
#[test]
fn merge_indicator_signal_fires_for_is_merge_flag() {
let scores = score_merge_indicator(true, "some message");
assert!(
scores[Cat::Merge.index()] > 0.40,
"merge indicator should give large Merge score"
);
assert!(
scores[Cat::Feature.index()] < 0.0,
"merge indicator should penalise Feature"
);
}
#[test]
fn merge_indicator_signal_zero_for_non_merge() {
let scores = score_merge_indicator(false, "fix null pointer");
for (i, &s) in scores.iter().enumerate() {
assert_eq!(s, 0.0, "non-merge commit should produce 0 for cat {i}");
}
}
#[test]
fn file_paths_signal_zero_when_empty() {
let scores = score_file_paths(&[]);
for (i, &s) in scores.iter().enumerate() {
assert_eq!(s, 0.0, "empty paths should produce 0 for cat {i}");
}
}
#[test]
fn file_paths_signal_tests_heavy_nudges_maintenance() {
let paths: Vec<String> = vec![
"tests/auth_test.rs".to_string(),
"tests/payment_test.rs".to_string(),
"tests/webhook_test.rs".to_string(),
"src/lib.rs".to_string(),
];
let scores = score_file_paths(&paths);
assert!(
scores[Cat::Maintenance.index()] > 0.0,
"tests-heavy paths should boost Maintenance"
);
}
#[test]
fn file_paths_signal_docs_heavy_nudges_content() {
let paths: Vec<String> = vec![
"docs/api.md".to_string(),
"docs/setup.md".to_string(),
"README.md".to_string(),
];
let scores = score_file_paths(&paths);
assert!(
scores[Cat::Content.index()] > 0.0,
"docs-heavy paths should boost Content"
);
}
#[test]
fn integration_fix_message_classifies_as_bugfix() {
let clf = default_classifier();
let result = clf.classify("fix: handle null user — fixes regression", false, &[]);
assert!(result.is_some(), "expected a verdict for a bugfix message");
let r = result.unwrap();
assert_eq!(r.category, "bugfix", "expected bugfix category");
assert!(
r.confidence >= 0.55,
"confidence should be >= 0.55, got {}",
r.confidence
);
assert_eq!(r.method, ClassificationMethod::WeightedSum);
}
#[test]
fn integration_merge_commit_classifies_as_merge() {
let clf = default_classifier();
let result = clf.classify("Merge pull request #42 from main", true, &[]);
assert!(result.is_some(), "expected a verdict for a merge commit");
let r = result.unwrap();
assert_eq!(r.category, "merge");
assert_eq!(r.method, ClassificationMethod::WeightedSum);
}
#[test]
fn integration_feature_message_classifies_as_feature() {
let clf = default_classifier();
let result = clf.classify(
"add new payment feature support with webhook integration",
false,
&[],
);
assert!(result.is_some(), "expected a verdict for a feature message");
let r = result.unwrap();
assert_eq!(r.category, "feature");
assert!(r.confidence >= 0.55);
}
#[test]
fn fall_through_when_no_signal_dominates() {
let clf = default_classifier();
let result = clf.classify("zzz qqq vvv www yyy uuu ppp rrr", false, &[]);
if let Some(ref r) = result {
assert!(
r.confidence >= 0.55,
"if a verdict is emitted it must exceed min_confidence"
);
}
}
#[test]
fn argmax_tie_does_not_emit_verdict() {
let clf = default_classifier();
let result = clf.classify("xyzxyzxyz blah blah blah nothing here", false, &[]);
if let Some(ref r) = result {
assert!(
r.confidence >= clf.config.min_confidence as f64,
"any emitted verdict must clear min_confidence"
);
}
}
#[test]
fn disabled_classifier_always_returns_none() {
let clf = WeightedSumClassifier::new(WeightedSumConfig {
enabled: false,
..WeightedSumConfig::default()
});
let result = clf.classify("fix: handle null pointer — critical bug", false, &[]);
assert!(
result.is_none(),
"disabled classifier must always return None"
);
}
#[test]
fn integration_fix_with_test_paths_produces_bugfix_or_maintenance() {
let clf = default_classifier();
let paths = vec![
"tests/auth_test.rs".to_string(),
"tests/null_test.rs".to_string(),
];
let result = clf.classify("fix bug: handle null pointer in auth module", false, &paths);
assert!(result.is_some(), "expected a verdict");
let r = result.unwrap();
assert!(
r.category == "bugfix" || r.category == "refactor",
"expected bugfix or refactor, got: {}",
r.category
);
assert!(r.confidence >= 0.55);
assert_eq!(r.method, ClassificationMethod::WeightedSum);
}
#[test]
fn emitted_confidence_stays_within_bounds() {
let clf = default_classifier();
let result = clf.classify(
"fix bug issue broken regression hotfix patch resolve repair correct",
false,
&[],
);
if let Some(r) = result {
assert!(r.confidence >= 0.55, "below min_confidence floor");
assert!(r.confidence <= 0.95, "above max confidence ceiling");
}
}
}