use std::sync::{Arc, Mutex};
use rusqlite::{params, Connection, OptionalExtension};
use tracing::warn;
use crate::classify::taxonomy::TaxonomyRegistry;
use crate::classify::tiers::ClassificationResult;
use crate::core::models::ClassificationMethod;
pub struct OverrideTier {
conn: Arc<Mutex<Connection>>,
taxonomy: TaxonomyRegistry,
}
impl OverrideTier {
pub fn new(conn: Arc<Mutex<Connection>>) -> Self {
Self {
conn,
taxonomy: TaxonomyRegistry::with_builtins(),
}
}
pub fn with_taxonomy(conn: Arc<Mutex<Connection>>, taxonomy: TaxonomyRegistry) -> Self {
Self { conn, taxonomy }
}
pub fn lookup(&self, commit_sha: &str, repo_path: &str) -> Option<ClassificationResult> {
let guard = match self.conn.lock() {
Ok(g) => g,
Err(e) => {
warn!(error = %e, "override tier mutex poisoned");
return None;
}
};
let row = guard
.query_row(
"SELECT work_type, change_type FROM classification_overrides \
WHERE commit_sha = ?1 AND repo_path = ?2",
params![commit_sha, repo_path],
|row| {
let work_type: String = row.get(0)?;
let change_type: String = row.get(1)?;
Ok((work_type, change_type))
},
)
.optional();
match row {
Ok(Some((work_type, change_type))) => {
let top_level = self
.taxonomy
.resolve(&change_type)
.or_else(|| self.taxonomy.resolve(&work_type));
Some(ClassificationResult {
category: work_type,
subcategory: Some(change_type),
top_level,
confidence: 1.0,
method: ClassificationMethod::Manual,
ticket_id: None,
complexity: None,
})
}
Ok(None) => None,
Err(e) => {
warn!(error = %e, commit_sha, "override lookup failed");
None
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn fresh_conn() -> Arc<Mutex<Connection>> {
let mut conn = Connection::open_in_memory().expect("open in-memory db");
crate::core::db::migrations::run(&mut conn).expect("run migrations");
Arc::new(Mutex::new(conn))
}
#[test]
fn lookup_returns_result_on_hit() {
let conn = fresh_conn();
conn.lock()
.expect("lock")
.execute(
"INSERT INTO classification_overrides \
(commit_sha, repo_path, work_type, change_type) \
VALUES (?1, ?2, ?3, ?4)",
params!["abc123", "/tmp/repo", "feature", "feature"],
)
.expect("insert override");
let tier = OverrideTier::new(conn);
let r = tier.lookup("abc123", "/tmp/repo").expect("hit");
assert_eq!(r.category, "feature");
assert_eq!(r.subcategory.as_deref(), Some("feature"));
assert!((r.confidence - 1.0).abs() < 1e-9);
assert_eq!(r.method, ClassificationMethod::Manual);
}
#[test]
fn lookup_returns_none_on_miss() {
let conn = fresh_conn();
let tier = OverrideTier::new(conn);
assert!(tier.lookup("missing", "/tmp/repo").is_none());
}
#[test]
fn lookup_different_repo_misses() {
let conn = fresh_conn();
conn.lock()
.expect("lock")
.execute(
"INSERT INTO classification_overrides \
(commit_sha, repo_path, work_type, change_type) \
VALUES (?1, ?2, ?3, ?4)",
params!["sha1", "/repo/a", "bugfix", "bugfix"],
)
.expect("insert");
let tier = OverrideTier::new(conn);
assert!(tier.lookup("sha1", "/repo/b").is_none());
assert!(tier.lookup("sha1", "/repo/a").is_some());
}
}