Skip to main content

arrow_graph_git/
diff.rs

1//! Diff — object-level comparison between two commits.
2//!
3//! Compares RecordBatches to find added, removed, and modified triples.
4//! A triple is identified by (subject, predicate, object, namespace).
5
6use crate::checkout;
7use crate::commit::{CommitError, CommitsTable};
8use crate::object_store::GitObjectStore;
9use arrow::array::{Array, Float64Array, StringArray, TimestampMillisecondArray, UInt8Array};
10use arrow_graph_core::{QuerySpec, col};
11use std::collections::HashMap;
12
13/// A single diff entry, carrying full provenance metadata so merges preserve it.
14#[derive(Debug, Clone, PartialEq)]
15pub struct DiffEntry {
16    pub subject: String,
17    pub predicate: String,
18    pub object: String,
19    pub namespace: String,
20    pub y_layer: u8,
21    pub confidence: Option<f64>,
22    pub graph: Option<String>,
23    pub source_document: Option<String>,
24    pub source_chunk_id: Option<String>,
25    pub caused_by: Option<String>,
26    pub derived_from: Option<String>,
27    pub consolidated_at: Option<i64>,
28}
29
30/// The result of a diff between two commits.
31#[derive(Debug, Clone, Default)]
32pub struct DiffResult {
33    /// Triples present in `head` but not in `base`.
34    pub added: Vec<DiffEntry>,
35    /// Triples present in `base` but not in `head`.
36    pub removed: Vec<DiffEntry>,
37}
38
39impl DiffResult {
40    pub fn is_empty(&self) -> bool {
41        self.added.is_empty() && self.removed.is_empty()
42    }
43
44    pub fn total_changes(&self) -> usize {
45        self.added.len() + self.removed.len()
46    }
47}
48
49/// A triple key for set comparison (identity = subject + predicate + object + namespace).
50#[derive(Debug, Clone, PartialEq, Eq, Hash)]
51struct TripleKey {
52    subject: String,
53    predicate: String,
54    object: String,
55    namespace: String,
56}
57
58/// Extract all triples (key -> full DiffEntry) from the store's current state.
59fn extract_triples(store: &arrow_graph_core::ArrowGraphStore) -> HashMap<TripleKey, DiffEntry> {
60    let mut map = HashMap::new();
61
62    let batches = store
63        .query(&QuerySpec {
64            include_deleted: false,
65            ..Default::default()
66        })
67        .unwrap_or_default();
68
69    for batch in &batches {
70        let subjects = batch
71            .column(col::SUBJECT)
72            .as_any()
73            .downcast_ref::<StringArray>()
74            .expect("subject column");
75        let predicates = batch
76            .column(col::PREDICATE)
77            .as_any()
78            .downcast_ref::<StringArray>()
79            .expect("predicate column");
80        let objects = batch
81            .column(col::OBJECT)
82            .as_any()
83            .downcast_ref::<StringArray>()
84            .expect("object column");
85        let graphs = batch
86            .column(col::GRAPH)
87            .as_any()
88            .downcast_ref::<StringArray>()
89            .expect("graph column");
90        let namespaces = batch
91            .column(col::NAMESPACE)
92            .as_any()
93            .downcast_ref::<StringArray>()
94            .expect("namespace column");
95        let layers = batch
96            .column(col::LAYER)
97            .as_any()
98            .downcast_ref::<UInt8Array>()
99            .expect("layer column");
100        let confidences = batch
101            .column(col::CONFIDENCE)
102            .as_any()
103            .downcast_ref::<Float64Array>()
104            .expect("confidence column");
105        let source_docs = batch
106            .column(col::SOURCE_DOCUMENT)
107            .as_any()
108            .downcast_ref::<StringArray>()
109            .expect("source_document column");
110        let source_chunks = batch
111            .column(col::SOURCE_CHUNK_ID)
112            .as_any()
113            .downcast_ref::<StringArray>()
114            .expect("source_chunk_id column");
115        let caused_bys = batch
116            .column(col::CAUSED_BY)
117            .as_any()
118            .downcast_ref::<StringArray>()
119            .expect("caused_by column");
120        let derived_froms = batch
121            .column(col::DERIVED_FROM)
122            .as_any()
123            .downcast_ref::<StringArray>()
124            .expect("derived_from column");
125        let consolidated_ats = batch
126            .column(col::CONSOLIDATED_AT)
127            .as_any()
128            .downcast_ref::<TimestampMillisecondArray>()
129            .expect("consolidated_at column");
130
131        for i in 0..batch.num_rows() {
132            let key = TripleKey {
133                subject: subjects.value(i).to_string(),
134                predicate: predicates.value(i).to_string(),
135                object: objects.value(i).to_string(),
136                namespace: namespaces.value(i).to_string(),
137            };
138            let entry = DiffEntry {
139                subject: key.subject.clone(),
140                predicate: key.predicate.clone(),
141                object: key.object.clone(),
142                namespace: key.namespace.clone(),
143                y_layer: layers.value(i),
144                confidence: if confidences.is_null(i) {
145                    None
146                } else {
147                    Some(confidences.value(i))
148                },
149                graph: if graphs.is_null(i) {
150                    None
151                } else {
152                    Some(graphs.value(i).to_string())
153                },
154                source_document: if source_docs.is_null(i) {
155                    None
156                } else {
157                    Some(source_docs.value(i).to_string())
158                },
159                source_chunk_id: if source_chunks.is_null(i) {
160                    None
161                } else {
162                    Some(source_chunks.value(i).to_string())
163                },
164                caused_by: if caused_bys.is_null(i) {
165                    None
166                } else {
167                    Some(caused_bys.value(i).to_string())
168                },
169                derived_from: if derived_froms.is_null(i) {
170                    None
171                } else {
172                    Some(derived_froms.value(i).to_string())
173                },
174                consolidated_at: if consolidated_ats.is_null(i) {
175                    None
176                } else {
177                    Some(consolidated_ats.value(i))
178                },
179            };
180            map.insert(key, entry);
181        }
182    }
183
184    map
185}
186
187/// Compute the diff between two commits.
188///
189/// `base` is the earlier commit, `head` is the later commit.
190/// Returns triples added in head and triples removed from base.
191///
192/// # Safety
193///
194/// **This function replaces the live store contents** by calling `checkout()` internally.
195/// Any uncommitted changes in `obj_store` will be lost. The store will contain the
196/// `head` commit's state when this function returns. Callers should commit or save
197/// any in-progress work before calling `diff()`.
198pub fn diff(
199    obj_store: &mut GitObjectStore,
200    commits_table: &CommitsTable,
201    base_commit_id: &str,
202    head_commit_id: &str,
203) -> Result<DiffResult, CommitError> {
204    // Load base state
205    checkout::checkout(obj_store, commits_table, base_commit_id)?;
206    let base_triples = extract_triples(&obj_store.store);
207
208    // Load head state
209    checkout::checkout(obj_store, commits_table, head_commit_id)?;
210    let head_triples = extract_triples(&obj_store.store);
211
212    // Added = in head but not in base (with full metadata from head)
213    let added: Vec<DiffEntry> = head_triples
214        .iter()
215        .filter(|(k, _)| !base_triples.contains_key(k))
216        .map(|(_, entry)| entry.clone())
217        .collect();
218
219    // Removed = in base but not in head (with full metadata from base)
220    let removed: Vec<DiffEntry> = base_triples
221        .iter()
222        .filter(|(k, _)| !head_triples.contains_key(k))
223        .map(|(_, entry)| entry.clone())
224        .collect();
225
226    Ok(DiffResult { added, removed })
227}
228
229/// Compute diff without mutating the store — saves and restores current state.
230///
231/// Use this when you have uncommitted changes you want to preserve.
232pub fn diff_nondestructive(
233    obj_store: &mut GitObjectStore,
234    commits_table: &CommitsTable,
235    base_commit_id: &str,
236    head_commit_id: &str,
237) -> Result<DiffResult, CommitError> {
238    // Save current state
239    let saved: Vec<(String, Vec<arrow::array::RecordBatch>)> = obj_store
240        .store
241        .namespaces()
242        .iter()
243        .map(|ns| {
244            let batches = obj_store.store.get_namespace_batches(ns).to_vec();
245            (ns.clone(), batches)
246        })
247        .collect();
248
249    let result = diff(obj_store, commits_table, base_commit_id, head_commit_id);
250
251    // Restore previous state
252    for (ns, batches) in saved {
253        obj_store.store.set_namespace_batches(&ns, batches);
254    }
255
256    result
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use crate::commit::create_commit;
263    use arrow_graph_core::Triple;
264
265    fn sample_triple(subj: &str, obj: &str) -> Triple {
266        Triple {
267            subject: subj.to_string(),
268            predicate: "rdf:type".to_string(),
269            object: obj.to_string(),
270            graph: None,
271            confidence: Some(0.9),
272            source_document: None,
273            source_chunk_id: None,
274            extracted_by: None,
275            caused_by: None,
276            derived_from: None,
277            consolidated_at: None,
278        }
279    }
280
281    #[test]
282    fn test_diff_detects_additions() {
283        let tmp = tempfile::tempdir().unwrap();
284        let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
285        let mut commits = CommitsTable::new();
286
287        // Commit with 1 triple
288        obj.store
289            .add_triple(&sample_triple("s1", "A"), "world", Some(1u8))
290            .unwrap();
291        let c1 = create_commit(&obj, &mut commits, vec![], "first", "test").unwrap();
292
293        // Add another triple and commit
294        obj.store
295            .add_triple(&sample_triple("s2", "B"), "world", Some(1u8))
296            .unwrap();
297        let c2 = create_commit(
298            &obj,
299            &mut commits,
300            vec![c1.commit_id.clone()],
301            "second",
302            "test",
303        )
304        .unwrap();
305
306        let result = diff(&mut obj, &commits, &c1.commit_id, &c2.commit_id).unwrap();
307        assert_eq!(result.added.len(), 1);
308        assert_eq!(result.removed.len(), 0);
309        assert_eq!(result.added[0].subject, "s2");
310        // Verify metadata is preserved
311        assert_eq!(result.added[0].y_layer, 1u8);
312        assert_eq!(result.added[0].confidence, Some(0.9));
313    }
314
315    #[test]
316    fn test_diff_detects_removals() {
317        let tmp = tempfile::tempdir().unwrap();
318        let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
319        let mut commits = CommitsTable::new();
320
321        // Commit with 2 triples
322        obj.store
323            .add_triple(&sample_triple("s1", "A"), "world", Some(1u8))
324            .unwrap();
325        let id2 = obj
326            .store
327            .add_triple(&sample_triple("s2", "B"), "world", Some(1u8))
328            .unwrap();
329        let c1 = create_commit(&obj, &mut commits, vec![], "first", "test").unwrap();
330
331        // Delete one and commit
332        obj.store.delete(&id2).unwrap();
333        let c2 = create_commit(
334            &obj,
335            &mut commits,
336            vec![c1.commit_id.clone()],
337            "second",
338            "test",
339        )
340        .unwrap();
341
342        let result = diff(&mut obj, &commits, &c1.commit_id, &c2.commit_id).unwrap();
343        assert_eq!(result.removed.len(), 1);
344        assert_eq!(result.removed[0].subject, "s2");
345    }
346
347    #[test]
348    fn test_diff_nondestructive_preserves_state() {
349        let tmp = tempfile::tempdir().unwrap();
350        let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
351        let mut commits = CommitsTable::new();
352
353        obj.store
354            .add_triple(&sample_triple("s1", "A"), "world", Some(1u8))
355            .unwrap();
356        let c1 = create_commit(&obj, &mut commits, vec![], "first", "test").unwrap();
357
358        obj.store
359            .add_triple(&sample_triple("s2", "B"), "world", Some(1u8))
360            .unwrap();
361        let c2 = create_commit(
362            &obj,
363            &mut commits,
364            vec![c1.commit_id.clone()],
365            "second",
366            "test",
367        )
368        .unwrap();
369
370        // Add uncommitted work
371        obj.store
372            .add_triple(&sample_triple("uncommitted", "X"), "world", Some(1u8))
373            .unwrap();
374        assert_eq!(obj.store.len(), 3); // s1 + s2 + uncommitted
375
376        // Nondestructive diff should preserve uncommitted state
377        let result = diff_nondestructive(&mut obj, &commits, &c1.commit_id, &c2.commit_id).unwrap();
378        assert_eq!(result.added.len(), 1);
379
380        // Uncommitted work should still be there
381        assert_eq!(obj.store.len(), 3);
382    }
383
384    #[test]
385    fn test_diff_no_changes() {
386        let tmp = tempfile::tempdir().unwrap();
387        let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
388        let mut commits = CommitsTable::new();
389
390        obj.store
391            .add_triple(&sample_triple("s1", "A"), "world", Some(1u8))
392            .unwrap();
393        let c1 = create_commit(&obj, &mut commits, vec![], "first", "test").unwrap();
394
395        // Commit same state again
396        let c2 = create_commit(
397            &obj,
398            &mut commits,
399            vec![c1.commit_id.clone()],
400            "same",
401            "test",
402        )
403        .unwrap();
404
405        let result = diff(&mut obj, &commits, &c1.commit_id, &c2.commit_id).unwrap();
406        assert!(result.is_empty());
407    }
408}