sqlite_graphrag/commands/
reclassify.rs1use crate::entity_type::EntityType;
11use crate::errors::AppError;
12use crate::i18n::errors_msg;
13use crate::output::{self, OutputFormat};
14use crate::paths::AppPaths;
15use crate::storage::connection::open_rw;
16use crate::storage::entities;
17use rusqlite::params;
18use serde::Serialize;
19
20#[derive(clap::Args)]
21#[command(after_long_help = "EXAMPLES:\n \
22 # Reclassify a single entity from its current type to 'tool'\n \
23 sqlite-graphrag reclassify --name tokio-runtime --new-type tool\n\n \
24 # Reclassify all 'concept' entities to 'tool' in one shot (batch)\n \
25 sqlite-graphrag reclassify --from-type concept --to-type tool --batch\n\n \
26 # Reclassify in a specific namespace\n \
27 sqlite-graphrag reclassify --name alice --new-type person --namespace my-project\n\n\
28NOTE:\n \
29 Single mode requires --name and --new-type.\n \
30 Batch mode requires --from-type, --to-type and --batch.\n \
31 Providing --name together with --batch is an error.")]
32pub struct ReclassifyArgs {
33 #[arg(long, conflicts_with_all = ["from_type", "batch"])]
35 pub name: Option<String>,
36 #[arg(long, value_enum, value_name = "TYPE")]
38 pub new_type: Option<EntityType>,
39 #[arg(
41 long,
42 value_enum,
43 value_name = "TYPE",
44 requires = "to_type",
45 requires = "batch"
46 )]
47 pub from_type: Option<EntityType>,
48 #[arg(long, value_enum, value_name = "TYPE", requires = "from_type")]
50 pub to_type: Option<EntityType>,
51 #[arg(long, default_value_t = false, requires = "from_type")]
53 pub batch: bool,
54 #[arg(long)]
55 pub namespace: Option<String>,
56 #[arg(long, value_enum, default_value = "json")]
57 pub format: OutputFormat,
58 #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
59 pub json: bool,
60 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
61 pub db: Option<String>,
62}
63
64#[derive(Serialize)]
65struct ReclassifyResponse {
66 action: String,
67 count: usize,
68 namespace: String,
69 elapsed_ms: u64,
71}
72
73pub fn run(args: ReclassifyArgs) -> Result<(), AppError> {
74 let inicio = std::time::Instant::now();
75 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
76 let paths = AppPaths::resolve(args.db.as_deref())?;
77
78 crate::storage::connection::ensure_db_ready(&paths)?;
79
80 let mut conn = open_rw(&paths.db)?;
81
82 let count = if args.batch {
83 let from_type = args.from_type.ok_or_else(|| {
85 AppError::Validation("--from-type is required in batch mode".to_string())
86 })?;
87 let to_type = args.to_type.ok_or_else(|| {
88 AppError::Validation("--to-type is required in batch mode".to_string())
89 })?;
90
91 let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
92 let affected = tx.execute(
93 "UPDATE entities SET type = ?1, updated_at = unixepoch()
94 WHERE type = ?2 AND namespace = ?3",
95 params![to_type.as_str(), from_type.as_str(), namespace],
96 )?;
97 tx.commit()?;
98 affected
99 } else {
100 let entity_name = args
102 .name
103 .as_deref()
104 .ok_or_else(|| AppError::Validation("--name is required in single mode".to_string()))?;
105 let new_type = args.new_type.ok_or_else(|| {
106 AppError::Validation("--new-type is required in single mode".to_string())
107 })?;
108
109 entities::find_entity_id(&conn, &namespace, entity_name)?.ok_or_else(|| {
111 AppError::NotFound(errors_msg::entity_not_found(entity_name, &namespace))
112 })?;
113
114 let tx = conn.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)?;
115 let affected = tx.execute(
116 "UPDATE entities SET type = ?1, updated_at = unixepoch()
117 WHERE name = ?2 AND namespace = ?3",
118 params![new_type.as_str(), entity_name, namespace],
119 )?;
120 tx.commit()?;
121 affected
122 };
123
124 conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);")?;
125
126 let response = ReclassifyResponse {
127 action: "reclassified".to_string(),
128 count,
129 namespace: namespace.clone(),
130 elapsed_ms: inicio.elapsed().as_millis() as u64,
131 };
132
133 match args.format {
134 OutputFormat::Json => output::emit_json(&response)?,
135 OutputFormat::Text | OutputFormat::Markdown => {
136 output::emit_text(&format!(
137 "reclassified: {} entities [{}]",
138 response.count, response.namespace
139 ));
140 }
141 }
142
143 Ok(())
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 #[test]
151 fn reclassify_response_serializes_all_fields() {
152 let resp = ReclassifyResponse {
153 action: "reclassified".to_string(),
154 count: 5,
155 namespace: "global".to_string(),
156 elapsed_ms: 12,
157 };
158 let json = serde_json::to_value(&resp).expect("serialization failed");
159 assert_eq!(json["action"], "reclassified");
160 assert_eq!(json["count"], 5);
161 assert_eq!(json["namespace"], "global");
162 assert!(json["elapsed_ms"].is_number());
163 }
164
165 #[test]
166 fn reclassify_response_count_zero_is_valid() {
167 let resp = ReclassifyResponse {
168 action: "reclassified".to_string(),
169 count: 0,
170 namespace: "my-project".to_string(),
171 elapsed_ms: 3,
172 };
173 let json = serde_json::to_value(&resp).expect("serialization failed");
174 assert_eq!(json["count"], 0);
175 assert_eq!(json["action"], "reclassified");
176 }
177
178 #[test]
179 fn reclassify_response_action_is_reclassified() {
180 let resp = ReclassifyResponse {
181 action: "reclassified".to_string(),
182 count: 1,
183 namespace: "ns".to_string(),
184 elapsed_ms: 1,
185 };
186 assert_eq!(resp.action, "reclassified");
187 }
188}