Skip to main content

arrow_graph_git/
cherry_pick.rs

1//! Cherry-pick — apply a specific commit's changes to the current branch.
2//!
3//! Computes the diff introduced by the source commit, then applies those
4//! changes to HEAD, creating a new commit with the current HEAD as parent.
5
6use crate::checkout;
7use crate::commit::{CommitError, CommitsTable, create_commit};
8use crate::diff;
9use crate::object_store::GitObjectStore;
10use arrow_graph_core::{Triple, col};
11use std::collections::HashSet;
12
13/// Errors from cherry-pick operations.
14#[derive(Debug, thiserror::Error)]
15pub enum CherryPickError {
16    #[error("Commit error: {0}")]
17    Commit(#[from] CommitError),
18
19    #[error("Store error: {0}")]
20    Store(#[from] arrow_graph_core::StoreError),
21
22    #[error("Commit has no parent: {0}")]
23    NoParent(String),
24
25    #[error("Cherry-pick conflict: {0} triples conflict with HEAD")]
26    Conflict(usize),
27}
28
29/// Cherry-pick a commit onto the current HEAD.
30///
31/// 1. Compute what the source commit changed (diff parent -> source)
32/// 2. Check if any of those changes conflict with HEAD
33/// 3. If clean, apply the diff and create a new commit
34///
35/// Returns the new commit's ID.
36pub fn cherry_pick(
37    obj_store: &mut GitObjectStore,
38    commits_table: &mut CommitsTable,
39    source_commit_id: &str,
40    head_commit_id: &str,
41    author: &str,
42) -> Result<String, CherryPickError> {
43    let source = commits_table
44        .get(source_commit_id)
45        .ok_or_else(|| CommitError::NotFound(source_commit_id.to_string()))?;
46
47    if source.parent_ids.is_empty() {
48        return Err(CherryPickError::NoParent(source_commit_id.to_string()));
49    }
50
51    let parent_id = source.parent_ids[0].clone();
52    let source_message = source.message.clone();
53
54    // Compute what the source commit changed
55    let source_diff = diff::diff(obj_store, commits_table, &parent_id, source_commit_id)?;
56
57    // Restore HEAD state
58    checkout::checkout(obj_store, commits_table, head_commit_id)?;
59
60    // Check for conflicts: if HEAD already has a triple with the same
61    // (subject, predicate, namespace) but different object, that's a conflict
62    let mut conflict_count = 0;
63    for entry in &source_diff.added {
64        let batches = obj_store.store.get_namespace_batches(&entry.namespace);
65        for batch in batches {
66            let subj_col = batch
67                .column(col::SUBJECT)
68                .as_any()
69                .downcast_ref::<arrow::array::StringArray>()
70                .expect("subject column");
71            let pred_col = batch
72                .column(col::PREDICATE)
73                .as_any()
74                .downcast_ref::<arrow::array::StringArray>()
75                .expect("predicate column");
76            let obj_col = batch
77                .column(col::OBJECT)
78                .as_any()
79                .downcast_ref::<arrow::array::StringArray>()
80                .expect("object column");
81
82            for i in 0..batch.num_rows() {
83                if subj_col.value(i) == entry.subject
84                    && pred_col.value(i) == entry.predicate
85                    && obj_col.value(i) != entry.object
86                {
87                    conflict_count += 1;
88                }
89            }
90        }
91    }
92
93    if conflict_count > 0 {
94        return Err(CherryPickError::Conflict(conflict_count));
95    }
96
97    // Apply additions
98    for entry in &source_diff.added {
99        // Skip if HEAD already has this exact triple
100        let already_exists = {
101            let batches = obj_store.store.get_namespace_batches(&entry.namespace);
102            let mut found = false;
103            for batch in batches {
104                let subj_col = batch
105                    .column(col::SUBJECT)
106                    .as_any()
107                    .downcast_ref::<arrow::array::StringArray>()
108                    .expect("subject column");
109                let pred_col = batch
110                    .column(col::PREDICATE)
111                    .as_any()
112                    .downcast_ref::<arrow::array::StringArray>()
113                    .expect("predicate column");
114                let obj_col = batch
115                    .column(col::OBJECT)
116                    .as_any()
117                    .downcast_ref::<arrow::array::StringArray>()
118                    .expect("object column");
119
120                for i in 0..batch.num_rows() {
121                    if subj_col.value(i) == entry.subject
122                        && pred_col.value(i) == entry.predicate
123                        && obj_col.value(i) == entry.object
124                    {
125                        found = true;
126                        break;
127                    }
128                }
129                if found {
130                    break;
131                }
132            }
133            found
134        };
135
136        if !already_exists {
137            let triple = Triple {
138                subject: entry.subject.clone(),
139                predicate: entry.predicate.clone(),
140                object: entry.object.clone(),
141                graph: entry.graph.clone(),
142                confidence: entry.confidence,
143                source_document: entry.source_document.clone(),
144                source_chunk_id: entry.source_chunk_id.clone(),
145                extracted_by: Some(format!("cherry-pick by {author}")),
146                caused_by: entry.caused_by.clone(),
147                derived_from: entry.derived_from.clone(),
148                consolidated_at: entry.consolidated_at,
149            };
150            obj_store
151                .store
152                .add_triple(&triple, &entry.namespace, Some(entry.y_layer))?;
153        }
154    }
155
156    // Apply removals — collect IDs to delete first to avoid borrow issues
157    let removals: HashSet<(String, String, String, String)> = source_diff
158        .removed
159        .iter()
160        .map(|e| {
161            (
162                e.subject.clone(),
163                e.predicate.clone(),
164                e.object.clone(),
165                e.namespace.clone(),
166            )
167        })
168        .collect();
169
170    for ns in obj_store.store.namespaces().to_vec() {
171        let batches = obj_store.store.get_namespace_batches(&ns);
172        let mut ids_to_delete = Vec::new();
173        for batch in batches {
174            let id_col = batch
175                .column(col::TRIPLE_ID)
176                .as_any()
177                .downcast_ref::<arrow::array::StringArray>()
178                .expect("triple_id column");
179            let subj_col = batch
180                .column(col::SUBJECT)
181                .as_any()
182                .downcast_ref::<arrow::array::StringArray>()
183                .expect("subject column");
184            let pred_col = batch
185                .column(col::PREDICATE)
186                .as_any()
187                .downcast_ref::<arrow::array::StringArray>()
188                .expect("predicate column");
189            let obj_col = batch
190                .column(col::OBJECT)
191                .as_any()
192                .downcast_ref::<arrow::array::StringArray>()
193                .expect("object column");
194
195            for i in 0..batch.num_rows() {
196                let key = (
197                    subj_col.value(i).to_string(),
198                    pred_col.value(i).to_string(),
199                    obj_col.value(i).to_string(),
200                    ns.clone(),
201                );
202                if removals.contains(&key) {
203                    ids_to_delete.push(id_col.value(i).to_string());
204                }
205            }
206        }
207        for id in &ids_to_delete {
208            let _ = obj_store.store.delete(id);
209        }
210    }
211
212    // Create cherry-pick commit with HEAD as parent
213    let cp_commit = create_commit(
214        obj_store,
215        commits_table,
216        vec![head_commit_id.to_string()],
217        &format!("Cherry-pick: {source_message}"),
218        author,
219    )?;
220
221    Ok(cp_commit.commit_id)
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use crate::checkout::checkout as git_checkout;
228    use crate::commit::create_commit;
229    use arrow_graph_core::Triple;
230
231    fn sample_triple(subj: &str, obj: &str) -> Triple {
232        Triple {
233            subject: subj.to_string(),
234            predicate: "rdf:type".to_string(),
235            object: obj.to_string(),
236            graph: None,
237            confidence: Some(0.9),
238            source_document: None,
239            source_chunk_id: None,
240            extracted_by: None,
241            caused_by: None,
242            derived_from: None,
243            consolidated_at: None,
244        }
245    }
246
247    #[test]
248    fn test_clean_cherry_pick() {
249        let tmp = tempfile::tempdir().unwrap();
250        let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
251        let mut commits = CommitsTable::new();
252
253        obj.store
254            .add_triple(&sample_triple("s1", "Base"), "world", Some(1u8))
255            .unwrap();
256        let base = create_commit(&obj, &mut commits, vec![], "base", "test").unwrap();
257
258        obj.store
259            .add_triple(&sample_triple("s2", "Feature"), "world", Some(1u8))
260            .unwrap();
261        let feature = create_commit(
262            &obj,
263            &mut commits,
264            vec![base.commit_id.clone()],
265            "add s2",
266            "test",
267        )
268        .unwrap();
269
270        git_checkout(&mut obj, &commits, &base.commit_id).unwrap();
271        obj.store
272            .add_triple(&sample_triple("s3", "Main"), "world", Some(1u8))
273            .unwrap();
274        let main_head = create_commit(
275            &obj,
276            &mut commits,
277            vec![base.commit_id.clone()],
278            "add s3",
279            "test",
280        )
281        .unwrap();
282
283        let cp_id = cherry_pick(
284            &mut obj,
285            &mut commits,
286            &feature.commit_id,
287            &main_head.commit_id,
288            "test",
289        )
290        .unwrap();
291
292        assert_eq!(obj.store.len(), 3);
293        let cp = commits.get(&cp_id).unwrap();
294        assert!(cp.message.starts_with("Cherry-pick:"));
295        assert_eq!(cp.parent_ids, vec![main_head.commit_id]);
296    }
297
298    #[test]
299    fn test_cherry_pick_with_conflict() {
300        let tmp = tempfile::tempdir().unwrap();
301        let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
302        let mut commits = CommitsTable::new();
303
304        obj.store
305            .add_triple(&sample_triple("s1", "Base"), "world", Some(1u8))
306            .unwrap();
307        let base = create_commit(&obj, &mut commits, vec![], "base", "test").unwrap();
308
309        obj.store
310            .add_triple(&sample_triple("conflict-subj", "TypeA"), "world", Some(1u8))
311            .unwrap();
312        let feature = create_commit(
313            &obj,
314            &mut commits,
315            vec![base.commit_id.clone()],
316            "feature",
317            "test",
318        )
319        .unwrap();
320
321        git_checkout(&mut obj, &commits, &base.commit_id).unwrap();
322        obj.store
323            .add_triple(&sample_triple("conflict-subj", "TypeB"), "world", Some(1u8))
324            .unwrap();
325        let main_head = create_commit(
326            &obj,
327            &mut commits,
328            vec![base.commit_id.clone()],
329            "main",
330            "test",
331        )
332        .unwrap();
333
334        let result = cherry_pick(
335            &mut obj,
336            &mut commits,
337            &feature.commit_id,
338            &main_head.commit_id,
339            "test",
340        );
341        assert!(result.is_err());
342        match result.unwrap_err() {
343            CherryPickError::Conflict(n) => assert!(n > 0),
344            other => panic!("Expected Conflict, got: {other:?}"),
345        }
346    }
347
348    #[test]
349    fn test_cherry_pick_preserves_content() {
350        let tmp = tempfile::tempdir().unwrap();
351        let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
352        let mut commits = CommitsTable::new();
353
354        obj.store
355            .add_triple(&sample_triple("s1", "Base"), "world", Some(1u8))
356            .unwrap();
357        let base = create_commit(&obj, &mut commits, vec![], "base", "test").unwrap();
358
359        obj.store
360            .add_triple(&sample_triple("feat1", "F1"), "world", Some(1u8))
361            .unwrap();
362        obj.store
363            .add_triple(&sample_triple("feat2", "F2"), "world", Some(1u8))
364            .unwrap();
365        let feature = create_commit(
366            &obj,
367            &mut commits,
368            vec![base.commit_id.clone()],
369            "feature",
370            "test",
371        )
372        .unwrap();
373
374        git_checkout(&mut obj, &commits, &base.commit_id).unwrap();
375        let main_head = create_commit(
376            &obj,
377            &mut commits,
378            vec![base.commit_id.clone()],
379            "main",
380            "test",
381        )
382        .unwrap();
383
384        cherry_pick(
385            &mut obj,
386            &mut commits,
387            &feature.commit_id,
388            &main_head.commit_id,
389            "test",
390        )
391        .unwrap();
392
393        assert_eq!(obj.store.len(), 3);
394    }
395}