use crate::checkout;
use crate::commit::{CommitError, CommitsTable, create_commit};
use crate::diff;
use crate::object_store::GitObjectStore;
use nusy_arrow_core::{Namespace, Triple, YLayer, col};
use std::collections::HashSet;
#[derive(Debug, thiserror::Error)]
pub enum CherryPickError {
#[error("Commit error: {0}")]
Commit(#[from] CommitError),
#[error("Store error: {0}")]
Store(#[from] nusy_arrow_core::StoreError),
#[error("Commit has no parent: {0}")]
NoParent(String),
#[error("Cherry-pick conflict: {0} triples conflict with HEAD")]
Conflict(usize),
}
pub fn cherry_pick(
obj_store: &mut GitObjectStore,
commits_table: &mut CommitsTable,
source_commit_id: &str,
head_commit_id: &str,
author: &str,
) -> Result<String, CherryPickError> {
let source = commits_table
.get(source_commit_id)
.ok_or_else(|| CommitError::NotFound(source_commit_id.to_string()))?;
if source.parent_ids.is_empty() {
return Err(CherryPickError::NoParent(source_commit_id.to_string()));
}
let parent_id = source.parent_ids[0].clone();
let source_message = source.message.clone();
let source_diff = diff::diff(obj_store, commits_table, &parent_id, source_commit_id)?;
checkout::checkout(obj_store, commits_table, head_commit_id)?;
let mut conflict_count = 0;
for entry in &source_diff.added {
let ns = Namespace::from_str_loose(&entry.namespace).unwrap_or(Namespace::World);
let batches = obj_store.store.get_namespace_batches(ns);
for batch in batches {
let subj_col = batch
.column(col::SUBJECT)
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.expect("subject column");
let pred_col = batch
.column(col::PREDICATE)
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.expect("predicate column");
let obj_col = batch
.column(col::OBJECT)
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.expect("object column");
for i in 0..batch.num_rows() {
if subj_col.value(i) == entry.subject
&& pred_col.value(i) == entry.predicate
&& obj_col.value(i) != entry.object
{
conflict_count += 1;
}
}
}
}
if conflict_count > 0 {
return Err(CherryPickError::Conflict(conflict_count));
}
for entry in &source_diff.added {
let ns = Namespace::from_str_loose(&entry.namespace).unwrap_or(Namespace::World);
let y_layer = YLayer::from_u8(entry.y_layer).unwrap_or(YLayer::Semantic);
let already_exists = {
let batches = obj_store.store.get_namespace_batches(ns);
let mut found = false;
for batch in batches {
let subj_col = batch
.column(col::SUBJECT)
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.expect("subject column");
let pred_col = batch
.column(col::PREDICATE)
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.expect("predicate column");
let obj_col = batch
.column(col::OBJECT)
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.expect("object column");
for i in 0..batch.num_rows() {
if subj_col.value(i) == entry.subject
&& pred_col.value(i) == entry.predicate
&& obj_col.value(i) == entry.object
{
found = true;
break;
}
}
if found {
break;
}
}
found
};
if !already_exists {
let triple = Triple {
subject: entry.subject.clone(),
predicate: entry.predicate.clone(),
object: entry.object.clone(),
graph: entry.graph.clone(),
confidence: entry.confidence,
source_document: entry.source_document.clone(),
source_chunk_id: entry.source_chunk_id.clone(),
extracted_by: Some(format!("cherry-pick by {author}")),
caused_by: entry.caused_by.clone(),
derived_from: entry.derived_from.clone(),
consolidated_at: entry.consolidated_at,
certifiability_class: entry.certifiability_class.clone(),
};
obj_store.store.add_triple(&triple, ns, y_layer)?;
}
}
let removals: HashSet<(String, String, String, String)> = source_diff
.removed
.iter()
.map(|e| {
(
e.subject.clone(),
e.predicate.clone(),
e.object.clone(),
e.namespace.clone(),
)
})
.collect();
for ns in Namespace::ALL {
let ns_str = ns.as_str().to_string();
let batches = obj_store.store.get_namespace_batches(ns);
let mut ids_to_delete = Vec::new();
for batch in batches {
let id_col = batch
.column(col::TRIPLE_ID)
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.expect("triple_id column");
let subj_col = batch
.column(col::SUBJECT)
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.expect("subject column");
let pred_col = batch
.column(col::PREDICATE)
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.expect("predicate column");
let obj_col = batch
.column(col::OBJECT)
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.expect("object column");
for i in 0..batch.num_rows() {
let key = (
subj_col.value(i).to_string(),
pred_col.value(i).to_string(),
obj_col.value(i).to_string(),
ns_str.clone(),
);
if removals.contains(&key) {
ids_to_delete.push(id_col.value(i).to_string());
}
}
}
for id in &ids_to_delete {
let _ = obj_store.store.delete(id);
}
}
let cp_commit = create_commit(
obj_store,
commits_table,
vec![head_commit_id.to_string()],
&format!("Cherry-pick: {source_message}"),
author,
)?;
Ok(cp_commit.commit_id)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::checkout::checkout as git_checkout;
use crate::commit::create_commit;
use nusy_arrow_core::{Namespace, Triple, YLayer};
fn sample_triple(subj: &str, obj: &str) -> Triple {
Triple {
subject: subj.to_string(),
predicate: "rdf:type".to_string(),
object: obj.to_string(),
graph: None,
confidence: Some(0.9),
source_document: None,
source_chunk_id: None,
extracted_by: None,
caused_by: None,
derived_from: None,
consolidated_at: None,
certifiability_class: None,
}
}
#[test]
fn test_clean_cherry_pick() {
let tmp = tempfile::tempdir().unwrap();
let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
let mut commits = CommitsTable::new();
obj.store
.add_triple(
&sample_triple("s1", "Base"),
Namespace::World,
YLayer::Semantic,
)
.unwrap();
let base = create_commit(&obj, &mut commits, vec![], "base", "DGX").unwrap();
obj.store
.add_triple(
&sample_triple("s2", "Feature"),
Namespace::World,
YLayer::Semantic,
)
.unwrap();
let feature = create_commit(
&obj,
&mut commits,
vec![base.commit_id.clone()],
"add s2",
"DGX",
)
.unwrap();
git_checkout(&mut obj, &commits, &base.commit_id).unwrap();
obj.store
.add_triple(
&sample_triple("s3", "Main"),
Namespace::World,
YLayer::Semantic,
)
.unwrap();
let main_head = create_commit(
&obj,
&mut commits,
vec![base.commit_id.clone()],
"add s3",
"DGX",
)
.unwrap();
let cp_id = cherry_pick(
&mut obj,
&mut commits,
&feature.commit_id,
&main_head.commit_id,
"DGX",
)
.unwrap();
assert_eq!(obj.store.len(), 3);
let cp = commits.get(&cp_id).unwrap();
assert!(cp.message.starts_with("Cherry-pick:"));
assert_eq!(cp.parent_ids, vec![main_head.commit_id]);
}
#[test]
fn test_cherry_pick_with_conflict() {
let tmp = tempfile::tempdir().unwrap();
let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
let mut commits = CommitsTable::new();
obj.store
.add_triple(
&sample_triple("s1", "Base"),
Namespace::World,
YLayer::Semantic,
)
.unwrap();
let base = create_commit(&obj, &mut commits, vec![], "base", "DGX").unwrap();
obj.store
.add_triple(
&sample_triple("conflict-subj", "TypeA"),
Namespace::World,
YLayer::Semantic,
)
.unwrap();
let feature = create_commit(
&obj,
&mut commits,
vec![base.commit_id.clone()],
"feature",
"DGX",
)
.unwrap();
git_checkout(&mut obj, &commits, &base.commit_id).unwrap();
obj.store
.add_triple(
&sample_triple("conflict-subj", "TypeB"),
Namespace::World,
YLayer::Semantic,
)
.unwrap();
let main_head = create_commit(
&obj,
&mut commits,
vec![base.commit_id.clone()],
"main",
"DGX",
)
.unwrap();
let result = cherry_pick(
&mut obj,
&mut commits,
&feature.commit_id,
&main_head.commit_id,
"DGX",
);
assert!(result.is_err());
match result.unwrap_err() {
CherryPickError::Conflict(n) => assert!(n > 0),
other => panic!("Expected Conflict, got: {other:?}"),
}
}
#[test]
fn test_cherry_pick_preserves_content() {
let tmp = tempfile::tempdir().unwrap();
let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
let mut commits = CommitsTable::new();
obj.store
.add_triple(
&sample_triple("s1", "Base"),
Namespace::World,
YLayer::Semantic,
)
.unwrap();
let base = create_commit(&obj, &mut commits, vec![], "base", "DGX").unwrap();
obj.store
.add_triple(
&sample_triple("feat1", "F1"),
Namespace::World,
YLayer::Semantic,
)
.unwrap();
obj.store
.add_triple(
&sample_triple("feat2", "F2"),
Namespace::World,
YLayer::Semantic,
)
.unwrap();
let feature = create_commit(
&obj,
&mut commits,
vec![base.commit_id.clone()],
"feature",
"DGX",
)
.unwrap();
git_checkout(&mut obj, &commits, &base.commit_id).unwrap();
let main_head = create_commit(
&obj,
&mut commits,
vec![base.commit_id.clone()],
"main",
"DGX",
)
.unwrap();
cherry_pick(
&mut obj,
&mut commits,
&feature.commit_id,
&main_head.commit_id,
"DGX",
)
.unwrap();
assert_eq!(obj.store.len(), 3);
}
}