Skip to main content

arrow_graph_git/
revert.rs

1//! Revert — create a new commit that undoes the changes from a target commit.
2//!
3//! Unlike `checkout`, revert does not go back in time — it creates a NEW commit
4//! on the current branch that applies the inverse of the target commit's diff.
5
6use crate::checkout;
7use crate::commit::{CommitError, CommitsTable, create_commit};
8use crate::diff;
9use crate::object_store::GitObjectStore;
10use arrow_graph_core::{Triple, col};
11
12/// Errors from revert operations.
13#[derive(Debug, thiserror::Error)]
14pub enum RevertError {
15    #[error("Commit error: {0}")]
16    Commit(#[from] CommitError),
17
18    #[error("Store error: {0}")]
19    Store(#[from] arrow_graph_core::StoreError),
20
21    #[error("Cannot revert merge commit {0} (has {1} parents) — specify parent")]
22    MergeCommit(String, usize),
23
24    #[error("Commit has no parent: {0}")]
25    NoParent(String),
26}
27
28/// Revert a commit by creating a new commit that undoes its changes.
29///
30/// 1. Find the target commit's parent
31/// 2. Diff parent -> target to get what the commit changed
32/// 3. Apply the inverse (add removed triples, remove added triples) to HEAD
33/// 4. Create a new commit with the inverted changes
34///
35/// Returns the new revert commit's ID.
36pub fn revert(
37    obj_store: &mut GitObjectStore,
38    commits_table: &mut CommitsTable,
39    commit_id: &str,
40    head_commit_id: &str,
41    author: &str,
42) -> Result<String, RevertError> {
43    let target = commits_table
44        .get(commit_id)
45        .ok_or_else(|| CommitError::NotFound(commit_id.to_string()))?;
46
47    // Cannot revert merge commits (multiple parents)
48    if target.parent_ids.len() > 1 {
49        return Err(RevertError::MergeCommit(
50            commit_id.to_string(),
51            target.parent_ids.len(),
52        ));
53    }
54
55    // Must have a parent to compute the diff
56    if target.parent_ids.is_empty() {
57        return Err(RevertError::NoParent(commit_id.to_string()));
58    }
59
60    let parent_id = target.parent_ids[0].clone();
61    let target_message = target.message.clone();
62
63    // Compute what the target commit changed: diff parent -> target
64    let commit_diff = diff::diff(obj_store, commits_table, &parent_id, commit_id)?;
65
66    // Restore HEAD state
67    checkout::checkout(obj_store, commits_table, head_commit_id)?;
68
69    // Apply the INVERSE:
70    // - What was added by the commit should be removed
71    // - What was removed by the commit should be re-added
72
73    // Remove the added triples
74    for entry in &commit_diff.added {
75        let batches = obj_store.store.get_namespace_batches(&entry.namespace);
76        let mut ids_to_delete = Vec::new();
77        for batch in batches {
78            let id_col = batch
79                .column(col::TRIPLE_ID)
80                .as_any()
81                .downcast_ref::<arrow::array::StringArray>()
82                .expect("triple_id column");
83            let subj_col = batch
84                .column(col::SUBJECT)
85                .as_any()
86                .downcast_ref::<arrow::array::StringArray>()
87                .expect("subject column");
88            let pred_col = batch
89                .column(col::PREDICATE)
90                .as_any()
91                .downcast_ref::<arrow::array::StringArray>()
92                .expect("predicate column");
93            let obj_col = batch
94                .column(col::OBJECT)
95                .as_any()
96                .downcast_ref::<arrow::array::StringArray>()
97                .expect("object column");
98
99            for i in 0..batch.num_rows() {
100                if subj_col.value(i) == entry.subject
101                    && pred_col.value(i) == entry.predicate
102                    && obj_col.value(i) == entry.object
103                {
104                    ids_to_delete.push(id_col.value(i).to_string());
105                }
106            }
107        }
108        for id in &ids_to_delete {
109            let _ = obj_store.store.delete(id);
110        }
111    }
112
113    // Re-add the removed triples
114    for entry in &commit_diff.removed {
115        let triple = Triple {
116            subject: entry.subject.clone(),
117            predicate: entry.predicate.clone(),
118            object: entry.object.clone(),
119            graph: entry.graph.clone(),
120            confidence: entry.confidence,
121            source_document: entry.source_document.clone(),
122            source_chunk_id: entry.source_chunk_id.clone(),
123            extracted_by: Some(format!("revert by {author}")),
124            caused_by: entry.caused_by.clone(),
125            derived_from: entry.derived_from.clone(),
126            consolidated_at: entry.consolidated_at,
127        };
128        obj_store
129            .store
130            .add_triple(&triple, &entry.namespace, Some(entry.y_layer))?;
131    }
132
133    // Create the revert commit
134    let revert_commit = create_commit(
135        obj_store,
136        commits_table,
137        vec![head_commit_id.to_string()],
138        &format!("Revert: {target_message}"),
139        author,
140    )?;
141
142    Ok(revert_commit.commit_id)
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    use crate::commit::create_commit;
149    use arrow_graph_core::Triple;
150
151    fn sample_triple(subj: &str, obj: &str) -> Triple {
152        Triple {
153            subject: subj.to_string(),
154            predicate: "rdf:type".to_string(),
155            object: obj.to_string(),
156            graph: None,
157            confidence: Some(0.9),
158            source_document: None,
159            source_chunk_id: None,
160            extracted_by: None,
161            caused_by: None,
162            derived_from: None,
163            consolidated_at: None,
164        }
165    }
166
167    #[test]
168    fn test_revert_restores_previous_state() {
169        let tmp = tempfile::tempdir().unwrap();
170        let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
171        let mut commits = CommitsTable::new();
172
173        obj.store
174            .add_triple(&sample_triple("s1", "A"), "world", Some(1u8))
175            .unwrap();
176        let ca = create_commit(&obj, &mut commits, vec![], "commit A", "test").unwrap();
177
178        obj.store
179            .add_triple(&sample_triple("s2", "B"), "world", Some(1u8))
180            .unwrap();
181        let cb = create_commit(
182            &obj,
183            &mut commits,
184            vec![ca.commit_id.clone()],
185            "commit B",
186            "test",
187        )
188        .unwrap();
189
190        let revert_id =
191            revert(&mut obj, &mut commits, &cb.commit_id, &cb.commit_id, "test").unwrap();
192
193        assert_eq!(obj.store.len(), 1);
194        let rc = commits.get(&revert_id).unwrap();
195        assert!(rc.message.starts_with("Revert:"));
196        assert_eq!(rc.parent_ids, vec![cb.commit_id.clone()]);
197    }
198
199    #[test]
200    fn test_revert_of_revert_restores_original() {
201        let tmp = tempfile::tempdir().unwrap();
202        let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
203        let mut commits = CommitsTable::new();
204
205        obj.store
206            .add_triple(&sample_triple("s1", "A"), "world", Some(1u8))
207            .unwrap();
208        let ca = create_commit(&obj, &mut commits, vec![], "commit A", "test").unwrap();
209
210        obj.store
211            .add_triple(&sample_triple("s2", "B"), "world", Some(1u8))
212            .unwrap();
213        let cb = create_commit(
214            &obj,
215            &mut commits,
216            vec![ca.commit_id.clone()],
217            "commit B",
218            "test",
219        )
220        .unwrap();
221
222        let revert_id =
223            revert(&mut obj, &mut commits, &cb.commit_id, &cb.commit_id, "test").unwrap();
224        assert_eq!(obj.store.len(), 1);
225
226        let _revert2_id = revert(&mut obj, &mut commits, &revert_id, &revert_id, "test").unwrap();
227        assert_eq!(obj.store.len(), 2);
228    }
229
230    #[test]
231    fn test_revert_merge_commit_errors() {
232        let tmp = tempfile::tempdir().unwrap();
233        let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
234        let mut commits = CommitsTable::new();
235
236        obj.store
237            .add_triple(&sample_triple("s1", "A"), "world", Some(1u8))
238            .unwrap();
239        let c1 = create_commit(&obj, &mut commits, vec![], "c1", "test").unwrap();
240        let c2 =
241            create_commit(&obj, &mut commits, vec![c1.commit_id.clone()], "c2", "test").unwrap();
242        let merge_c = create_commit(
243            &obj,
244            &mut commits,
245            vec![c1.commit_id.clone(), c2.commit_id.clone()],
246            "merge",
247            "test",
248        )
249        .unwrap();
250
251        let result = revert(
252            &mut obj,
253            &mut commits,
254            &merge_c.commit_id,
255            &merge_c.commit_id,
256            "test",
257        );
258        assert!(result.is_err());
259        match result.unwrap_err() {
260            RevertError::MergeCommit(_, n) => assert_eq!(n, 2),
261            other => panic!("Expected MergeCommit error, got: {other:?}"),
262        }
263    }
264
265    #[test]
266    fn test_revert_root_commit_errors() {
267        let tmp = tempfile::tempdir().unwrap();
268        let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
269        let mut commits = CommitsTable::new();
270
271        obj.store
272            .add_triple(&sample_triple("s1", "A"), "world", Some(1u8))
273            .unwrap();
274        let c1 = create_commit(&obj, &mut commits, vec![], "root", "test").unwrap();
275
276        let result = revert(&mut obj, &mut commits, &c1.commit_id, &c1.commit_id, "test");
277        assert!(result.is_err());
278        match result.unwrap_err() {
279            RevertError::NoParent(_) => {}
280            other => panic!("Expected NoParent error, got: {other:?}"),
281        }
282    }
283}