use crate::entity_type::EntityType;
use crate::errors::AppError;
use crate::output::{self, OutputFormat};
use crate::paths::AppPaths;
use crate::storage::connection::open_rw;
use rusqlite::params;
use serde::Serialize;
#[derive(clap::Args)]
#[command(after_long_help = "EXAMPLES:\n \
# Rename a single edge from 'mentions' to 'related'\n \
sqlite-graphrag reclassify-relation --source tokio --target axum \\\n \
--from-relation mentions --to-relation related\n\n \
# Rename every 'mentions' edge in the namespace to 'related'\n \
sqlite-graphrag reclassify-relation \\\n \
--from-relation mentions --to-relation related --batch\n\n \
# Dry-run to preview what would change\n \
sqlite-graphrag reclassify-relation \\\n \
--from-relation mentions --to-relation related --batch --dry-run\n\n \
# Batch rename only edges whose source is a 'tool' entity\n \
sqlite-graphrag reclassify-relation \\\n \
--from-relation uses --to-relation depends_on --batch \\\n \
--filter-source-type tool\n\n\
NOTE:\n \
Single mode requires --source, --target and --from-relation.\n \
Batch mode requires --from-relation, --to-relation and --batch.\n \
--filter-source-type and --filter-target-type are only effective in batch mode.")]
pub struct ReclassifyRelationArgs {
#[arg(long, conflicts_with = "batch", value_name = "ENTITY")]
pub source: Option<String>,
#[arg(long, conflicts_with = "batch", value_name = "ENTITY")]
pub target: Option<String>,
#[arg(long, value_parser = crate::parsers::parse_relation, value_name = "RELATION")]
pub from_relation: String,
#[arg(long, value_parser = crate::parsers::parse_relation, value_name = "RELATION")]
pub to_relation: String,
#[arg(long, default_value_t = false)]
pub batch: bool,
#[arg(long, value_enum, value_name = "TYPE", requires = "batch")]
pub filter_source_type: Option<EntityType>,
#[arg(long, value_enum, value_name = "TYPE", requires = "batch")]
pub filter_target_type: Option<EntityType>,
#[arg(long, default_value_t = false)]
pub dry_run: bool,
#[arg(long)]
pub namespace: Option<String>,
#[arg(long, value_enum, default_value = "json")]
pub format: OutputFormat,
#[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
pub json: bool,
#[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
pub db: Option<String>,
}
#[derive(Serialize)]
struct ReclassifyRelationResponse {
action: String,
from_relation: String,
to_relation: String,
count: usize,
merged_duplicates: usize,
namespace: String,
elapsed_ms: u64,
}
pub fn run(args: ReclassifyRelationArgs) -> Result<(), AppError> {
let inicio = std::time::Instant::now();
let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
let paths = AppPaths::resolve(args.db.as_deref())?;
crate::storage::connection::ensure_db_ready(&paths)?;
crate::parsers::warn_if_non_canonical(&args.from_relation);
crate::parsers::warn_if_non_canonical(&args.to_relation);
if args.from_relation == args.to_relation {
return Err(AppError::Validation(
"--from-relation and --to-relation must be different".to_string(),
));
}
let mut conn = open_rw(&paths.db)?;
if args.batch {
run_batch(args, inicio, namespace, &mut conn)
} else {
run_single(args, inicio, namespace, &mut conn)
}
}
fn run_single(
args: ReclassifyRelationArgs,
inicio: std::time::Instant,
namespace: String,
conn: &mut rusqlite::Connection,
) -> Result<(), AppError> {
let source_name = args.source.as_deref().ok_or_else(|| {
AppError::Validation(
"--source is required in single mode (omit --batch for single-edge rename)".to_string(),
)
})?;
let target_name = args
.target
.as_deref()
.ok_or_else(|| AppError::Validation("--target is required in single mode".to_string()))?;
let source_name_norm = crate::parsers::normalize_entity_name(source_name);
let target_name_norm = crate::parsers::normalize_entity_name(target_name);
let source_id: i64 = conn
.query_row(
"SELECT id FROM entities WHERE name = ?1 AND namespace = ?2",
params![source_name_norm, namespace],
|r| r.get(0),
)
.map_err(|_| {
AppError::NotFound(format!(
"source entity '{source_name}' not found in namespace '{namespace}'"
))
})?;
let target_id: i64 = conn
.query_row(
"SELECT id FROM entities WHERE name = ?1 AND namespace = ?2",
params![target_name_norm, namespace],
|r| r.get(0),
)
.map_err(|_| {
AppError::NotFound(format!(
"target entity '{target_name}' not found in namespace '{namespace}'"
))
})?;
let original_count: i64 = conn.query_row(
"SELECT COUNT(*) FROM relationships
WHERE source_id = ?1 AND target_id = ?2 AND relation = ?3 AND namespace = ?4",
params![source_id, target_id, args.from_relation, namespace],
|r| r.get(0),
)?;
if original_count == 0 {
return Err(AppError::NotFound(format!(
"edge '{source_name}' --[{}]--> '{target_name}' not found in namespace '{namespace}'",
args.from_relation
)));
}
if args.dry_run {
emit_response(
&args,
"dry_run",
original_count as usize,
0,
namespace,
inicio,
)?;
return Ok(());
}
let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
let updated = tx.execute(
"UPDATE OR IGNORE relationships
SET relation = ?1, updated_at = unixepoch()
WHERE source_id = ?2 AND target_id = ?3 AND relation = ?4 AND namespace = ?5",
params![
args.to_relation,
source_id,
target_id,
args.from_relation,
namespace
],
)?;
let deleted = tx.execute(
"DELETE FROM relationships
WHERE source_id = ?1 AND target_id = ?2 AND relation = ?3 AND namespace = ?4",
params![source_id, target_id, args.from_relation, namespace],
)?;
tx.commit()?;
conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);")?;
let merged = (original_count as usize).saturating_sub(updated + deleted);
emit_response(&args, "reclassified", updated, merged, namespace, inicio)
}
fn run_batch(
args: ReclassifyRelationArgs,
inicio: std::time::Instant,
namespace: String,
conn: &mut rusqlite::Connection,
) -> Result<(), AppError> {
let source_filter = args
.filter_source_type
.map(|t| format!(" AND src.type = '{}'", t.as_str()))
.unwrap_or_default();
let target_filter = args
.filter_target_type
.map(|t| format!(" AND tgt.type = '{}'", t.as_str()))
.unwrap_or_default();
let has_filters = !source_filter.is_empty() || !target_filter.is_empty();
let original_count: i64 = if has_filters {
conn.query_row(
&format!(
"SELECT COUNT(*) FROM relationships r
JOIN entities src ON src.id = r.source_id
JOIN entities tgt ON tgt.id = r.target_id
WHERE r.relation = ?1 AND r.namespace = ?2{source_filter}{target_filter}"
),
params![args.from_relation, namespace],
|r| r.get(0),
)?
} else {
conn.query_row(
"SELECT COUNT(*) FROM relationships
WHERE relation = ?1 AND namespace = ?2",
params![args.from_relation, namespace],
|r| r.get(0),
)?
};
if original_count == 0 {
tracing::warn!(
from_relation = %args.from_relation,
namespace = %namespace,
"reclassify-relation batch matched zero edges — verify --from-relation value"
);
}
if args.dry_run {
emit_response(
&args,
"dry_run",
original_count as usize,
0,
namespace,
inicio,
)?;
return Ok(());
}
let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
let updated = if has_filters {
let ids: Vec<i64> = {
let mut stmt = tx.prepare(&format!(
"SELECT r.id FROM relationships r
JOIN entities src ON src.id = r.source_id
JOIN entities tgt ON tgt.id = r.target_id
WHERE r.relation = ?1 AND r.namespace = ?2{source_filter}{target_filter}"
))?;
let collected: Vec<i64> = stmt
.query_map(params![args.from_relation, namespace], |r| r.get(0))?
.collect::<Result<Vec<_>, _>>()?;
collected
};
let mut moved: usize = 0;
for id in &ids {
let n = tx.execute(
"UPDATE OR IGNORE relationships
SET relation = ?1, updated_at = unixepoch()
WHERE id = ?2",
params![args.to_relation, id],
)?;
moved += n;
}
moved
} else {
tx.execute(
"UPDATE OR IGNORE relationships
SET relation = ?1, updated_at = unixepoch()
WHERE relation = ?2 AND namespace = ?3",
params![args.to_relation, args.from_relation, namespace],
)?
};
let deleted = if has_filters {
tx.execute(
&format!(
"DELETE FROM relationships WHERE id IN (
SELECT r.id FROM relationships r
JOIN entities src ON src.id = r.source_id
JOIN entities tgt ON tgt.id = r.target_id
WHERE r.relation = ?1 AND r.namespace = ?2{source_filter}{target_filter}
)"
),
params![args.from_relation, namespace],
)?
} else {
tx.execute(
"DELETE FROM relationships WHERE relation = ?1 AND namespace = ?2",
params![args.from_relation, namespace],
)?
};
tx.commit()?;
conn.execute_batch("ANALYZE relationships;")?;
conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);")?;
let merged = (original_count as usize).saturating_sub(updated + deleted);
emit_response(&args, "reclassified", updated, merged, namespace, inicio)
}
fn emit_response(
args: &ReclassifyRelationArgs,
action: &str,
count: usize,
merged_duplicates: usize,
namespace: String,
inicio: std::time::Instant,
) -> Result<(), AppError> {
let response = ReclassifyRelationResponse {
action: action.to_string(),
from_relation: args.from_relation.clone(),
to_relation: args.to_relation.clone(),
count,
merged_duplicates,
namespace: namespace.clone(),
elapsed_ms: inicio.elapsed().as_millis() as u64,
};
match args.format {
OutputFormat::Json => output::emit_json(&response)?,
OutputFormat::Text | OutputFormat::Markdown => {
output::emit_text(&format!(
"{action}: {count} edges '{}' → '{}' [{namespace}] (duplicates merged: {merged_duplicates})",
args.from_relation, args.to_relation
));
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn make_response(action: &str, count: usize, merged: usize) -> ReclassifyRelationResponse {
ReclassifyRelationResponse {
action: action.to_string(),
from_relation: "mentions".to_string(),
to_relation: "related".to_string(),
count,
merged_duplicates: merged,
namespace: "global".to_string(),
elapsed_ms: 1,
}
}
#[test]
fn response_serializes_all_fields() {
let resp = make_response("reclassified", 5, 0);
let json = serde_json::to_value(&resp).expect("serialization failed");
assert_eq!(json["action"], "reclassified");
assert_eq!(json["from_relation"], "mentions");
assert_eq!(json["to_relation"], "related");
assert_eq!(json["count"], 5);
assert_eq!(json["merged_duplicates"], 0);
assert_eq!(json["namespace"], "global");
assert!(json["elapsed_ms"].is_number());
}
#[test]
fn response_action_dry_run() {
let resp = make_response("dry_run", 10, 0);
let json = serde_json::to_value(&resp).expect("serialization failed");
assert_eq!(json["action"], "dry_run");
assert_eq!(json["count"], 10);
assert_eq!(json["merged_duplicates"], 0);
}
#[test]
fn response_merged_duplicates_nonzero() {
let resp = make_response("reclassified", 7, 3);
let json = serde_json::to_value(&resp).expect("serialization failed");
assert_eq!(json["count"], 7);
assert_eq!(json["merged_duplicates"], 3);
}
#[test]
fn response_count_zero_when_nothing_matched() {
let resp = make_response("reclassified", 0, 0);
let json = serde_json::to_value(&resp).expect("serialization failed");
assert_eq!(json["count"], 0);
assert_eq!(json["merged_duplicates"], 0);
}
#[test]
fn response_action_values_exhaustive() {
for action in &["reclassified", "dry_run"] {
let resp = make_response(action, 1, 0);
let json = serde_json::to_value(&resp).expect("serialization");
assert_eq!(json["action"], *action);
}
}
#[test]
fn response_from_and_to_relation_present() {
let resp = ReclassifyRelationResponse {
action: "reclassified".to_string(),
from_relation: "uses".to_string(),
to_relation: "depends_on".to_string(),
count: 3,
merged_duplicates: 1,
namespace: "my-project".to_string(),
elapsed_ms: 5,
};
let json = serde_json::to_value(&resp).expect("serialization failed");
assert_eq!(json["from_relation"], "uses");
assert_eq!(json["to_relation"], "depends_on");
}
#[test]
fn same_relation_value_rejected_at_logic_level() {
let from = "mentions".to_string();
let to = "mentions".to_string();
assert!(
from == to,
"same-value rename must be caught before DB access"
);
}
}