use super::*;
use crate::classify::classifier::{ClassificationEngine, ClassificationEngineConfig};
use crate::classify::rules::default_rules;
use crate::core::config::Config;
use crate::core::db::Database;
use rusqlite::params;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
fn engine_with_mock_llm(endpoint: &str) -> ClassificationEngine {
let cfg = ClassificationEngineConfig {
use_llm: true,
..ClassificationEngineConfig::default()
};
ClassificationEngine::new(default_rules(), cfg)
.expect("build engine")
.with_test_llm_endpoint(endpoint)
}
async fn mock_llm_server(category: &str, confidence: f64, complexity: u8) -> MockServer {
let server = MockServer::start().await;
let body = serde_json::json!({
"choices": [{
"message": {
"content": format!(
"{{\"category\":\"{category}\",\"subcategory\":null,\
\"confidence\":{confidence},\"complexity\":{complexity}}}"
)
}
}]
});
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(ResponseTemplate::new(200).set_body_json(body))
.mount(&server)
.await;
server
}
fn insert_commit(db: &Database, sha: &str, message: &str) -> i64 {
db.connection()
.execute(
"INSERT INTO commits \
(sha, author_name, author_email, timestamp, message, repository) \
VALUES (?1, 'a', 'a@x', '2024-01-01T00:00:00Z', ?2, 'acme/widgets')",
params![sha, message],
)
.expect("insert commit");
db.connection().last_insert_rowid()
}
#[tokio::test]
async fn pipeline_force_reclassifies_existing_rows_in_place() {
let mut db = Database::open_in_memory().expect("db");
db.connection()
.execute(
"INSERT INTO classifications (category, subcategory, confidence, method) \
VALUES ('feature', NULL, 0.5, 'regex_rule')",
[],
)
.expect("insert classification");
let pre_cls_id = db.connection().last_insert_rowid();
let commit_id = insert_commit(&db, "sha-fix-1", "fix: handle null user");
db.connection()
.execute(
"UPDATE commits SET classification_id = ?1 WHERE id = ?2",
params![pre_cls_id, commit_id],
)
.expect("link cls");
let pipeline_no_force = ClassificationPipeline::new(Config::default());
let engine = ClassificationEngine::new(default_rules(), ClassificationEngineConfig::default())
.expect("engine");
pipeline_no_force
.run_with_engine(&mut db, engine)
.await
.expect("default run");
let still_feature: String = db
.connection()
.query_row(
"SELECT cl.category FROM classifications cl \
JOIN commits c ON c.classification_id = cl.id WHERE c.sha = 'sha-fix-1'",
[],
|row| row.get(0),
)
.expect("query 1");
assert_eq!(
still_feature, "feature",
"default flow must NOT re-classify already-classified commits"
);
let pipeline_forced = ClassificationPipeline::new(Config::default()).with_force(true);
let engine = ClassificationEngine::new(default_rules(), ClassificationEngineConfig::default())
.expect("engine");
pipeline_forced
.run_with_engine(&mut db, engine)
.await
.expect("force run");
let (new_cat, new_cls_id): (String, i64) = db
.connection()
.query_row(
"SELECT cl.category, cl.id FROM classifications cl \
JOIN commits c ON c.classification_id = cl.id WHERE c.sha = 'sha-fix-1'",
[],
|row| Ok((row.get(0)?, row.get(1)?)),
)
.expect("query 2");
assert_eq!(new_cat, "bugfix");
assert_eq!(
new_cls_id, pre_cls_id,
"force must update in place, not orphan"
);
let total_rows: i64 = db
.connection()
.query_row("SELECT COUNT(*) FROM classifications", [], |row| row.get(0))
.expect("count");
assert_eq!(
total_rows, 1,
"force must not duplicate the classifications row"
);
}
#[tokio::test]
async fn pipeline_force_since_bounds_rewrite_window() {
let mut db = Database::open_in_memory().expect("db");
for (sha, ts) in [
("sha-old", "2023-06-01T00:00:00Z"),
("sha-new", "2025-06-01T00:00:00Z"),
] {
db.connection()
.execute(
"INSERT INTO classifications (category, confidence, method) \
VALUES ('feature', 0.5, 'regex_rule')",
[],
)
.expect("insert cls");
let cls_id = db.connection().last_insert_rowid();
db.connection()
.execute(
"INSERT INTO commits (sha, author_name, author_email, timestamp, message, repository, classification_id) \
VALUES (?1, 'a', 'a@x', ?2, 'fix: handle null user', 'r', ?3)",
params![sha, ts, cls_id],
)
.expect("insert commit");
}
let pipeline = ClassificationPipeline::new(Config::default())
.with_force(true)
.with_since(Some("2025-01-01".to_string()));
let engine = ClassificationEngine::new(default_rules(), ClassificationEngineConfig::default())
.expect("engine");
pipeline
.run_with_engine(&mut db, engine)
.await
.expect("force+since");
let new_cat: String = db
.connection()
.query_row(
"SELECT cl.category FROM classifications cl \
JOIN commits c ON c.classification_id = cl.id WHERE c.sha = 'sha-new'",
[],
|row| row.get(0),
)
.expect("query new");
assert_eq!(new_cat, "bugfix");
let old_cat: String = db
.connection()
.query_row(
"SELECT cl.category FROM classifications cl \
JOIN commits c ON c.classification_id = cl.id WHERE c.sha = 'sha-old'",
[],
|row| row.get(0),
)
.expect("query old");
assert_eq!(
old_cat, "feature",
"--since must exclude commits older than the bound"
);
}
#[test]
fn read_candidate_commits_branches_select_correctly() {
let db = Database::open_in_memory().expect("db");
db.connection()
.execute(
"INSERT INTO classifications (category, confidence, method) \
VALUES ('feature', 0.5, 'regex_rule')",
[],
)
.expect("cls");
let cls_id = db.connection().last_insert_rowid();
db.connection()
.execute(
"INSERT INTO commits (sha, author_name, author_email, timestamp, message, repository, classification_id) \
VALUES ('old', 'a', 'a@x', '2023-01-01T00:00:00Z', 'm', 'r', ?1)",
params![cls_id],
)
.expect("insert old");
db.connection()
.execute(
"INSERT INTO commits (sha, author_name, author_email, timestamp, message, repository, classification_id) \
VALUES ('new', 'a', 'a@x', '2025-06-01T00:00:00Z', 'm', 'r', ?1)",
params![cls_id],
)
.expect("insert new");
db.connection()
.execute(
"INSERT INTO commits (sha, author_name, author_email, timestamp, message, repository) \
VALUES ('null', 'a', 'a@x', '2024-01-01T00:00:00Z', 'm', 'r')",
[],
)
.expect("insert unclassified");
let v =
super::pipeline_db::read_candidate_commits(&db, false, None, None, &[]).expect("default");
let shas: Vec<&str> = v.iter().map(|c| c.sha.as_str()).collect();
assert_eq!(shas, vec!["null"]);
let v = super::pipeline_db::read_candidate_commits(&db, true, None, None, &[]).expect("force");
let mut shas: Vec<&str> = v.iter().map(|c| c.sha.as_str()).collect();
shas.sort();
assert_eq!(shas, vec!["new", "null", "old"]);
let v = super::pipeline_db::read_candidate_commits(&db, true, Some("2025-01-01"), None, &[])
.expect("force+since");
let shas: Vec<&str> = v.iter().map(|c| c.sha.as_str()).collect();
assert_eq!(shas, vec!["new"]);
}
#[tokio::test]
async fn pipeline_sets_is_revert_for_revert_verdicts() {
let mut db = Database::open_in_memory().expect("db");
let revert_id = insert_commit(&db, "sha-revert", "Revert \"feat: add login\"");
let feature_id = insert_commit(&db, "sha-feat", "feat: add login form");
let config = Config::default();
let pipeline = ClassificationPipeline::new(config);
let engine = ClassificationEngine::new(default_rules(), ClassificationEngineConfig::default())
.expect("engine builds");
pipeline
.run_with_engine(&mut db, engine)
.await
.expect("run pipeline");
let revert_flag: i64 = db
.connection()
.query_row(
"SELECT is_revert FROM commits WHERE id = ?1",
params![revert_id],
|row| row.get(0),
)
.expect("query revert");
assert_eq!(
revert_flag, 1,
"revert verdict must set commits.is_revert=1"
);
let feat_flag: i64 = db
.connection()
.query_row(
"SELECT is_revert FROM commits WHERE id = ?1",
params![feature_id],
|row| row.get(0),
)
.expect("query feature");
assert_eq!(
feat_flag, 0,
"non-revert verdict must leave commits.is_revert at 0"
);
}
#[test]
fn is_revert_verdict_recognizes_canonical_markers() {
assert!(super::pipeline_db::is_revert_verdict("revert", None));
assert!(super::pipeline_db::is_revert_verdict("Revert", None)); assert!(super::pipeline_db::is_revert_verdict("rollback", None));
assert!(super::pipeline_db::is_revert_verdict("ROLLBACK", None));
assert!(super::pipeline_db::is_revert_verdict(
"merge",
Some("revert")
));
assert!(super::pipeline_db::is_revert_verdict(
"merge",
Some("rollback")
));
assert!(!super::pipeline_db::is_revert_verdict("feature", None));
assert!(!super::pipeline_db::is_revert_verdict(
"bugfix",
Some("hotfix")
));
assert!(!super::pipeline_db::is_revert_verdict("reverted", None));
}
#[tokio::test]
async fn pipeline_writes_complexity_to_db() {
let server = mock_llm_server("feature", 0.9, 2).await;
let endpoint = format!("{}/v1/chat/completions", server.uri());
let mut db = Database::open_in_memory().expect("db");
insert_commit(&db, "sha-a", "zzz qqq vvv www yyy uuu");
let classification = crate::core::config::ClassificationConfig {
use_llm: true,
llm_fallback_threshold: 1.0,
..crate::core::config::ClassificationConfig::default()
};
let config = Config {
classification: Some(classification),
..Config::default()
};
let pipeline = ClassificationPipeline::new(config);
let engine = engine_with_mock_llm(&endpoint);
pipeline
.run_with_engine(&mut db, engine)
.await
.expect("run pipeline");
let complexity: Option<i64> = db
.connection()
.query_row(
"SELECT cl.complexity FROM classifications cl \
JOIN commits c ON c.classification_id = cl.id \
WHERE c.sha = 'sha-a'",
[],
|row| row.get(0),
)
.expect("query complexity");
assert_eq!(complexity, Some(2));
}
#[tokio::test]
async fn backfill_complexity_updates_only_null_rows() {
let server = mock_llm_server("feature", 0.9, 4).await;
let endpoint = format!("{}/v1/chat/completions", server.uri());
let mut db = Database::open_in_memory().expect("db");
db.connection()
.execute(
"INSERT INTO classifications (category, confidence, method, complexity) \
VALUES ('feature', 0.5, 'regex_rule', NULL)",
[],
)
.expect("insert cl 1");
let cl1 = db.connection().last_insert_rowid();
let c1 = insert_commit(&db, "sha-null", "needs scoring");
db.connection()
.execute(
"UPDATE commits SET classification_id = ?1 WHERE id = ?2",
params![cl1, c1],
)
.expect("link 1");
db.connection()
.execute(
"INSERT INTO classifications (category, confidence, method, complexity) \
VALUES ('bugfix', 0.8, 'regex_rule', 3)",
[],
)
.expect("insert cl 2");
let cl2 = db.connection().last_insert_rowid();
let c2 = insert_commit(&db, "sha-scored", "already scored");
db.connection()
.execute(
"UPDATE commits SET classification_id = ?1 WHERE id = ?2",
params![cl2, c2],
)
.expect("link 2");
let engine = engine_with_mock_llm(&endpoint);
let updated = ClassificationPipeline::backfill_complexity_with_engine(&mut db, &engine)
.await
.expect("backfill");
assert_eq!(updated, 1, "only the NULL row should be updated");
let filled: Option<i64> = db
.connection()
.query_row(
"SELECT complexity FROM classifications WHERE id = ?1",
params![cl1],
|row| row.get(0),
)
.expect("query filled");
assert_eq!(filled, Some(4), "NULL row backfilled to the LLM score");
let unchanged: Option<i64> = db
.connection()
.query_row(
"SELECT complexity FROM classifications WHERE id = ?1",
params![cl2],
|row| row.get(0),
)
.expect("query unchanged");
assert_eq!(unchanged, Some(3), "already-scored row must be unchanged");
}
#[test]
fn source_aware_default_model_selection() {
use crate::classify::tiers::bedrock::DEFAULT_BEDROCK_MODEL;
use crate::classify::tiers::llm::ANTHROPIC_DEFAULT_MODEL;
use crate::core::config::{LlmConfig, LlmSource};
fn default_for(source: LlmSource) -> &'static str {
match source {
LlmSource::Bedrock => DEFAULT_BEDROCK_MODEL,
LlmSource::AnthropicApi => ANTHROPIC_DEFAULT_MODEL,
LlmSource::Openrouter => "gpt-4o-mini",
}
}
let bedrock_cfg = LlmConfig {
source: LlmSource::Bedrock,
model: None,
..LlmConfig::default()
};
let resolved = bedrock_cfg
.model
.as_deref()
.unwrap_or_else(|| default_for(bedrock_cfg.source.clone()));
assert_eq!(
resolved, DEFAULT_BEDROCK_MODEL,
"bedrock source with no model must fall back to DEFAULT_BEDROCK_MODEL"
);
let anthropic_cfg = LlmConfig {
source: LlmSource::AnthropicApi,
model: None,
..LlmConfig::default()
};
let resolved = anthropic_cfg
.model
.as_deref()
.unwrap_or_else(|| default_for(anthropic_cfg.source.clone()));
assert_eq!(
resolved, ANTHROPIC_DEFAULT_MODEL,
"anthropic-api source with no model must fall back to ANTHROPIC_DEFAULT_MODEL"
);
let openrouter_cfg = LlmConfig {
source: LlmSource::Openrouter,
model: None,
..LlmConfig::default()
};
let resolved = openrouter_cfg
.model
.as_deref()
.unwrap_or_else(|| default_for(openrouter_cfg.source.clone()));
assert_eq!(
resolved, "gpt-4o-mini",
"openrouter source with no model must fall back to gpt-4o-mini"
);
for source in [
LlmSource::Bedrock,
LlmSource::AnthropicApi,
LlmSource::Openrouter,
] {
let explicit_cfg = LlmConfig {
source: source.clone(),
model: Some("my-custom-model".to_string()),
..LlmConfig::default()
};
let resolved = explicit_cfg
.model
.as_deref()
.unwrap_or_else(|| default_for(source));
assert_eq!(
resolved, "my-custom-model",
"explicit model must override the source default"
);
}
}