use std::collections::HashMap;
use futures::stream::StreamExt;
use indicatif::{ProgressBar, ProgressStyle};
use rusqlite::params;
use tracing::{info, warn};
use crate::classify::classifier::{ClassificationEngine, ClassificationEngineConfig};
use crate::classify::errors::Result;
use crate::classify::rules::{default_rules, load_rules};
use crate::classify::sources::ExternalSourceResolver;
use crate::classify::tiers::ClassificationResult;
use crate::core::config::Config;
use crate::core::db::{CheckpointMode, Database};
use crate::core::models::ClassificationMethod;
#[allow(dead_code)]
const DEFAULT_MIN_COVERAGE_PCT: f64 = 20.0;
#[derive(Debug, Clone, Default)]
pub struct ClassificationStats {
pub total_commits: usize,
pub classified: usize,
pub by_method: HashMap<String, usize>,
pub by_category: HashMap<String, usize>,
pub coverage_pct: f64,
pub coverage_by_repo: HashMap<String, RepoCoverage>,
}
#[derive(Debug, Clone, Default)]
pub struct RepoCoverage {
pub total: usize,
pub classified: usize,
pub coverage_pct: f64,
}
pub struct ClassificationPipeline {
config: Config,
force: bool,
since: Option<String>,
until: Option<String>,
repos: Vec<String>,
}
impl ClassificationPipeline {
pub fn new(config: Config) -> Self {
Self {
config,
force: false,
since: None,
until: None,
repos: Vec::new(),
}
}
pub fn with_force(mut self, force: bool) -> Self {
self.force = force;
self
}
pub fn with_since(mut self, since: Option<String>) -> Self {
self.since = since;
self
}
pub fn with_until(mut self, until: Option<String>) -> Self {
self.until = until;
self
}
pub fn with_repos(mut self, repos: Vec<String>) -> Self {
self.repos = repos;
self
}
async fn build_engine(&self) -> Result<ClassificationEngine> {
let ruleset = match self
.config
.classification
.as_ref()
.and_then(|c| c.rules_file.as_ref())
{
Some(path) => {
let custom = load_rules(path)?;
if custom.extend_defaults {
let mut merged = default_rules();
let custom_ids: std::collections::HashSet<String> =
custom.rules.iter().map(|r| r.id.clone()).collect();
merged.rules.retain(|r| !custom_ids.contains(&r.id));
merged.rules.extend(custom.rules);
merged
} else {
custom
}
}
None => default_rules(),
};
let use_llm = self
.config
.classification
.as_ref()
.map(|c| c.use_llm)
.unwrap_or(false);
let engine_cfg = match self.config.classification.as_ref() {
Some(c) => ClassificationEngineConfig {
use_llm: c.use_llm,
llm_model: c.llm_model.clone().unwrap_or_else(|| "gpt-4o-mini".into()),
llm_provider: c.llm_provider.clone(),
openrouter_api_key: c.openrouter_api_key.clone(),
confidence_threshold: c.confidence_threshold,
weighted_sum: c.weighted_sum.clone(),
},
None => ClassificationEngineConfig::default(),
};
let custom_taxonomy = self
.config
.classification
.as_ref()
.map(|c| c.custom_categories.clone())
.unwrap_or_default();
let jira_mappings = self
.config
.jira
.as_ref()
.map(|j| j.jira_project_mappings.clone())
.unwrap_or_default();
let jira_confidence = self
.config
.jira
.as_ref()
.and_then(|j| j.jira_project_mapping_confidence);
let engine_cfg_no_llm = ClassificationEngineConfig {
use_llm: false,
..engine_cfg.clone()
};
let mut engine = ClassificationEngine::with_taxonomy_mappings_and_confidence(
ruleset,
engine_cfg_no_llm,
custom_taxonomy,
jira_mappings,
jira_confidence,
None,
)?;
if use_llm {
let llm_classifier = if let Some(llm_cfg) = self.config.llm.as_ref() {
let model = llm_cfg
.model
.as_deref()
.or(self
.config
.classification
.as_ref()
.and_then(|c| c.llm_model.as_deref()))
.unwrap_or("gpt-4o-mini");
crate::classify::tiers::llm::LlmClassifier::from_llm_config(llm_cfg, model)
.await
.map_err(|e| {
crate::classify::errors::ClassifyError::Config(format!(
"LLM provider init failed (llm: section): {e}"
))
})?
} else {
if self
.config
.classification
.as_ref()
.map(|c| c.openrouter_api_key.is_some() || c.llm_provider != "auto")
.unwrap_or(false)
{
warn!(
"classification.openrouter_api_key / classification.llm_provider \
are deprecated. Migrate to the top-level `llm:` section: \
'llm:\\n source: openrouter\\n api_key_env: OPENROUTER_API_KEY'"
);
}
crate::classify::tiers::llm::LlmClassifier::from_provider_async(
&engine_cfg.llm_provider,
&engine_cfg.llm_model,
engine_cfg.openrouter_api_key.clone(),
)
.await
.map_err(|e| {
crate::classify::errors::ClassifyError::Config(format!(
"LLM provider init failed: {e}"
))
})?
};
if !llm_classifier.has_api_key() {
return Err(crate::classify::errors::ClassifyError::Config(
"LLM tier is enabled (use_llm: true) but no API key or credentials \
could be resolved. Ensure the environment variable named by \
llm.api_key_env is set and non-empty (for openrouter/anthropic-api), \
or that valid AWS credentials are present in the credential chain \
(for bedrock). No database writes will occur."
.to_string(),
));
}
engine.attach_llm(llm_classifier);
}
Ok(engine)
}
fn build_resolver(&self) -> Option<ExternalSourceResolver> {
let no_external = self
.config
.classification
.as_ref()
.map(|c| c.no_external)
.unwrap_or(false);
if no_external {
return None;
}
let sources = self
.config
.classification
.as_ref()
.map(|c| c.sources.as_slice())
.unwrap_or(&[]);
if sources.is_empty() {
return None;
}
Some(ExternalSourceResolver::new(sources))
}
pub async fn run(&self, db: &mut Database) -> Result<ClassificationStats> {
let engine = self.build_engine().await?;
let resolver = self.build_resolver();
self.run_with_engine_and_resolver(db, engine, resolver)
.await
}
#[allow(dead_code)]
pub(crate) async fn run_with_engine(
&self,
db: &mut Database,
engine: ClassificationEngine,
) -> Result<ClassificationStats> {
self.run_with_engine_and_resolver(db, engine, None).await
}
pub(crate) async fn run_with_engine_and_resolver(
&self,
db: &mut Database,
engine: ClassificationEngine,
resolver: Option<ExternalSourceResolver>,
) -> Result<ClassificationStats> {
let commits = read_candidate_commits(
db,
self.force,
self.since.as_deref(),
self.until.as_deref(),
&self.repos,
)?;
let total = commits.len();
info!(
total,
force = self.force,
since = ?self.since,
until = ?self.until,
repos = ?self.repos,
"starting classification"
);
if commits.is_empty() {
return Ok(ClassificationStats::default());
}
let overrides = read_overrides(db, &commits)?;
let pairs: Vec<(&str, bool)> = commits
.iter()
.map(|c| (c.message.as_str(), c.is_merge))
.collect();
let mut results = engine.classify_batch(&pairs);
for (idx, commit) in commits.iter().enumerate() {
if let Some(r) = overrides.get(&commit.id) {
results[idx] = r.clone();
}
}
if let Some(res) = &resolver {
let pb = make_progress(commits.len() as u64, "External sources");
for (idx, commit) in commits.iter().enumerate() {
if overrides.contains_key(&commit.id) {
pb.inc(1);
continue;
}
if let Some(signal) = res.resolve(&commit.message).await {
let top_level = engine.taxonomy().resolve(&signal.category);
results[idx] = ClassificationResult {
category: signal.category,
subcategory: None,
top_level,
confidence: signal.confidence,
method: ClassificationMethod::ExternalSource,
ticket_id:
crate::classify::tiers::regex_tier::RegexMatcher::extract_ticket_id(
&commit.message,
),
complexity: None,
};
}
pb.inc(1);
}
pb.finish_and_clear();
}
if engine.config().use_llm {
if matches!(engine.llm_has_api_key(), Some(false)) {
warn!(
"LLM tier enabled but no API key resolved \
(OPENAI_API_KEY / OPENROUTER_API_KEY unset); \
fallback will short-circuit silently"
);
}
let fallback_threshold = self
.config
.classification
.as_ref()
.map(|c| c.llm_fallback_threshold)
.unwrap_or(0.65);
let concurrency = self
.config
.classification
.as_ref()
.map(|c| c.llm_fallback_concurrency.max(1))
.unwrap_or(8);
let pending: Vec<(usize, String, bool, f64)> = commits
.iter()
.enumerate()
.filter_map(|(idx, commit)| {
if results[idx].confidence <= fallback_threshold {
Some((
idx,
commit.message.clone(),
commit.is_merge,
results[idx].confidence,
))
} else {
None
}
})
.collect();
let pb = make_progress(pending.len() as u64, "LLM fallback");
let engine_ref = &engine;
let pb_ref = &pb;
let new_results: Vec<(usize, ClassificationResult, f64)> =
futures::stream::iter(pending.into_iter().map(
|(idx, message, _is_merge, original_conf)| async move {
let r = engine_ref
.llm_classify_only(&message)
.await
.unwrap_or_else(ClassificationResult::unclassified);
pb_ref.inc(1);
(idx, r, original_conf)
},
))
.buffer_unordered(concurrency)
.collect()
.await;
pb.finish_and_clear();
for (idx, r, original_conf) in new_results {
if r.confidence > original_conf {
results[idx] = r;
} else {
warn!(
commit_idx = idx,
original_conf,
new_conf = r.confidence,
"LLM fallback did not improve confidence; keeping original verdict"
);
}
}
}
let checkpoint_every = self
.config
.classification
.as_ref()
.map(|c| c.checkpoint_every)
.unwrap_or(0);
let mut stats = write_results(db, &commits, &results, checkpoint_every)?;
compute_coverage(&mut stats);
persist_repository_status(db, &stats)?;
report_coverage(&stats, self.min_coverage_pct());
info!(
total = stats.total_commits,
classified = stats.classified,
coverage_pct = stats.coverage_pct,
"classification complete"
);
Ok(stats)
}
fn min_coverage_pct(&self) -> f64 {
self.config
.classification
.as_ref()
.map(|c| c.min_coverage_pct)
.unwrap_or(DEFAULT_MIN_COVERAGE_PCT)
}
pub async fn backfill_complexity(&self, db: &mut Database) -> Result<usize> {
let engine = self.build_engine().await?;
Self::backfill_complexity_with_engine(db, &engine).await
}
pub(crate) async fn backfill_complexity_with_engine(
db: &mut Database,
engine: &ClassificationEngine,
) -> Result<usize> {
let candidates = read_complexity_backfill_candidates(db)?;
let total = candidates.len();
info!(total, "starting complexity backfill");
if candidates.is_empty() {
return Ok(0);
}
let pb = make_progress(total as u64, "Complexity backfill");
let mut updated = 0_usize;
{
let conn = db.connection_mut();
let tx = conn.transaction().map_err(crate::core::TgaError::from)?;
{
let mut update_stmt = tx
.prepare("UPDATE classifications SET complexity = ?1 WHERE id = ?2")
.map_err(crate::core::TgaError::from)?;
for cand in &candidates {
let verdict = engine.llm_classify_only(&cand.message).await;
let complexity = verdict.and_then(|r| r.complexity);
match complexity {
Some(score) => {
update_stmt
.execute(params![score as i64, cand.classification_id])
.map_err(crate::core::TgaError::from)?;
updated += 1;
info!(
commit_sha = %cand.commit_sha,
score,
"backfilled complexity"
);
}
None => {
warn!(
commit_sha = %cand.commit_sha,
"LLM returned no complexity score; leaving NULL"
);
}
}
pb.inc(1);
}
}
tx.commit().map_err(crate::core::TgaError::from)?;
}
pb.finish_and_clear();
info!(updated, total, "complexity backfill complete");
Ok(updated)
}
}
struct ComplexityBackfillCandidate {
classification_id: i64,
commit_sha: String,
message: String,
}
fn read_complexity_backfill_candidates(db: &Database) -> Result<Vec<ComplexityBackfillCandidate>> {
let mut stmt = db
.connection()
.prepare(
"SELECT cl.id, c.sha, c.message \
FROM classifications cl \
JOIN commits c ON c.classification_id = cl.id \
WHERE cl.complexity IS NULL AND cl.method != 'exact_rule'",
)
.map_err(crate::core::TgaError::from)?;
let rows = stmt
.query_map([], |row| {
Ok(ComplexityBackfillCandidate {
classification_id: row.get(0)?,
commit_sha: row.get(1)?,
message: row.get(2)?,
})
})
.map_err(crate::core::TgaError::from)?;
let mut out = Vec::new();
for r in rows {
out.push(r.map_err(crate::core::TgaError::from)?);
}
Ok(out)
}
fn read_overrides(
db: &Database,
commits: &[CommitRow],
) -> Result<HashMap<i64, ClassificationResult>> {
use crate::classify::taxonomy::TaxonomyRegistry;
use crate::core::models::ClassificationMethod;
let mut out: HashMap<i64, ClassificationResult> = HashMap::new();
if commits.is_empty() {
return Ok(out);
}
let taxonomy = TaxonomyRegistry::with_builtins();
let conn = db.connection();
let mut stmt = conn
.prepare(
"SELECT work_type, change_type FROM classification_overrides \
WHERE commit_sha = ?1 AND repo_path = ?2",
)
.map_err(crate::core::TgaError::from)?;
for commit in commits {
let row = stmt.query_row(params![commit.sha, commit.repository], |row| {
let work_type: String = row.get(0)?;
let change_type: String = row.get(1)?;
Ok((work_type, change_type))
});
match row {
Ok((work_type, change_type)) => {
let top_level = taxonomy
.resolve(&change_type)
.or_else(|| taxonomy.resolve(&work_type));
out.insert(
commit.id,
ClassificationResult {
category: work_type,
subcategory: Some(change_type),
top_level,
confidence: 1.0,
method: ClassificationMethod::Manual,
ticket_id: None,
complexity: None,
},
);
}
Err(rusqlite::Error::QueryReturnedNoRows) => {}
Err(e) => return Err(crate::core::TgaError::from(e).into()),
}
}
Ok(out)
}
fn compute_coverage(stats: &mut ClassificationStats) {
stats.coverage_pct = if stats.total_commits == 0 {
0.0
} else {
(stats.classified as f64 / stats.total_commits as f64) * 100.0
};
for repo in stats.coverage_by_repo.values_mut() {
repo.coverage_pct = if repo.total == 0 {
0.0
} else {
(repo.classified as f64 / repo.total as f64) * 100.0
};
}
}
fn persist_repository_status(db: &mut Database, stats: &ClassificationStats) -> Result<()> {
if stats.coverage_by_repo.is_empty() {
return Ok(());
}
let conn = db.connection_mut();
let tx = conn.transaction().map_err(crate::core::TgaError::from)?;
{
let mut upsert = tx
.prepare(
"INSERT INTO repository_analysis_status \
(repo_name, last_analyzed_at, classification_coverage_pct, \
total_commits, classified_commits) \
VALUES (?1, datetime('now'), ?2, ?3, ?4) \
ON CONFLICT(repo_name) DO UPDATE SET \
last_analyzed_at = datetime('now'), \
classification_coverage_pct = excluded.classification_coverage_pct, \
total_commits = excluded.total_commits, \
classified_commits = excluded.classified_commits",
)
.map_err(crate::core::TgaError::from)?;
for (repo, cov) in &stats.coverage_by_repo {
upsert
.execute(params![
repo,
cov.coverage_pct,
cov.total as i64,
cov.classified as i64,
])
.map_err(crate::core::TgaError::from)?;
}
}
tx.commit().map_err(crate::core::TgaError::from)?;
Ok(())
}
fn report_coverage(stats: &ClassificationStats, threshold_pct: f64) {
info!(
"Classification coverage: {:.1}% ({} / {})",
stats.coverage_pct, stats.classified, stats.total_commits
);
if stats.coverage_pct < threshold_pct && stats.total_commits > 0 {
warn!(
coverage_pct = stats.coverage_pct,
threshold_pct, "classification coverage below configured threshold"
);
}
}
#[allow(dead_code)]
struct CommitRow {
id: i64,
sha: String,
message: String,
is_merge: bool,
repository: String,
existing_classification_id: Option<i64>,
}
fn is_revert_verdict(category: &str, subcategory: Option<&str>) -> bool {
fn matches(s: &str) -> bool {
s.eq_ignore_ascii_case("revert") || s.eq_ignore_ascii_case("rollback")
}
if matches(category) {
return true;
}
if let Some(sub) = subcategory {
if matches(sub) {
return true;
}
}
false
}
fn read_candidate_commits(
db: &Database,
force: bool,
since: Option<&str>,
until: Option<&str>,
repos: &[String],
) -> Result<Vec<CommitRow>> {
use rusqlite::types::Value;
let mut predicates: Vec<String> = Vec::new();
let mut params: Vec<Value> = Vec::new();
if !force {
predicates.push("classification_id IS NULL".to_string());
}
if let Some(s) = since {
params.push(Value::Text(s.to_string()));
predicates.push(format!("timestamp >= ?{}", params.len()));
}
if let Some(u) = until {
params.push(Value::Text(u.to_string()));
predicates.push(format!("timestamp <= ?{}", params.len()));
}
if !repos.is_empty() {
let start = params.len() + 1;
for r in repos {
params.push(Value::Text(r.clone()));
}
let end = params.len();
let placeholders: Vec<String> = (start..=end).map(|i| format!("?{i}")).collect();
predicates.push(format!("repository IN ({})", placeholders.join(", ")));
}
let where_clause = if predicates.is_empty() {
String::new()
} else {
format!(" WHERE {}", predicates.join(" AND "))
};
let sql = format!(
"SELECT id, sha, message, is_merge, repository, classification_id FROM commits{where_clause}"
);
let mut stmt = db
.connection()
.prepare(&sql)
.map_err(crate::core::TgaError::from)?;
let rows = stmt
.query_map(rusqlite::params_from_iter(params.iter()), |row| {
Ok(CommitRow {
id: row.get(0)?,
sha: row.get(1)?,
message: row.get(2)?,
is_merge: row.get::<_, i64>(3)? != 0,
repository: row.get(4)?,
existing_classification_id: row.get(5)?,
})
})
.map_err(crate::core::TgaError::from)?;
let mut out = Vec::new();
for r in rows {
out.push(r.map_err(crate::core::TgaError::from)?);
}
Ok(out)
}
fn write_results(
db: &mut Database,
commits: &[CommitRow],
results: &[ClassificationResult],
checkpoint_every: usize,
) -> Result<ClassificationStats> {
let mut stats = ClassificationStats {
total_commits: commits.len(),
..Default::default()
};
let chunk_size = if checkpoint_every > 0 {
checkpoint_every
} else {
commits.len().max(1) };
let pb = make_progress(commits.len() as u64, "Writing results");
for (chunk_commits, chunk_results) in commits.chunks(chunk_size).zip(results.chunks(chunk_size))
{
write_results_chunk(db, chunk_commits, chunk_results, &mut stats, &pb)?;
if checkpoint_every > 0 && chunk_commits.len() == chunk_size {
if let Err(e) = db.wal_checkpoint(CheckpointMode::Passive) {
warn!(error = %e, "periodic WAL PASSIVE checkpoint failed (non-fatal)");
}
}
}
pb.finish_and_clear();
if stats.classified < stats.total_commits {
warn!(
unclassified = stats.total_commits - stats.classified,
"some commits remained uncategorized"
);
}
Ok(stats)
}
fn write_results_chunk(
db: &mut Database,
commits: &[CommitRow],
results: &[ClassificationResult],
stats: &mut ClassificationStats,
pb: &indicatif::ProgressBar,
) -> Result<()> {
let conn = db.connection_mut();
let tx = conn.transaction().map_err(crate::core::TgaError::from)?;
{
let mut insert_classification = tx
.prepare(
"INSERT INTO classifications \
(category, subcategory, ticket_id, confidence, method, complexity) \
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
)
.map_err(crate::core::TgaError::from)?;
let mut update_existing_classification = tx
.prepare(
"UPDATE classifications \
SET category = ?1, subcategory = ?2, ticket_id = ?3, \
confidence = ?4, method = ?5, complexity = ?6 \
WHERE id = ?7",
)
.map_err(crate::core::TgaError::from)?;
let mut update_commit = tx
.prepare(
"UPDATE commits SET classification_id = ?1, confidence = ?2, is_revert = ?3 \
WHERE id = ?4",
)
.map_err(crate::core::TgaError::from)?;
for (commit, result) in commits.iter().zip(results.iter()) {
let classification_id = if let Some(existing) = commit.existing_classification_id {
update_existing_classification
.execute(params![
result.category,
result.subcategory,
result.ticket_id,
result.confidence,
result.method.as_str(),
result.complexity.map(|v| v as i64),
existing,
])
.map_err(crate::core::TgaError::from)?;
existing
} else {
insert_classification
.execute(params![
result.category,
result.subcategory,
result.ticket_id,
result.confidence,
result.method.as_str(),
result.complexity.map(|v| v as i64),
])
.map_err(crate::core::TgaError::from)?;
tx.last_insert_rowid()
};
let is_revert = is_revert_verdict(&result.category, result.subcategory.as_deref());
update_commit
.execute(params![
classification_id,
result.confidence,
if is_revert { 1_i64 } else { 0_i64 },
commit.id,
])
.map_err(crate::core::TgaError::from)?;
let is_classified = !result.category.is_empty()
&& result.category != "uncategorized"
&& result.confidence > 0.0;
if is_classified {
stats.classified += 1;
}
let repo_entry = stats
.coverage_by_repo
.entry(commit.repository.clone())
.or_default();
repo_entry.total += 1;
if is_classified {
repo_entry.classified += 1;
}
*stats
.by_method
.entry(result.method.as_str().to_string())
.or_insert(0) += 1;
*stats
.by_category
.entry(result.category.clone())
.or_insert(0) += 1;
pb.inc(1);
}
}
tx.commit().map_err(crate::core::TgaError::from)?;
Ok(())
}
fn make_progress(len: u64, label: &str) -> ProgressBar {
let pb = ProgressBar::new(len);
if let Ok(style) =
ProgressStyle::with_template("{prefix:.bold} [{bar:40.cyan/blue}] {pos}/{len} ({percent}%)")
{
pb.set_style(style.progress_chars("##-"));
}
pb.set_prefix(label.to_string());
pb
}
#[cfg(test)]
mod tests {
use super::*;
use crate::classify::classifier::{ClassificationEngine, ClassificationEngineConfig};
use crate::classify::rules::default_rules;
use crate::core::config::Config;
use rusqlite::params;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
fn engine_with_mock_llm(endpoint: &str) -> ClassificationEngine {
let cfg = ClassificationEngineConfig {
use_llm: true,
..ClassificationEngineConfig::default()
};
ClassificationEngine::new(default_rules(), cfg)
.expect("build engine")
.with_test_llm_endpoint(endpoint)
}
async fn mock_llm_server(category: &str, confidence: f64, complexity: u8) -> MockServer {
let server = MockServer::start().await;
let body = serde_json::json!({
"choices": [{
"message": {
"content": format!(
"{{\"category\":\"{category}\",\"subcategory\":null,\
\"confidence\":{confidence},\"complexity\":{complexity}}}"
)
}
}]
});
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(ResponseTemplate::new(200).set_body_json(body))
.mount(&server)
.await;
server
}
fn insert_commit(db: &Database, sha: &str, message: &str) -> i64 {
db.connection()
.execute(
"INSERT INTO commits \
(sha, author_name, author_email, timestamp, message, repository) \
VALUES (?1, 'a', 'a@x', '2024-01-01T00:00:00Z', ?2, 'acme/widgets')",
params![sha, message],
)
.expect("insert commit");
db.connection().last_insert_rowid()
}
#[tokio::test]
async fn pipeline_force_reclassifies_existing_rows_in_place() {
let mut db = Database::open_in_memory().expect("db");
db.connection()
.execute(
"INSERT INTO classifications (category, subcategory, confidence, method) \
VALUES ('feature', NULL, 0.5, 'regex_rule')",
[],
)
.expect("insert classification");
let pre_cls_id = db.connection().last_insert_rowid();
let commit_id = insert_commit(&db, "sha-fix-1", "fix: handle null user");
db.connection()
.execute(
"UPDATE commits SET classification_id = ?1 WHERE id = ?2",
params![pre_cls_id, commit_id],
)
.expect("link cls");
let pipeline_no_force = ClassificationPipeline::new(Config::default());
let engine =
ClassificationEngine::new(default_rules(), ClassificationEngineConfig::default())
.expect("engine");
pipeline_no_force
.run_with_engine(&mut db, engine)
.await
.expect("default run");
let still_feature: String = db
.connection()
.query_row(
"SELECT cl.category FROM classifications cl \
JOIN commits c ON c.classification_id = cl.id WHERE c.sha = 'sha-fix-1'",
[],
|row| row.get(0),
)
.expect("query 1");
assert_eq!(
still_feature, "feature",
"default flow must NOT re-classify already-classified commits"
);
let pipeline_forced = ClassificationPipeline::new(Config::default()).with_force(true);
let engine =
ClassificationEngine::new(default_rules(), ClassificationEngineConfig::default())
.expect("engine");
pipeline_forced
.run_with_engine(&mut db, engine)
.await
.expect("force run");
let (new_cat, new_cls_id): (String, i64) = db
.connection()
.query_row(
"SELECT cl.category, cl.id FROM classifications cl \
JOIN commits c ON c.classification_id = cl.id WHERE c.sha = 'sha-fix-1'",
[],
|row| Ok((row.get(0)?, row.get(1)?)),
)
.expect("query 2");
assert_eq!(new_cat, "bugfix");
assert_eq!(
new_cls_id, pre_cls_id,
"force must update in place, not orphan"
);
let total_rows: i64 = db
.connection()
.query_row("SELECT COUNT(*) FROM classifications", [], |row| row.get(0))
.expect("count");
assert_eq!(
total_rows, 1,
"force must not duplicate the classifications row"
);
}
#[tokio::test]
async fn pipeline_force_since_bounds_rewrite_window() {
let mut db = Database::open_in_memory().expect("db");
for (sha, ts) in [
("sha-old", "2023-06-01T00:00:00Z"),
("sha-new", "2025-06-01T00:00:00Z"),
] {
db.connection()
.execute(
"INSERT INTO classifications (category, confidence, method) \
VALUES ('feature', 0.5, 'regex_rule')",
[],
)
.expect("insert cls");
let cls_id = db.connection().last_insert_rowid();
db.connection()
.execute(
"INSERT INTO commits (sha, author_name, author_email, timestamp, message, repository, classification_id) \
VALUES (?1, 'a', 'a@x', ?2, 'fix: handle null user', 'r', ?3)",
params![sha, ts, cls_id],
)
.expect("insert commit");
}
let pipeline = ClassificationPipeline::new(Config::default())
.with_force(true)
.with_since(Some("2025-01-01".to_string()));
let engine =
ClassificationEngine::new(default_rules(), ClassificationEngineConfig::default())
.expect("engine");
pipeline
.run_with_engine(&mut db, engine)
.await
.expect("force+since");
let new_cat: String = db
.connection()
.query_row(
"SELECT cl.category FROM classifications cl \
JOIN commits c ON c.classification_id = cl.id WHERE c.sha = 'sha-new'",
[],
|row| row.get(0),
)
.expect("query new");
assert_eq!(new_cat, "bugfix");
let old_cat: String = db
.connection()
.query_row(
"SELECT cl.category FROM classifications cl \
JOIN commits c ON c.classification_id = cl.id WHERE c.sha = 'sha-old'",
[],
|row| row.get(0),
)
.expect("query old");
assert_eq!(
old_cat, "feature",
"--since must exclude commits older than the bound"
);
}
#[test]
fn read_candidate_commits_branches_select_correctly() {
let db = Database::open_in_memory().expect("db");
db.connection()
.execute(
"INSERT INTO classifications (category, confidence, method) \
VALUES ('feature', 0.5, 'regex_rule')",
[],
)
.expect("cls");
let cls_id = db.connection().last_insert_rowid();
db.connection()
.execute(
"INSERT INTO commits (sha, author_name, author_email, timestamp, message, repository, classification_id) \
VALUES ('old', 'a', 'a@x', '2023-01-01T00:00:00Z', 'm', 'r', ?1)",
params![cls_id],
)
.expect("insert old");
db.connection()
.execute(
"INSERT INTO commits (sha, author_name, author_email, timestamp, message, repository, classification_id) \
VALUES ('new', 'a', 'a@x', '2025-06-01T00:00:00Z', 'm', 'r', ?1)",
params![cls_id],
)
.expect("insert new");
db.connection()
.execute(
"INSERT INTO commits (sha, author_name, author_email, timestamp, message, repository) \
VALUES ('null', 'a', 'a@x', '2024-01-01T00:00:00Z', 'm', 'r')",
[],
)
.expect("insert unclassified");
let v = super::read_candidate_commits(&db, false, None, None, &[]).expect("default");
let shas: Vec<&str> = v.iter().map(|c| c.sha.as_str()).collect();
assert_eq!(shas, vec!["null"]);
let v = super::read_candidate_commits(&db, true, None, None, &[]).expect("force");
let mut shas: Vec<&str> = v.iter().map(|c| c.sha.as_str()).collect();
shas.sort();
assert_eq!(shas, vec!["new", "null", "old"]);
let v = super::read_candidate_commits(&db, true, Some("2025-01-01"), None, &[])
.expect("force+since");
let shas: Vec<&str> = v.iter().map(|c| c.sha.as_str()).collect();
assert_eq!(shas, vec!["new"]);
}
#[tokio::test]
async fn pipeline_sets_is_revert_for_revert_verdicts() {
let mut db = Database::open_in_memory().expect("db");
let revert_id = insert_commit(&db, "sha-revert", "Revert \"feat: add login\"");
let feature_id = insert_commit(&db, "sha-feat", "feat: add login form");
let config = Config::default();
let pipeline = ClassificationPipeline::new(config);
let engine =
ClassificationEngine::new(default_rules(), ClassificationEngineConfig::default())
.expect("engine builds");
pipeline
.run_with_engine(&mut db, engine)
.await
.expect("run pipeline");
let revert_flag: i64 = db
.connection()
.query_row(
"SELECT is_revert FROM commits WHERE id = ?1",
params![revert_id],
|row| row.get(0),
)
.expect("query revert");
assert_eq!(
revert_flag, 1,
"revert verdict must set commits.is_revert=1"
);
let feat_flag: i64 = db
.connection()
.query_row(
"SELECT is_revert FROM commits WHERE id = ?1",
params![feature_id],
|row| row.get(0),
)
.expect("query feature");
assert_eq!(
feat_flag, 0,
"non-revert verdict must leave commits.is_revert at 0"
);
}
#[test]
fn is_revert_verdict_recognizes_canonical_markers() {
assert!(super::is_revert_verdict("revert", None));
assert!(super::is_revert_verdict("Revert", None)); assert!(super::is_revert_verdict("rollback", None));
assert!(super::is_revert_verdict("ROLLBACK", None));
assert!(super::is_revert_verdict("merge", Some("revert")));
assert!(super::is_revert_verdict("merge", Some("rollback")));
assert!(!super::is_revert_verdict("feature", None));
assert!(!super::is_revert_verdict("bugfix", Some("hotfix")));
assert!(!super::is_revert_verdict("reverted", None));
}
#[tokio::test]
async fn pipeline_writes_complexity_to_db() {
let server = mock_llm_server("feature", 0.9, 2).await;
let endpoint = format!("{}/v1/chat/completions", server.uri());
let mut db = Database::open_in_memory().expect("db");
insert_commit(&db, "sha-a", "zzz qqq vvv www yyy uuu");
let classification = crate::core::config::ClassificationConfig {
use_llm: true,
llm_fallback_threshold: 1.0,
..crate::core::config::ClassificationConfig::default()
};
let config = Config {
classification: Some(classification),
..Config::default()
};
let pipeline = ClassificationPipeline::new(config);
let engine = engine_with_mock_llm(&endpoint);
pipeline
.run_with_engine(&mut db, engine)
.await
.expect("run pipeline");
let complexity: Option<i64> = db
.connection()
.query_row(
"SELECT cl.complexity FROM classifications cl \
JOIN commits c ON c.classification_id = cl.id \
WHERE c.sha = 'sha-a'",
[],
|row| row.get(0),
)
.expect("query complexity");
assert_eq!(complexity, Some(2));
}
#[tokio::test]
async fn backfill_complexity_updates_only_null_rows() {
let server = mock_llm_server("feature", 0.9, 4).await;
let endpoint = format!("{}/v1/chat/completions", server.uri());
let mut db = Database::open_in_memory().expect("db");
db.connection()
.execute(
"INSERT INTO classifications (category, confidence, method, complexity) \
VALUES ('feature', 0.5, 'regex_rule', NULL)",
[],
)
.expect("insert cl 1");
let cl1 = db.connection().last_insert_rowid();
let c1 = insert_commit(&db, "sha-null", "needs scoring");
db.connection()
.execute(
"UPDATE commits SET classification_id = ?1 WHERE id = ?2",
params![cl1, c1],
)
.expect("link 1");
db.connection()
.execute(
"INSERT INTO classifications (category, confidence, method, complexity) \
VALUES ('bugfix', 0.8, 'regex_rule', 3)",
[],
)
.expect("insert cl 2");
let cl2 = db.connection().last_insert_rowid();
let c2 = insert_commit(&db, "sha-scored", "already scored");
db.connection()
.execute(
"UPDATE commits SET classification_id = ?1 WHERE id = ?2",
params![cl2, c2],
)
.expect("link 2");
let engine = engine_with_mock_llm(&endpoint);
let updated = ClassificationPipeline::backfill_complexity_with_engine(&mut db, &engine)
.await
.expect("backfill");
assert_eq!(updated, 1, "only the NULL row should be updated");
let filled: Option<i64> = db
.connection()
.query_row(
"SELECT complexity FROM classifications WHERE id = ?1",
params![cl1],
|row| row.get(0),
)
.expect("query filled");
assert_eq!(filled, Some(4), "NULL row backfilled to the LLM score");
let unchanged: Option<i64> = db
.connection()
.query_row(
"SELECT complexity FROM classifications WHERE id = ?1",
params![cl2],
|row| row.get(0),
)
.expect("query unchanged");
assert_eq!(unchanged, Some(3), "already-scored row must be unchanged");
}
}