use std::io::{self, BufRead, Write};
use clap::{Args, Subcommand};
use rusqlite::params;
use tga::core::config::Config;
use tga::core::db::Database;
#[derive(Args, Debug)]
pub struct OverrideArgs {
#[command(subcommand)]
pub subcommand: OverrideSubcommand,
}
#[derive(Subcommand, Debug)]
pub enum OverrideSubcommand {
Add {
sha: String,
work_type: String,
change_type: String,
#[arg(long)]
notes: Option<String>,
#[arg(long)]
repo: Option<String>,
},
List {
#[arg(long)]
repo: Option<String>,
},
Remove {
sha: String,
#[arg(long, default_value_t = false)]
yes: bool,
},
}
pub fn run(_config: Config, db: &mut Database, args: OverrideArgs) -> anyhow::Result<()> {
match args.subcommand {
OverrideSubcommand::Add {
sha,
work_type,
change_type,
notes,
repo,
} => add(
db,
&sha,
&work_type,
&change_type,
notes.as_deref(),
repo.as_deref(),
),
OverrideSubcommand::List { repo } => list(db, repo.as_deref()),
OverrideSubcommand::Remove { sha, yes } => remove(db, &sha, yes, &mut io::stdin().lock()),
}
}
fn add(
db: &Database,
sha: &str,
work_type: &str,
change_type: &str,
notes: Option<&str>,
repo: Option<&str>,
) -> anyhow::Result<()> {
let conn = db.connection();
let repo_path: String = match repo {
Some(r) => r.to_string(),
None => {
let mut stmt = conn.prepare("SELECT repository FROM commits WHERE sha = ?1")?;
let mut rows = stmt.query(params![sha])?;
match rows.next()? {
Some(row) => row.get::<_, String>(0)?,
None => {
anyhow::bail!(
"no commit with sha {sha} found in DB and --repo not provided; \
pass --repo <path> to add an override for an unknown SHA"
)
}
}
}
};
let current = lookup_current_classification(db, sha)?;
match current {
Some((cat, sub, conf, method)) => {
println!(
"Current classification for {sha}: category={cat}, subcategory={}, \
confidence={conf:.2}, method={method}",
sub.as_deref().unwrap_or("-")
);
}
None => println!("(no existing classification for {sha} — adding fresh override)"),
}
conn.execute(
"INSERT OR REPLACE INTO classification_overrides \
(commit_sha, repo_path, work_type, change_type, notes) \
VALUES (?1, ?2, ?3, ?4, ?5)",
params![sha, repo_path, work_type, change_type, notes],
)?;
println!(
"Override added: {sha} ({repo_path}) -> work_type={work_type}, change_type={change_type}"
);
Ok(())
}
fn list(db: &Database, repo: Option<&str>) -> anyhow::Result<()> {
let conn = db.connection();
let (sql, bind): (&str, Vec<&str>) = match repo {
Some(r) => (
"SELECT commit_sha, repo_path, work_type, change_type, notes, created_at \
FROM classification_overrides WHERE repo_path = ?1 ORDER BY created_at DESC",
vec![r],
),
None => (
"SELECT commit_sha, repo_path, work_type, change_type, notes, created_at \
FROM classification_overrides ORDER BY created_at DESC",
vec![],
),
};
let mut stmt = conn.prepare(sql)?;
let rows = stmt.query_map(rusqlite::params_from_iter(bind.iter()), |row| {
Ok((
row.get::<_, String>(0)?,
row.get::<_, String>(1)?,
row.get::<_, String>(2)?,
row.get::<_, String>(3)?,
row.get::<_, Option<String>>(4)?,
row.get::<_, String>(5)?,
))
})?;
println!(
"{:<10} {:<32} {:<14} {:<14} {:<20} notes",
"sha", "repo_path", "work_type", "change_type", "created_at"
);
println!("{}", "-".repeat(110));
let mut count = 0usize;
for r in rows {
let (sha, repo_path, work_type, change_type, notes, created_at) = r?;
let short_sha = sha.chars().take(8).collect::<String>();
let note_str = notes.unwrap_or_else(|| "-".to_string());
println!(
"{:<10} {:<32} {:<14} {:<14} {:<20} {}",
short_sha, repo_path, work_type, change_type, created_at, note_str
);
count += 1;
}
if count == 0 {
println!("(no overrides found)");
}
Ok(())
}
fn remove<R: BufRead>(
db: &mut Database,
sha: &str,
skip_confirm: bool,
reader: &mut R,
) -> anyhow::Result<()> {
let conn = db.connection();
let n: i64 = conn.query_row(
"SELECT COUNT(*) FROM classification_overrides WHERE commit_sha = ?1",
params![sha],
|row| row.get(0),
)?;
if n == 0 {
println!("No override exists for {sha}.");
return Ok(());
}
if !skip_confirm {
print!("Delete {n} override row(s) for {sha}? [y/N] ");
io::stdout().flush()?;
let mut line = String::new();
reader.read_line(&mut line)?;
if !matches!(line.trim().to_lowercase().as_str(), "y" | "yes") {
println!("Aborted.");
return Ok(());
}
}
let deleted = conn.execute(
"DELETE FROM classification_overrides WHERE commit_sha = ?1",
params![sha],
)?;
println!("Removed {deleted} override row(s) for {sha}.");
Ok(())
}
type CurrentClassification = (String, Option<String>, f64, String);
fn lookup_current_classification(
db: &Database,
sha: &str,
) -> anyhow::Result<Option<CurrentClassification>> {
let conn = db.connection();
let mut stmt = conn.prepare(
"SELECT c.category, c.subcategory, c.confidence, c.method \
FROM commits cm \
JOIN classifications c ON cm.classification_id = c.id \
WHERE cm.sha = ?1",
)?;
let mut rows = stmt.query(params![sha])?;
if let Some(row) = rows.next()? {
Ok(Some((row.get(0)?, row.get(1)?, row.get(2)?, row.get(3)?)))
} else {
Ok(None)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn fresh_db() -> Database {
Database::open_in_memory().expect("open in-memory db")
}
fn insert_commit(db: &Database, sha: &str, repo: &str) {
db.connection()
.execute(
"INSERT INTO commits (sha, author_name, author_email, timestamp, message, repository) \
VALUES (?1, 'n', 'e', '2024-01-01T00:00:00Z', 'm', ?2)",
params![sha, repo],
)
.expect("insert commit");
}
#[test]
fn add_inserts_row() {
let db = fresh_db();
add(
&db,
"abc123",
"feature",
"feature",
Some("note"),
Some("/tmp/r"),
)
.expect("add ok");
let n: i64 = db
.connection()
.query_row(
"SELECT COUNT(*) FROM classification_overrides WHERE commit_sha = ?1",
params!["abc123"],
|r| r.get(0),
)
.expect("count");
assert_eq!(n, 1);
}
#[test]
fn add_uses_repo_from_commits_when_omitted() {
let db = fresh_db();
insert_commit(&db, "sha-x", "/repo/x");
add(&db, "sha-x", "feature", "feature", None, None).expect("add ok");
let repo_path: String = db
.connection()
.query_row(
"SELECT repo_path FROM classification_overrides WHERE commit_sha = 'sha-x'",
[],
|r| r.get(0),
)
.expect("query");
assert_eq!(repo_path, "/repo/x");
}
#[test]
fn add_errors_when_repo_unresolvable() {
let db = fresh_db();
let err = add(&db, "missing", "feature", "feature", None, None).unwrap_err();
assert!(err.to_string().contains("--repo"));
}
#[test]
fn list_filters_by_repo() {
let db = fresh_db();
add(&db, "a", "feature", "feature", None, Some("/repo/a")).expect("add a");
add(&db, "b", "feature", "feature", None, Some("/repo/b")).expect("add b");
let n: i64 = db
.connection()
.query_row(
"SELECT COUNT(*) FROM classification_overrides WHERE repo_path = '/repo/a'",
[],
|r| r.get(0),
)
.expect("count");
assert_eq!(n, 1);
}
#[test]
fn remove_deletes_row() {
let mut db = fresh_db();
add(&db, "del-me", "feature", "feature", None, Some("/r")).expect("add");
let mut input: &[u8] = b"y\n";
remove(&mut db, "del-me", false, &mut input).expect("remove ok");
let n: i64 = db
.connection()
.query_row(
"SELECT COUNT(*) FROM classification_overrides WHERE commit_sha = 'del-me'",
[],
|r| r.get(0),
)
.expect("count");
assert_eq!(n, 0);
}
#[test]
fn remove_aborts_without_confirmation() {
let mut db = fresh_db();
add(&db, "keep", "feature", "feature", None, Some("/r")).expect("add");
let mut input: &[u8] = b"n\n";
remove(&mut db, "keep", false, &mut input).expect("remove returns ok");
let n: i64 = db
.connection()
.query_row(
"SELECT COUNT(*) FROM classification_overrides WHERE commit_sha = 'keep'",
[],
|r| r.get(0),
)
.expect("count");
assert_eq!(n, 1);
}
}