1use crate::checkout;
7use crate::commit::{CommitError, CommitsTable, create_commit};
8use crate::diff;
9use crate::object_store::GitObjectStore;
10use arrow_graph_core::{Triple, col};
11use std::collections::HashSet;
12
13#[derive(Debug, thiserror::Error)]
15pub enum CherryPickError {
16 #[error("Commit error: {0}")]
17 Commit(#[from] CommitError),
18
19 #[error("Store error: {0}")]
20 Store(#[from] arrow_graph_core::StoreError),
21
22 #[error("Commit has no parent: {0}")]
23 NoParent(String),
24
25 #[error("Cherry-pick conflict: {0} triples conflict with HEAD")]
26 Conflict(usize),
27}
28
29pub fn cherry_pick(
37 obj_store: &mut GitObjectStore,
38 commits_table: &mut CommitsTable,
39 source_commit_id: &str,
40 head_commit_id: &str,
41 author: &str,
42) -> Result<String, CherryPickError> {
43 let source = commits_table
44 .get(source_commit_id)
45 .ok_or_else(|| CommitError::NotFound(source_commit_id.to_string()))?;
46
47 if source.parent_ids.is_empty() {
48 return Err(CherryPickError::NoParent(source_commit_id.to_string()));
49 }
50
51 let parent_id = source.parent_ids[0].clone();
52 let source_message = source.message.clone();
53
54 let source_diff = diff::diff(obj_store, commits_table, &parent_id, source_commit_id)?;
56
57 checkout::checkout(obj_store, commits_table, head_commit_id)?;
59
60 let mut conflict_count = 0;
63 for entry in &source_diff.added {
64 let batches = obj_store.store.get_namespace_batches(&entry.namespace);
65 for batch in batches {
66 let subj_col = batch
67 .column(col::SUBJECT)
68 .as_any()
69 .downcast_ref::<arrow::array::StringArray>()
70 .expect("subject column");
71 let pred_col = batch
72 .column(col::PREDICATE)
73 .as_any()
74 .downcast_ref::<arrow::array::StringArray>()
75 .expect("predicate column");
76 let obj_col = batch
77 .column(col::OBJECT)
78 .as_any()
79 .downcast_ref::<arrow::array::StringArray>()
80 .expect("object column");
81
82 for i in 0..batch.num_rows() {
83 if subj_col.value(i) == entry.subject
84 && pred_col.value(i) == entry.predicate
85 && obj_col.value(i) != entry.object
86 {
87 conflict_count += 1;
88 }
89 }
90 }
91 }
92
93 if conflict_count > 0 {
94 return Err(CherryPickError::Conflict(conflict_count));
95 }
96
97 for entry in &source_diff.added {
99 let already_exists = {
101 let batches = obj_store.store.get_namespace_batches(&entry.namespace);
102 let mut found = false;
103 for batch in batches {
104 let subj_col = batch
105 .column(col::SUBJECT)
106 .as_any()
107 .downcast_ref::<arrow::array::StringArray>()
108 .expect("subject column");
109 let pred_col = batch
110 .column(col::PREDICATE)
111 .as_any()
112 .downcast_ref::<arrow::array::StringArray>()
113 .expect("predicate column");
114 let obj_col = batch
115 .column(col::OBJECT)
116 .as_any()
117 .downcast_ref::<arrow::array::StringArray>()
118 .expect("object column");
119
120 for i in 0..batch.num_rows() {
121 if subj_col.value(i) == entry.subject
122 && pred_col.value(i) == entry.predicate
123 && obj_col.value(i) == entry.object
124 {
125 found = true;
126 break;
127 }
128 }
129 if found {
130 break;
131 }
132 }
133 found
134 };
135
136 if !already_exists {
137 let triple = Triple {
138 subject: entry.subject.clone(),
139 predicate: entry.predicate.clone(),
140 object: entry.object.clone(),
141 graph: entry.graph.clone(),
142 confidence: entry.confidence,
143 source_document: entry.source_document.clone(),
144 source_chunk_id: entry.source_chunk_id.clone(),
145 extracted_by: Some(format!("cherry-pick by {author}")),
146 caused_by: entry.caused_by.clone(),
147 derived_from: entry.derived_from.clone(),
148 consolidated_at: entry.consolidated_at,
149 };
150 obj_store
151 .store
152 .add_triple(&triple, &entry.namespace, Some(entry.y_layer))?;
153 }
154 }
155
156 let removals: HashSet<(String, String, String, String)> = source_diff
158 .removed
159 .iter()
160 .map(|e| {
161 (
162 e.subject.clone(),
163 e.predicate.clone(),
164 e.object.clone(),
165 e.namespace.clone(),
166 )
167 })
168 .collect();
169
170 for ns in obj_store.store.namespaces().to_vec() {
171 let batches = obj_store.store.get_namespace_batches(&ns);
172 let mut ids_to_delete = Vec::new();
173 for batch in batches {
174 let id_col = batch
175 .column(col::TRIPLE_ID)
176 .as_any()
177 .downcast_ref::<arrow::array::StringArray>()
178 .expect("triple_id column");
179 let subj_col = batch
180 .column(col::SUBJECT)
181 .as_any()
182 .downcast_ref::<arrow::array::StringArray>()
183 .expect("subject column");
184 let pred_col = batch
185 .column(col::PREDICATE)
186 .as_any()
187 .downcast_ref::<arrow::array::StringArray>()
188 .expect("predicate column");
189 let obj_col = batch
190 .column(col::OBJECT)
191 .as_any()
192 .downcast_ref::<arrow::array::StringArray>()
193 .expect("object column");
194
195 for i in 0..batch.num_rows() {
196 let key = (
197 subj_col.value(i).to_string(),
198 pred_col.value(i).to_string(),
199 obj_col.value(i).to_string(),
200 ns.clone(),
201 );
202 if removals.contains(&key) {
203 ids_to_delete.push(id_col.value(i).to_string());
204 }
205 }
206 }
207 for id in &ids_to_delete {
208 let _ = obj_store.store.delete(id);
209 }
210 }
211
212 let cp_commit = create_commit(
214 obj_store,
215 commits_table,
216 vec![head_commit_id.to_string()],
217 &format!("Cherry-pick: {source_message}"),
218 author,
219 )?;
220
221 Ok(cp_commit.commit_id)
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use crate::checkout::checkout as git_checkout;
228 use crate::commit::create_commit;
229 use arrow_graph_core::Triple;
230
231 fn sample_triple(subj: &str, obj: &str) -> Triple {
232 Triple {
233 subject: subj.to_string(),
234 predicate: "rdf:type".to_string(),
235 object: obj.to_string(),
236 graph: None,
237 confidence: Some(0.9),
238 source_document: None,
239 source_chunk_id: None,
240 extracted_by: None,
241 caused_by: None,
242 derived_from: None,
243 consolidated_at: None,
244 }
245 }
246
247 #[test]
248 fn test_clean_cherry_pick() {
249 let tmp = tempfile::tempdir().unwrap();
250 let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
251 let mut commits = CommitsTable::new();
252
253 obj.store
254 .add_triple(&sample_triple("s1", "Base"), "world", Some(1u8))
255 .unwrap();
256 let base = create_commit(&obj, &mut commits, vec![], "base", "test").unwrap();
257
258 obj.store
259 .add_triple(&sample_triple("s2", "Feature"), "world", Some(1u8))
260 .unwrap();
261 let feature = create_commit(
262 &obj,
263 &mut commits,
264 vec![base.commit_id.clone()],
265 "add s2",
266 "test",
267 )
268 .unwrap();
269
270 git_checkout(&mut obj, &commits, &base.commit_id).unwrap();
271 obj.store
272 .add_triple(&sample_triple("s3", "Main"), "world", Some(1u8))
273 .unwrap();
274 let main_head = create_commit(
275 &obj,
276 &mut commits,
277 vec![base.commit_id.clone()],
278 "add s3",
279 "test",
280 )
281 .unwrap();
282
283 let cp_id = cherry_pick(
284 &mut obj,
285 &mut commits,
286 &feature.commit_id,
287 &main_head.commit_id,
288 "test",
289 )
290 .unwrap();
291
292 assert_eq!(obj.store.len(), 3);
293 let cp = commits.get(&cp_id).unwrap();
294 assert!(cp.message.starts_with("Cherry-pick:"));
295 assert_eq!(cp.parent_ids, vec![main_head.commit_id]);
296 }
297
298 #[test]
299 fn test_cherry_pick_with_conflict() {
300 let tmp = tempfile::tempdir().unwrap();
301 let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
302 let mut commits = CommitsTable::new();
303
304 obj.store
305 .add_triple(&sample_triple("s1", "Base"), "world", Some(1u8))
306 .unwrap();
307 let base = create_commit(&obj, &mut commits, vec![], "base", "test").unwrap();
308
309 obj.store
310 .add_triple(&sample_triple("conflict-subj", "TypeA"), "world", Some(1u8))
311 .unwrap();
312 let feature = create_commit(
313 &obj,
314 &mut commits,
315 vec![base.commit_id.clone()],
316 "feature",
317 "test",
318 )
319 .unwrap();
320
321 git_checkout(&mut obj, &commits, &base.commit_id).unwrap();
322 obj.store
323 .add_triple(&sample_triple("conflict-subj", "TypeB"), "world", Some(1u8))
324 .unwrap();
325 let main_head = create_commit(
326 &obj,
327 &mut commits,
328 vec![base.commit_id.clone()],
329 "main",
330 "test",
331 )
332 .unwrap();
333
334 let result = cherry_pick(
335 &mut obj,
336 &mut commits,
337 &feature.commit_id,
338 &main_head.commit_id,
339 "test",
340 );
341 assert!(result.is_err());
342 match result.unwrap_err() {
343 CherryPickError::Conflict(n) => assert!(n > 0),
344 other => panic!("Expected Conflict, got: {other:?}"),
345 }
346 }
347
348 #[test]
349 fn test_cherry_pick_preserves_content() {
350 let tmp = tempfile::tempdir().unwrap();
351 let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
352 let mut commits = CommitsTable::new();
353
354 obj.store
355 .add_triple(&sample_triple("s1", "Base"), "world", Some(1u8))
356 .unwrap();
357 let base = create_commit(&obj, &mut commits, vec![], "base", "test").unwrap();
358
359 obj.store
360 .add_triple(&sample_triple("feat1", "F1"), "world", Some(1u8))
361 .unwrap();
362 obj.store
363 .add_triple(&sample_triple("feat2", "F2"), "world", Some(1u8))
364 .unwrap();
365 let feature = create_commit(
366 &obj,
367 &mut commits,
368 vec![base.commit_id.clone()],
369 "feature",
370 "test",
371 )
372 .unwrap();
373
374 git_checkout(&mut obj, &commits, &base.commit_id).unwrap();
375 let main_head = create_commit(
376 &obj,
377 &mut commits,
378 vec![base.commit_id.clone()],
379 "main",
380 "test",
381 )
382 .unwrap();
383
384 cherry_pick(
385 &mut obj,
386 &mut commits,
387 &feature.commit_id,
388 &main_head.commit_id,
389 "test",
390 )
391 .unwrap();
392
393 assert_eq!(obj.store.len(), 3);
394 }
395}