1use crate::checkout;
7use crate::commit::{CommitError, CommitsTable, create_commit};
8use crate::diff;
9use crate::object_store::GitObjectStore;
10use arrow_graph_core::{Triple, col};
11
12#[derive(Debug, thiserror::Error)]
14pub enum RevertError {
15 #[error("Commit error: {0}")]
16 Commit(#[from] CommitError),
17
18 #[error("Store error: {0}")]
19 Store(#[from] arrow_graph_core::StoreError),
20
21 #[error("Cannot revert merge commit {0} (has {1} parents) — specify parent")]
22 MergeCommit(String, usize),
23
24 #[error("Commit has no parent: {0}")]
25 NoParent(String),
26}
27
28pub fn revert(
37 obj_store: &mut GitObjectStore,
38 commits_table: &mut CommitsTable,
39 commit_id: &str,
40 head_commit_id: &str,
41 author: &str,
42) -> Result<String, RevertError> {
43 let target = commits_table
44 .get(commit_id)
45 .ok_or_else(|| CommitError::NotFound(commit_id.to_string()))?;
46
47 if target.parent_ids.len() > 1 {
49 return Err(RevertError::MergeCommit(
50 commit_id.to_string(),
51 target.parent_ids.len(),
52 ));
53 }
54
55 if target.parent_ids.is_empty() {
57 return Err(RevertError::NoParent(commit_id.to_string()));
58 }
59
60 let parent_id = target.parent_ids[0].clone();
61 let target_message = target.message.clone();
62
63 let commit_diff = diff::diff(obj_store, commits_table, &parent_id, commit_id)?;
65
66 checkout::checkout(obj_store, commits_table, head_commit_id)?;
68
69 for entry in &commit_diff.added {
75 let batches = obj_store.store.get_namespace_batches(&entry.namespace);
76 let mut ids_to_delete = Vec::new();
77 for batch in batches {
78 let id_col = batch
79 .column(col::TRIPLE_ID)
80 .as_any()
81 .downcast_ref::<arrow::array::StringArray>()
82 .expect("triple_id column");
83 let subj_col = batch
84 .column(col::SUBJECT)
85 .as_any()
86 .downcast_ref::<arrow::array::StringArray>()
87 .expect("subject column");
88 let pred_col = batch
89 .column(col::PREDICATE)
90 .as_any()
91 .downcast_ref::<arrow::array::StringArray>()
92 .expect("predicate column");
93 let obj_col = batch
94 .column(col::OBJECT)
95 .as_any()
96 .downcast_ref::<arrow::array::StringArray>()
97 .expect("object column");
98
99 for i in 0..batch.num_rows() {
100 if subj_col.value(i) == entry.subject
101 && pred_col.value(i) == entry.predicate
102 && obj_col.value(i) == entry.object
103 {
104 ids_to_delete.push(id_col.value(i).to_string());
105 }
106 }
107 }
108 for id in &ids_to_delete {
109 let _ = obj_store.store.delete(id);
110 }
111 }
112
113 for entry in &commit_diff.removed {
115 let triple = Triple {
116 subject: entry.subject.clone(),
117 predicate: entry.predicate.clone(),
118 object: entry.object.clone(),
119 graph: entry.graph.clone(),
120 confidence: entry.confidence,
121 source_document: entry.source_document.clone(),
122 source_chunk_id: entry.source_chunk_id.clone(),
123 extracted_by: Some(format!("revert by {author}")),
124 caused_by: entry.caused_by.clone(),
125 derived_from: entry.derived_from.clone(),
126 consolidated_at: entry.consolidated_at,
127 };
128 obj_store
129 .store
130 .add_triple(&triple, &entry.namespace, Some(entry.y_layer))?;
131 }
132
133 let revert_commit = create_commit(
135 obj_store,
136 commits_table,
137 vec![head_commit_id.to_string()],
138 &format!("Revert: {target_message}"),
139 author,
140 )?;
141
142 Ok(revert_commit.commit_id)
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148 use crate::commit::create_commit;
149 use arrow_graph_core::Triple;
150
151 fn sample_triple(subj: &str, obj: &str) -> Triple {
152 Triple {
153 subject: subj.to_string(),
154 predicate: "rdf:type".to_string(),
155 object: obj.to_string(),
156 graph: None,
157 confidence: Some(0.9),
158 source_document: None,
159 source_chunk_id: None,
160 extracted_by: None,
161 caused_by: None,
162 derived_from: None,
163 consolidated_at: None,
164 }
165 }
166
167 #[test]
168 fn test_revert_restores_previous_state() {
169 let tmp = tempfile::tempdir().unwrap();
170 let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
171 let mut commits = CommitsTable::new();
172
173 obj.store
174 .add_triple(&sample_triple("s1", "A"), "world", Some(1u8))
175 .unwrap();
176 let ca = create_commit(&obj, &mut commits, vec![], "commit A", "test").unwrap();
177
178 obj.store
179 .add_triple(&sample_triple("s2", "B"), "world", Some(1u8))
180 .unwrap();
181 let cb = create_commit(
182 &obj,
183 &mut commits,
184 vec![ca.commit_id.clone()],
185 "commit B",
186 "test",
187 )
188 .unwrap();
189
190 let revert_id =
191 revert(&mut obj, &mut commits, &cb.commit_id, &cb.commit_id, "test").unwrap();
192
193 assert_eq!(obj.store.len(), 1);
194 let rc = commits.get(&revert_id).unwrap();
195 assert!(rc.message.starts_with("Revert:"));
196 assert_eq!(rc.parent_ids, vec![cb.commit_id.clone()]);
197 }
198
199 #[test]
200 fn test_revert_of_revert_restores_original() {
201 let tmp = tempfile::tempdir().unwrap();
202 let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
203 let mut commits = CommitsTable::new();
204
205 obj.store
206 .add_triple(&sample_triple("s1", "A"), "world", Some(1u8))
207 .unwrap();
208 let ca = create_commit(&obj, &mut commits, vec![], "commit A", "test").unwrap();
209
210 obj.store
211 .add_triple(&sample_triple("s2", "B"), "world", Some(1u8))
212 .unwrap();
213 let cb = create_commit(
214 &obj,
215 &mut commits,
216 vec![ca.commit_id.clone()],
217 "commit B",
218 "test",
219 )
220 .unwrap();
221
222 let revert_id =
223 revert(&mut obj, &mut commits, &cb.commit_id, &cb.commit_id, "test").unwrap();
224 assert_eq!(obj.store.len(), 1);
225
226 let _revert2_id = revert(&mut obj, &mut commits, &revert_id, &revert_id, "test").unwrap();
227 assert_eq!(obj.store.len(), 2);
228 }
229
230 #[test]
231 fn test_revert_merge_commit_errors() {
232 let tmp = tempfile::tempdir().unwrap();
233 let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
234 let mut commits = CommitsTable::new();
235
236 obj.store
237 .add_triple(&sample_triple("s1", "A"), "world", Some(1u8))
238 .unwrap();
239 let c1 = create_commit(&obj, &mut commits, vec![], "c1", "test").unwrap();
240 let c2 =
241 create_commit(&obj, &mut commits, vec![c1.commit_id.clone()], "c2", "test").unwrap();
242 let merge_c = create_commit(
243 &obj,
244 &mut commits,
245 vec![c1.commit_id.clone(), c2.commit_id.clone()],
246 "merge",
247 "test",
248 )
249 .unwrap();
250
251 let result = revert(
252 &mut obj,
253 &mut commits,
254 &merge_c.commit_id,
255 &merge_c.commit_id,
256 "test",
257 );
258 assert!(result.is_err());
259 match result.unwrap_err() {
260 RevertError::MergeCommit(_, n) => assert_eq!(n, 2),
261 other => panic!("Expected MergeCommit error, got: {other:?}"),
262 }
263 }
264
265 #[test]
266 fn test_revert_root_commit_errors() {
267 let tmp = tempfile::tempdir().unwrap();
268 let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
269 let mut commits = CommitsTable::new();
270
271 obj.store
272 .add_triple(&sample_triple("s1", "A"), "world", Some(1u8))
273 .unwrap();
274 let c1 = create_commit(&obj, &mut commits, vec![], "root", "test").unwrap();
275
276 let result = revert(&mut obj, &mut commits, &c1.commit_id, &c1.commit_id, "test");
277 assert!(result.is_err());
278 match result.unwrap_err() {
279 RevertError::NoParent(_) => {}
280 other => panic!("Expected NoParent error, got: {other:?}"),
281 }
282 }
283}