Skip to main content

nusy_arrow_git/
cherry_pick.rs

1//! Cherry-pick — apply a specific commit's changes to the current branch.
2//!
3//! Computes the diff introduced by the source commit, then applies those
4//! changes to HEAD, creating a new commit with the current HEAD as parent.
5
6use crate::checkout;
7use crate::commit::{CommitError, CommitsTable, create_commit};
8use crate::diff;
9use crate::object_store::GitObjectStore;
10use nusy_arrow_core::{Namespace, Triple, YLayer, col};
11use std::collections::HashSet;
12
13/// Errors from cherry-pick operations.
14#[derive(Debug, thiserror::Error)]
15pub enum CherryPickError {
16    #[error("Commit error: {0}")]
17    Commit(#[from] CommitError),
18
19    #[error("Store error: {0}")]
20    Store(#[from] nusy_arrow_core::StoreError),
21
22    #[error("Commit has no parent: {0}")]
23    NoParent(String),
24
25    #[error("Cherry-pick conflict: {0} triples conflict with HEAD")]
26    Conflict(usize),
27}
28
29/// Cherry-pick a commit onto the current HEAD.
30///
31/// 1. Compute what the source commit changed (diff parent → source)
32/// 2. Check if any of those changes conflict with HEAD
33/// 3. If clean, apply the diff and create a new commit
34///
35/// Returns the new commit's ID.
36pub fn cherry_pick(
37    obj_store: &mut GitObjectStore,
38    commits_table: &mut CommitsTable,
39    source_commit_id: &str,
40    head_commit_id: &str,
41    author: &str,
42) -> Result<String, CherryPickError> {
43    let source = commits_table
44        .get(source_commit_id)
45        .ok_or_else(|| CommitError::NotFound(source_commit_id.to_string()))?;
46
47    if source.parent_ids.is_empty() {
48        return Err(CherryPickError::NoParent(source_commit_id.to_string()));
49    }
50
51    let parent_id = source.parent_ids[0].clone();
52    let source_message = source.message.clone();
53
54    // Compute what the source commit changed
55    let source_diff = diff::diff(obj_store, commits_table, &parent_id, source_commit_id)?;
56
57    // Restore HEAD state
58    checkout::checkout(obj_store, commits_table, head_commit_id)?;
59
60    // Check for conflicts: if HEAD already has a triple with the same
61    // (subject, predicate, namespace) but different object, that's a conflict
62    let mut conflict_count = 0;
63    for entry in &source_diff.added {
64        // Defensive: fall back to World namespace if the diff entry has an
65        // unrecognized namespace string. This can happen if namespaces were
66        // added after the commit was created.
67        let ns = Namespace::from_str_loose(&entry.namespace).unwrap_or(Namespace::World);
68        let batches = obj_store.store.get_namespace_batches(ns);
69        for batch in batches {
70            let subj_col = batch
71                .column(col::SUBJECT)
72                .as_any()
73                .downcast_ref::<arrow::array::StringArray>()
74                .expect("subject column");
75            let pred_col = batch
76                .column(col::PREDICATE)
77                .as_any()
78                .downcast_ref::<arrow::array::StringArray>()
79                .expect("predicate column");
80            let obj_col = batch
81                .column(col::OBJECT)
82                .as_any()
83                .downcast_ref::<arrow::array::StringArray>()
84                .expect("object column");
85
86            for i in 0..batch.num_rows() {
87                if subj_col.value(i) == entry.subject
88                    && pred_col.value(i) == entry.predicate
89                    && obj_col.value(i) != entry.object
90                {
91                    conflict_count += 1;
92                }
93            }
94        }
95    }
96
97    if conflict_count > 0 {
98        return Err(CherryPickError::Conflict(conflict_count));
99    }
100
101    // Apply additions
102    for entry in &source_diff.added {
103        // Defensive fallbacks: if namespace or y_layer values from the diff
104        // are unrecognized (e.g., schema evolved), default to World/Semantic
105        // rather than failing the entire cherry-pick.
106        let ns = Namespace::from_str_loose(&entry.namespace).unwrap_or(Namespace::World);
107        let y_layer = YLayer::from_u8(entry.y_layer).unwrap_or(YLayer::Semantic);
108
109        // Skip if HEAD already has this exact triple
110        let already_exists = {
111            let batches = obj_store.store.get_namespace_batches(ns);
112            let mut found = false;
113            for batch in batches {
114                let subj_col = batch
115                    .column(col::SUBJECT)
116                    .as_any()
117                    .downcast_ref::<arrow::array::StringArray>()
118                    .expect("subject column");
119                let pred_col = batch
120                    .column(col::PREDICATE)
121                    .as_any()
122                    .downcast_ref::<arrow::array::StringArray>()
123                    .expect("predicate column");
124                let obj_col = batch
125                    .column(col::OBJECT)
126                    .as_any()
127                    .downcast_ref::<arrow::array::StringArray>()
128                    .expect("object column");
129
130                for i in 0..batch.num_rows() {
131                    if subj_col.value(i) == entry.subject
132                        && pred_col.value(i) == entry.predicate
133                        && obj_col.value(i) == entry.object
134                    {
135                        found = true;
136                        break;
137                    }
138                }
139                if found {
140                    break;
141                }
142            }
143            found
144        };
145
146        if !already_exists {
147            let triple = Triple {
148                subject: entry.subject.clone(),
149                predicate: entry.predicate.clone(),
150                object: entry.object.clone(),
151                graph: entry.graph.clone(),
152                confidence: entry.confidence,
153                source_document: entry.source_document.clone(),
154                source_chunk_id: entry.source_chunk_id.clone(),
155                extracted_by: Some(format!("cherry-pick by {author}")),
156                caused_by: entry.caused_by.clone(),
157                derived_from: entry.derived_from.clone(),
158                consolidated_at: entry.consolidated_at,
159                certifiability_class: entry.certifiability_class.clone(),
160            };
161            obj_store.store.add_triple(&triple, ns, y_layer)?;
162        }
163    }
164
165    // Apply removals — collect IDs to delete first to avoid borrow issues
166    let removals: HashSet<(String, String, String, String)> = source_diff
167        .removed
168        .iter()
169        .map(|e| {
170            (
171                e.subject.clone(),
172                e.predicate.clone(),
173                e.object.clone(),
174                e.namespace.clone(),
175            )
176        })
177        .collect();
178
179    for ns in Namespace::ALL {
180        let ns_str = ns.as_str().to_string();
181        let batches = obj_store.store.get_namespace_batches(ns);
182        let mut ids_to_delete = Vec::new();
183        for batch in batches {
184            let id_col = batch
185                .column(col::TRIPLE_ID)
186                .as_any()
187                .downcast_ref::<arrow::array::StringArray>()
188                .expect("triple_id column");
189            let subj_col = batch
190                .column(col::SUBJECT)
191                .as_any()
192                .downcast_ref::<arrow::array::StringArray>()
193                .expect("subject column");
194            let pred_col = batch
195                .column(col::PREDICATE)
196                .as_any()
197                .downcast_ref::<arrow::array::StringArray>()
198                .expect("predicate column");
199            let obj_col = batch
200                .column(col::OBJECT)
201                .as_any()
202                .downcast_ref::<arrow::array::StringArray>()
203                .expect("object column");
204
205            for i in 0..batch.num_rows() {
206                let key = (
207                    subj_col.value(i).to_string(),
208                    pred_col.value(i).to_string(),
209                    obj_col.value(i).to_string(),
210                    ns_str.clone(),
211                );
212                if removals.contains(&key) {
213                    ids_to_delete.push(id_col.value(i).to_string());
214                }
215            }
216        }
217        for id in &ids_to_delete {
218            // Best-effort delete: triple may already have been removed by a
219            // prior operation or may not exist in this snapshot. Swallowing
220            // the error is intentional — the removal set comes from a diff
221            // against a different commit, so missing IDs are expected.
222            let _ = obj_store.store.delete(id);
223        }
224    }
225
226    // Create cherry-pick commit with HEAD as parent
227    let cp_commit = create_commit(
228        obj_store,
229        commits_table,
230        vec![head_commit_id.to_string()],
231        &format!("Cherry-pick: {source_message}"),
232        author,
233    )?;
234
235    Ok(cp_commit.commit_id)
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use crate::checkout::checkout as git_checkout;
242    use crate::commit::create_commit;
243    use nusy_arrow_core::{Namespace, Triple, YLayer};
244
245    fn sample_triple(subj: &str, obj: &str) -> Triple {
246        Triple {
247            subject: subj.to_string(),
248            predicate: "rdf:type".to_string(),
249            object: obj.to_string(),
250            graph: None,
251            confidence: Some(0.9),
252            source_document: None,
253            source_chunk_id: None,
254            extracted_by: None,
255            caused_by: None,
256            derived_from: None,
257            consolidated_at: None,
258            certifiability_class: None,
259        }
260    }
261
262    #[test]
263    fn test_clean_cherry_pick() {
264        let tmp = tempfile::tempdir().unwrap();
265        let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
266        let mut commits = CommitsTable::new();
267
268        // Base: s1
269        obj.store
270            .add_triple(
271                &sample_triple("s1", "Base"),
272                Namespace::World,
273                YLayer::Semantic,
274            )
275            .unwrap();
276        let base = create_commit(&obj, &mut commits, vec![], "base", "DGX").unwrap();
277
278        // Feature branch: add s2
279        obj.store
280            .add_triple(
281                &sample_triple("s2", "Feature"),
282                Namespace::World,
283                YLayer::Semantic,
284            )
285            .unwrap();
286        let feature = create_commit(
287            &obj,
288            &mut commits,
289            vec![base.commit_id.clone()],
290            "add s2",
291            "DGX",
292        )
293        .unwrap();
294
295        // Main branch: add s3 (diverge from base)
296        git_checkout(&mut obj, &commits, &base.commit_id).unwrap();
297        obj.store
298            .add_triple(
299                &sample_triple("s3", "Main"),
300                Namespace::World,
301                YLayer::Semantic,
302            )
303            .unwrap();
304        let main_head = create_commit(
305            &obj,
306            &mut commits,
307            vec![base.commit_id.clone()],
308            "add s3",
309            "DGX",
310        )
311        .unwrap();
312
313        // Cherry-pick the feature commit onto main
314        let cp_id = cherry_pick(
315            &mut obj,
316            &mut commits,
317            &feature.commit_id,
318            &main_head.commit_id,
319            "DGX",
320        )
321        .unwrap();
322
323        // Main should now have s1 + s3 + s2 (cherry-picked)
324        assert_eq!(obj.store.len(), 3);
325
326        // Verify the cherry-pick commit
327        let cp = commits.get(&cp_id).unwrap();
328        assert!(cp.message.starts_with("Cherry-pick:"));
329        assert_eq!(cp.parent_ids, vec![main_head.commit_id]);
330    }
331
332    #[test]
333    fn test_cherry_pick_with_conflict() {
334        let tmp = tempfile::tempdir().unwrap();
335        let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
336        let mut commits = CommitsTable::new();
337
338        // Base
339        obj.store
340            .add_triple(
341                &sample_triple("s1", "Base"),
342                Namespace::World,
343                YLayer::Semantic,
344            )
345            .unwrap();
346        let base = create_commit(&obj, &mut commits, vec![], "base", "DGX").unwrap();
347
348        // Feature: add (conflict-subj, rdf:type, TypeA)
349        obj.store
350            .add_triple(
351                &sample_triple("conflict-subj", "TypeA"),
352                Namespace::World,
353                YLayer::Semantic,
354            )
355            .unwrap();
356        let feature = create_commit(
357            &obj,
358            &mut commits,
359            vec![base.commit_id.clone()],
360            "feature",
361            "DGX",
362        )
363        .unwrap();
364
365        // Main: add (conflict-subj, rdf:type, TypeB) — conflicts!
366        git_checkout(&mut obj, &commits, &base.commit_id).unwrap();
367        obj.store
368            .add_triple(
369                &sample_triple("conflict-subj", "TypeB"),
370                Namespace::World,
371                YLayer::Semantic,
372            )
373            .unwrap();
374        let main_head = create_commit(
375            &obj,
376            &mut commits,
377            vec![base.commit_id.clone()],
378            "main",
379            "DGX",
380        )
381        .unwrap();
382
383        // Cherry-pick should detect the conflict
384        let result = cherry_pick(
385            &mut obj,
386            &mut commits,
387            &feature.commit_id,
388            &main_head.commit_id,
389            "DGX",
390        );
391        assert!(result.is_err());
392        match result.unwrap_err() {
393            CherryPickError::Conflict(n) => assert!(n > 0),
394            other => panic!("Expected Conflict, got: {other:?}"),
395        }
396    }
397
398    #[test]
399    fn test_cherry_pick_preserves_content() {
400        let tmp = tempfile::tempdir().unwrap();
401        let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
402        let mut commits = CommitsTable::new();
403
404        // Base
405        obj.store
406            .add_triple(
407                &sample_triple("s1", "Base"),
408                Namespace::World,
409                YLayer::Semantic,
410            )
411            .unwrap();
412        let base = create_commit(&obj, &mut commits, vec![], "base", "DGX").unwrap();
413
414        // Feature: add multiple triples
415        obj.store
416            .add_triple(
417                &sample_triple("feat1", "F1"),
418                Namespace::World,
419                YLayer::Semantic,
420            )
421            .unwrap();
422        obj.store
423            .add_triple(
424                &sample_triple("feat2", "F2"),
425                Namespace::World,
426                YLayer::Semantic,
427            )
428            .unwrap();
429        let feature = create_commit(
430            &obj,
431            &mut commits,
432            vec![base.commit_id.clone()],
433            "feature",
434            "DGX",
435        )
436        .unwrap();
437
438        // Go back to base for cherry-pick
439        git_checkout(&mut obj, &commits, &base.commit_id).unwrap();
440        let main_head = create_commit(
441            &obj,
442            &mut commits,
443            vec![base.commit_id.clone()],
444            "main",
445            "DGX",
446        )
447        .unwrap();
448
449        cherry_pick(
450            &mut obj,
451            &mut commits,
452            &feature.commit_id,
453            &main_head.commit_id,
454            "DGX",
455        )
456        .unwrap();
457
458        // Should have s1 + feat1 + feat2
459        assert_eq!(obj.store.len(), 3);
460    }
461}