use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use rayon::prelude::*;
use rusqlite::Connection;
use crate::classify::errors::Result;
use crate::classify::rules::RuleSet;
use crate::classify::taxonomy::{SubcategoryDef, TaxonomyRegistry};
use crate::classify::tiers::exact::ExactMatcher;
use crate::classify::tiers::fuzzy::FuzzyClassifier;
use crate::classify::tiers::issue_type_tier::IssueTypeTier;
use crate::classify::tiers::jira_project_tier::JiraProjectTier;
use crate::classify::tiers::llm::LlmClassifier;
use crate::classify::tiers::override_tier::OverrideTier;
use crate::classify::tiers::regex_tier::RegexMatcher;
use crate::classify::tiers::weighted_sum::WeightedSumClassifier;
use crate::classify::tiers::ClassificationResult;
use crate::core::models::ClassificationMethod;
#[derive(Debug, Clone)]
pub struct ClassificationEngineConfig {
pub use_llm: bool,
pub llm_model: String,
pub llm_provider: String,
pub openrouter_api_key: Option<String>,
pub confidence_threshold: f64,
pub weighted_sum: crate::classify::tiers::weighted_sum::WeightedSumConfig,
}
impl Default for ClassificationEngineConfig {
fn default() -> Self {
Self {
use_llm: false,
llm_model: "gpt-4o-mini".to_string(),
llm_provider: "auto".to_string(),
openrouter_api_key: None,
confidence_threshold: 0.7,
weighted_sum: crate::classify::tiers::weighted_sum::WeightedSumConfig::default(),
}
}
}
pub struct ClassificationEngine {
override_tier: Option<OverrideTier>,
exact: ExactMatcher,
issue_type: IssueTypeTier,
regex: RegexMatcher,
jira_project: JiraProjectTier,
weighted_sum: WeightedSumClassifier,
fuzzy: Option<FuzzyClassifier>,
llm: Option<LlmClassifier>,
taxonomy: TaxonomyRegistry,
config: ClassificationEngineConfig,
}
impl ClassificationEngine {
pub fn new(ruleset: RuleSet, config: ClassificationEngineConfig) -> Result<Self> {
Self::with_taxonomy(ruleset, config, Vec::new())
}
pub fn with_taxonomy(
ruleset: RuleSet,
config: ClassificationEngineConfig,
custom_taxonomy: Vec<SubcategoryDef>,
) -> Result<Self> {
Self::with_taxonomy_and_mappings(ruleset, config, custom_taxonomy, HashMap::new(), None)
}
pub fn with_taxonomy_and_mappings(
ruleset: RuleSet,
config: ClassificationEngineConfig,
custom_taxonomy: Vec<SubcategoryDef>,
jira_project_mappings: HashMap<String, String>,
override_conn: Option<Arc<Mutex<Connection>>>,
) -> Result<Self> {
Self::with_taxonomy_mappings_and_confidence(
ruleset,
config,
custom_taxonomy,
jira_project_mappings,
None,
override_conn,
)
}
pub fn with_taxonomy_mappings_and_confidence(
ruleset: RuleSet,
config: ClassificationEngineConfig,
custom_taxonomy: Vec<SubcategoryDef>,
jira_project_mappings: HashMap<String, String>,
jira_confidence: Option<f64>,
override_conn: Option<Arc<Mutex<Connection>>>,
) -> Result<Self> {
let exact = ExactMatcher::new(&ruleset.rules)?;
let regex = RegexMatcher::new(&ruleset.rules)?;
let weighted_sum = WeightedSumClassifier::new(config.weighted_sum.clone());
let fuzzy = if ruleset.extend_defaults {
Some(FuzzyClassifier)
} else {
None
};
let llm = if config.use_llm {
match LlmClassifier::from_provider(
&config.llm_provider,
&config.llm_model,
config.openrouter_api_key.clone(),
) {
Ok(c) => Some(c),
Err(e) => {
return Err(crate::classify::errors::ClassifyError::Config(format!(
"LLM provider init failed: {e}"
)))
}
}
} else {
None
};
let taxonomy = TaxonomyRegistry::new(custom_taxonomy);
let issue_type = IssueTypeTier::with_taxonomy(taxonomy.clone());
let jira_project = JiraProjectTier::with_taxonomy_and_confidence(
jira_project_mappings,
taxonomy.clone(),
jira_confidence.unwrap_or(
crate::classify::tiers::jira_project_tier::DEFAULT_PROJECT_MAPPING_CONFIDENCE,
),
);
let override_tier = override_conn.map(|c| OverrideTier::with_taxonomy(c, taxonomy.clone()));
Ok(Self {
override_tier,
exact,
issue_type,
regex,
jira_project,
weighted_sum,
fuzzy,
llm,
taxonomy,
config,
})
}
pub fn taxonomy(&self) -> &TaxonomyRegistry {
&self.taxonomy
}
#[cfg(test)]
pub(crate) fn with_test_llm_endpoint(mut self, endpoint: &str) -> Self {
self.llm = Some(
LlmClassifier::new(&self.config.llm_model, Some("sk-test".to_string()))
.with_endpoint(endpoint),
);
self
}
pub fn config(&self) -> &ClassificationEngineConfig {
&self.config
}
pub fn classify_sync(&self, message: &str, is_merge: bool) -> Option<ClassificationResult> {
self.classify_sync_with_context(message, is_merge, None, None, None)
}
pub fn classify_sync_with_context(
&self,
message: &str,
is_merge: bool,
commit_sha: Option<&str>,
repo_path: Option<&str>,
issue_type: Option<&str>,
) -> Option<ClassificationResult> {
if let (Some(tier), Some(sha), Some(repo)) =
(self.override_tier.as_ref(), commit_sha, repo_path)
{
if let Some(r) = tier.lookup(sha, repo) {
return Some(r);
}
}
if let Some(rule) = self.exact.classify(message) {
return Some(ClassificationResult {
top_level: self.taxonomy.resolve(&rule.category),
category: rule.category.clone(),
subcategory: rule.subcategory.clone(),
confidence: rule.confidence,
method: ClassificationMethod::ExactRule,
ticket_id: RegexMatcher::extract_ticket_id(message),
complexity: None,
});
}
if let Some(it) = issue_type {
if let Some(mut r) = self.issue_type.classify(it) {
r.ticket_id = RegexMatcher::extract_ticket_id(message);
return Some(r);
}
}
if !self.jira_project.is_empty() {
if let Some(r) = self.jira_project.classify(message) {
return Some(r);
}
}
if let Some(rule) = self.regex.classify(message) {
return Some(ClassificationResult {
top_level: self.taxonomy.resolve(&rule.category),
category: rule.category.clone(),
subcategory: rule.subcategory.clone(),
confidence: rule.confidence,
method: ClassificationMethod::RegexRule,
ticket_id: RegexMatcher::extract_ticket_id(message),
complexity: None,
});
}
if let Some(mut result) = self.weighted_sum.classify(message, is_merge, &[]) {
if result.ticket_id.is_none() {
result.ticket_id = RegexMatcher::extract_ticket_id(message);
}
if let Some(top) = self.taxonomy.resolve(&result.category) {
result.top_level = Some(top);
}
return Some(result);
}
if let Some(fuzzy) = &self.fuzzy {
if let Some(mut result) = fuzzy.classify(message, is_merge) {
if result.ticket_id.is_none() {
result.ticket_id = RegexMatcher::extract_ticket_id(message);
}
if let Some(top) = self.taxonomy.resolve(&result.category) {
result.top_level = Some(top);
}
return Some(result);
}
}
None
}
pub async fn classify(&self, message: &str, is_merge: bool) -> ClassificationResult {
if let Some(r) = self.classify_sync(message, is_merge) {
return r;
}
if let Some(r) = self.llm_classify_only(message).await {
return r;
}
let mut fallback = ClassificationResult::unclassified();
fallback.ticket_id = RegexMatcher::extract_ticket_id(message);
fallback
}
pub async fn llm_classify_only(&self, message: &str) -> Option<ClassificationResult> {
let llm = self.llm.as_ref()?;
let mut r = llm.classify(message).await?;
r.top_level = self.taxonomy.resolve(&r.category);
if r.ticket_id.is_none() {
r.ticket_id = RegexMatcher::extract_ticket_id(message);
}
Some(r)
}
pub fn llm_has_api_key(&self) -> Option<bool> {
self.llm.as_ref().map(LlmClassifier::has_api_key)
}
pub fn classify_batch(&self, messages: &[(&str, bool)]) -> Vec<ClassificationResult> {
messages
.par_iter()
.map(|(msg, is_merge)| {
self.classify_sync(msg, *is_merge)
.unwrap_or_else(ClassificationResult::unclassified)
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::classify::rules::default_rules;
#[test]
fn jira_project_mapping_outranks_generic_ticket_regex() {
let mut mappings = HashMap::new();
mappings.insert("TQL".to_string(), "bug_fix".to_string());
let engine = ClassificationEngine::with_taxonomy_and_mappings(
default_rules(),
ClassificationEngineConfig::default(),
Vec::new(),
mappings,
None,
)
.expect("engine builds");
let v = engine
.classify_sync("TQL-1234 fix null pointer", false)
.expect("verdict");
assert_eq!(v.category, "bug_fix");
assert!((v.confidence - 0.88).abs() < 1e-6);
assert_eq!(v.ticket_id.as_deref(), Some("TQL-1234"));
}
#[test]
fn jira_project_mapping_confidence_threads_through_engine_builder() {
let mut mappings = HashMap::new();
mappings.insert("INFRA".to_string(), "platform".to_string());
let engine = ClassificationEngine::with_taxonomy_mappings_and_confidence(
default_rules(),
ClassificationEngineConfig::default(),
Vec::new(),
mappings,
Some(0.5),
None,
)
.expect("engine builds");
let v = engine
.classify_sync("INFRA-7 patch", false)
.expect("verdict");
assert!((v.confidence - 0.5).abs() < 1e-6);
}
#[test]
fn fuzzy_tier_suppressed_when_extend_defaults_false() {
use crate::classify::rules::{Rule, RuleSet};
let ruleset = RuleSet {
version: None,
extend_defaults: false, rules: vec![Rule {
id: "my-deploy".to_string(),
category: "deployment".to_string(),
subcategory: None,
keywords: vec!["deploy:".to_string()],
patterns: vec![],
priority: 110,
confidence: 0.9,
}],
};
let engine = ClassificationEngine::new(ruleset, ClassificationEngineConfig::default())
.expect("engine builds");
let result = engine.classify_sync("Merge pull request #42 from main", true);
if let Some(ref r) = result {
assert_ne!(
r.method,
ClassificationMethod::FuzzyMatch,
"fuzzy tier must not fire when extend_defaults is false; got: {result:?}"
);
}
}
#[test]
fn fuzzy_tier_active_when_extend_defaults_true() {
let ruleset = {
let mut rs = default_rules();
rs.extend_defaults = true;
rs
};
let engine = ClassificationEngine::new(ruleset, ClassificationEngineConfig::default())
.expect("engine builds");
let result = engine.classify_sync("Merge pull request #42 from main", true);
assert!(
result.is_some(),
"fuzzy tier must fire for merge commits when extend_defaults is true"
);
let r = result.unwrap();
assert_eq!(r.category, "merge");
}
#[test]
fn exact_rule_still_beats_jira_project_mapping() {
let mut mappings = HashMap::new();
mappings.insert("TQL".to_string(), "platform".to_string());
let engine = ClassificationEngine::with_taxonomy_and_mappings(
default_rules(),
ClassificationEngineConfig::default(),
Vec::new(),
mappings,
None,
)
.expect("engine builds");
let v = engine
.classify_sync("fix: TQL-1 handle null user", false)
.expect("verdict");
assert_eq!(v.category, "bugfix");
assert_eq!(v.method, ClassificationMethod::ExactRule);
}
}