use crate::classify::taxonomy::TopLevelCategory;
use crate::classify::tiers::ClassificationResult;
use crate::core::models::ClassificationMethod;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum Cat {
Feature = 0,
Bugfix = 1,
Ktlo = 2,
Integrations = 3,
PlatformWork = 4,
Content = 5,
Maintenance = 6,
Merge = 7,
}
pub(crate) const NUM_CATS: usize = 8;
impl Cat {
const ALL: [Cat; NUM_CATS] = [
Cat::Feature,
Cat::Bugfix,
Cat::Ktlo,
Cat::Integrations,
Cat::PlatformWork,
Cat::Content,
Cat::Maintenance,
Cat::Merge,
];
pub(crate) fn index(self) -> usize {
self as usize
}
fn to_verdict(self) -> (&'static str, TopLevelCategory) {
match self {
Cat::Feature => ("feature", TopLevelCategory::Feature),
Cat::Bugfix => ("bugfix", TopLevelCategory::Bugfix),
Cat::Ktlo => ("chore", TopLevelCategory::Ktlo),
Cat::Integrations => ("integration", TopLevelCategory::Integrations),
Cat::PlatformWork => ("platform", TopLevelCategory::PlatformWork),
Cat::Content => ("docs", TopLevelCategory::Content),
Cat::Maintenance => ("refactor", TopLevelCategory::Maintenance),
Cat::Merge => ("merge", TopLevelCategory::Maintenance),
}
}
}
static MERGE_WEIGHTS: &[f32; NUM_CATS] = &[
-0.05, -0.05, -0.05, -0.05, -0.05, -0.05, -0.05, 0.65,
];
const FEATURE_KEYWORDS: &[&str] = &[
"add",
"implement",
"feature",
"support",
"introduce",
"create",
"build",
"new",
"extend",
"enable",
];
const BUGFIX_KEYWORDS: &[&str] = &[
"fix",
"bug",
"issue",
"broken",
"regression",
"hotfix",
"patch",
"resolve",
"repair",
"correct",
];
const KTLO_KEYWORDS: &[&str] = &[
"chore", "ci", "build", "ops", "release", "version", "bump", "update", "upgrade", "automate",
];
const INTEGRATION_KEYWORDS: &[&str] = &[
"integrate",
"api",
"webhook",
"sdk",
"plugin",
"connector",
"bridge",
"endpoint",
"external",
"third-party",
];
const PLATFORM_KEYWORDS: &[&str] = &[
"perf",
"performance",
"infra",
"infrastructure",
"architecture",
"devops",
"deploy",
"scale",
"optimize",
"database",
];
const CONTENT_KEYWORDS: &[&str] = &[
"docs",
"readme",
"documentation",
"comment",
"typo",
"copy",
"translation",
"locale",
"i18n",
"asset",
];
const MAINTENANCE_KEYWORDS: &[&str] = &[
"refactor",
"cleanup",
"rename",
"deps",
"dependency",
"style",
"lint",
"format",
"test",
"remove",
];
static KEYWORD_BAGS: &[&[&str]] = &[
FEATURE_KEYWORDS, BUGFIX_KEYWORDS, KTLO_KEYWORDS, INTEGRATION_KEYWORDS, PLATFORM_KEYWORDS, CONTENT_KEYWORDS, MAINTENANCE_KEYWORDS, &[], ];
pub(crate) fn score_keywords(lower: &str) -> [f32; NUM_CATS] {
let mut out = [0.0f32; NUM_CATS];
for cat in Cat::ALL {
let bag = KEYWORD_BAGS[cat.index()];
if bag.is_empty() {
continue;
}
let matched = bag.iter().filter(|&&kw| lower.contains(kw)).count();
let score = match matched {
0 => 0.0,
1 => 0.40,
2 => 0.60,
_ => 0.75,
};
out[cat.index()] = score;
}
out
}
pub(crate) fn score_ticket_prefix(message: &str) -> [f32; NUM_CATS] {
const TICKET_WEIGHT: f32 = 0.05;
if has_jira_style_prefix(message) {
[TICKET_WEIGHT; NUM_CATS]
} else {
[0.0; NUM_CATS]
}
}
pub(crate) fn score_message_length(trimmed: &str) -> [f32; NUM_CATS] {
let len = trimmed.len();
let mut out = [0.0f32; NUM_CATS];
if len < 12 {
out[Cat::Ktlo.index()] = 0.10;
out[Cat::Merge.index()] = 0.10;
out[Cat::Maintenance.index()] = 0.05;
out[Cat::Feature.index()] = -0.05;
out[Cat::Bugfix.index()] = -0.03;
} else if len > 80 {
out[Cat::Feature.index()] = 0.10;
out[Cat::PlatformWork.index()] = 0.10;
out[Cat::Maintenance.index()] = 0.05;
out[Cat::Bugfix.index()] = 0.05;
}
out
}
pub(crate) fn score_merge_indicator(is_merge: bool, lower: &str) -> [f32; NUM_CATS] {
let is_merge_commit = is_merge
|| lower.starts_with("merge pull request")
|| lower.starts_with("merge branch")
|| lower.starts_with("merge remote-tracking")
|| lower.starts_with("merge ");
if !is_merge_commit {
return [0.0; NUM_CATS];
}
*MERGE_WEIGHTS
}
pub(crate) fn score_file_paths(paths: &[String]) -> [f32; NUM_CATS] {
if paths.is_empty() {
return [0.0; NUM_CATS];
}
let total = paths.len() as f32;
let test_count = paths
.iter()
.filter(|p| {
p.contains("tests/")
|| p.contains("test/")
|| p.contains("spec/")
|| p.ends_with("_test.rs")
|| p.ends_with("_spec.rb")
|| p.ends_with(".test.ts")
|| p.ends_with(".spec.ts")
})
.count() as f32;
let docs_count = paths
.iter()
.filter(|p| {
p.contains("docs/")
|| p.contains("doc/")
|| p.ends_with(".md")
|| p.ends_with(".rst")
|| p.ends_with(".txt")
})
.count() as f32;
let manifest_count = paths
.iter()
.filter(|p| {
let name = p.split('/').next_back().unwrap_or(p.as_str());
matches!(
name,
"Cargo.toml"
| "package.json"
| "pyproject.toml"
| "requirements.txt"
| "Gemfile"
| "pom.xml"
| "build.gradle"
| "go.mod"
| "Pipfile"
| "setup.py"
| "composer.json"
)
})
.count() as f32;
let mut out = [0.0f32; NUM_CATS];
let test_ratio = test_count / total;
if test_ratio >= 0.5 {
out[Cat::Maintenance.index()] += test_ratio * 0.20;
out[Cat::Bugfix.index()] += test_ratio * 0.10;
}
let docs_ratio = docs_count / total;
if docs_ratio >= 0.5 {
out[Cat::Content.index()] += docs_ratio * 0.20;
}
let manifest_ratio = manifest_count / total;
if manifest_ratio >= 0.5 {
out[Cat::Ktlo.index()] += manifest_ratio * 0.20;
out[Cat::Maintenance.index()] += manifest_ratio * 0.10;
}
out
}
fn has_jira_style_prefix(message: &str) -> bool {
let first = match message.split_whitespace().next() {
Some(s) => s,
None => return false,
};
let candidate = first.trim_end_matches([':', '-', ',']);
let mut parts = candidate.split('-');
let project = match parts.next() {
Some(s) => s,
None => return false,
};
let number = match parts.next() {
Some(s) => s,
None => return false,
};
if parts.next().is_some() {
return false;
}
if project.is_empty() || number.is_empty() {
return false;
}
project
.chars()
.next()
.is_some_and(|c| c.is_ascii_uppercase())
&& project
.chars()
.all(|c| c.is_ascii_uppercase() || c.is_ascii_digit())
&& number.chars().all(|c| c.is_ascii_digit())
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(deny_unknown_fields)]
pub struct WeightedSumConfig {
#[serde(default = "default_weighted_sum_enabled")]
pub enabled: bool,
#[serde(default = "default_weighted_sum_min_confidence")]
pub min_confidence: f32,
}
fn default_weighted_sum_enabled() -> bool {
true
}
fn default_weighted_sum_min_confidence() -> f32 {
0.55
}
impl Default for WeightedSumConfig {
fn default() -> Self {
Self {
enabled: default_weighted_sum_enabled(),
min_confidence: default_weighted_sum_min_confidence(),
}
}
}
pub struct WeightedSumClassifier {
pub(crate) config: WeightedSumConfig,
}
impl WeightedSumClassifier {
pub fn new(config: WeightedSumConfig) -> Self {
Self { config }
}
pub fn classify(
&self,
message: &str,
is_merge: bool,
paths: &[String],
) -> Option<ClassificationResult> {
if !self.config.enabled {
return None;
}
let trimmed = message.trim();
let lower = trimmed.to_lowercase();
let mut scores = [0.0f32; NUM_CATS];
let keyword_scores = score_keywords(&lower);
for i in 0..NUM_CATS {
scores[i] += keyword_scores[i];
}
let ticket_scores = score_ticket_prefix(trimmed);
for i in 0..NUM_CATS {
scores[i] += ticket_scores[i];
}
let length_scores = score_message_length(trimmed);
for i in 0..NUM_CATS {
scores[i] += length_scores[i];
}
let merge_scores = score_merge_indicator(is_merge, &lower);
for i in 0..NUM_CATS {
scores[i] += merge_scores[i];
}
let path_scores = score_file_paths(paths);
for i in 0..NUM_CATS {
scores[i] += path_scores[i];
}
let (best_cat_idx, &best_score) = scores
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))?;
let tie_count = scores.iter().filter(|&&s| s == best_score).count();
if tie_count > 1 || best_score <= 0.0 {
return None;
}
if (best_score as f64) < self.config.min_confidence as f64 {
return None;
}
let best_cat = Cat::ALL[best_cat_idx];
let (category, top_level) = best_cat.to_verdict();
let confidence = (best_score as f64)
.max(self.config.min_confidence as f64)
.min(0.95);
Some(ClassificationResult {
category: category.to_string(),
subcategory: None,
top_level: Some(top_level),
confidence,
method: ClassificationMethod::WeightedSum,
ticket_id: None,
complexity: None,
})
}
}