arrow_graph_git/
checkout.rs1use crate::commit::{CommitError, CommitsTable};
7use crate::object_store::GitObjectStore;
8use arrow_graph_core::schema::normalize_to_current;
9use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
10use std::fs;
11
12pub type Result<T> = std::result::Result<T, CommitError>;
13
14pub fn checkout(
18 obj_store: &mut GitObjectStore,
19 commits_table: &CommitsTable,
20 commit_id: &str,
21) -> Result<()> {
22 let _commit = commits_table
24 .get(commit_id)
25 .ok_or_else(|| CommitError::NotFound(commit_id.to_string()))?;
26
27 obj_store.store.clear();
29
30 for ns in obj_store.store.namespaces().to_vec() {
32 let path = obj_store.namespace_parquet_path(commit_id, &ns);
33 if !path.exists() {
34 continue;
35 }
36
37 let file = fs::File::open(&path)?;
38 let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
39
40 let version = builder
42 .metadata()
43 .file_metadata()
44 .key_value_metadata()
45 .and_then(|kv| {
46 kv.iter()
47 .find(|e| e.key == "schema_version" || e.key == "nusy_schema_version")
48 .and_then(|e| e.value.clone())
49 })
50 .unwrap_or_else(|| "1.0.0".to_string());
51
52 let reader = builder.build()?;
53
54 let mut batches = Vec::new();
55 for batch_result in reader {
56 let batch = batch_result?;
57 let normalized = normalize_to_current(&batch, &version)?;
59 batches.push(normalized);
60 }
61
62 obj_store.store.set_namespace_batches(&ns, batches);
63 }
64
65 Ok(())
66}
67
68#[cfg(test)]
69mod tests {
70 use super::*;
71 use crate::commit::create_commit;
72 use arrow_graph_core::{QuerySpec, Triple};
73
74 fn sample_triple(subj: &str) -> Triple {
75 Triple {
76 subject: subj.to_string(),
77 predicate: "rdf:type".to_string(),
78 object: "Thing".to_string(),
79 graph: None,
80 confidence: Some(0.9),
81 source_document: None,
82 source_chunk_id: None,
83 extracted_by: None,
84 caused_by: None,
85 derived_from: None,
86 consolidated_at: None,
87 }
88 }
89
90 #[test]
91 fn test_commit_checkout_roundtrip() {
92 let tmp = tempfile::tempdir().unwrap();
93 let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
94 let mut commits = CommitsTable::new();
95
96 let triples: Vec<Triple> = (0..1000).map(|i| sample_triple(&format!("s{i}"))).collect();
98 obj.store.add_batch(&triples, "world", Some(1u8)).unwrap();
99
100 let c1 = create_commit(&obj, &mut commits, vec![], "with 1K", "test").unwrap();
101 assert_eq!(obj.store.len(), 1000);
102
103 let more: Vec<Triple> = (1000..1500)
105 .map(|i| sample_triple(&format!("s{i}")))
106 .collect();
107 obj.store.add_batch(&more, "world", Some(1u8)).unwrap();
108 assert_eq!(obj.store.len(), 1500);
109
110 checkout(&mut obj, &commits, &c1.commit_id).unwrap();
112 assert_eq!(obj.store.len(), 1000);
113
114 let q = obj
116 .store
117 .query(&QuerySpec {
118 subject: Some("s0".to_string()),
119 ..Default::default()
120 })
121 .unwrap();
122 let count: usize = q.iter().map(|b| b.num_rows()).sum();
123 assert_eq!(count, 1, "s0 should exist after checkout");
124
125 let q2 = obj
126 .store
127 .query(&QuerySpec {
128 subject: Some("s1000".to_string()),
129 ..Default::default()
130 })
131 .unwrap();
132 let count2: usize = q2.iter().map(|b| b.num_rows()).sum();
133 assert_eq!(count2, 0, "s1000 should NOT exist after checkout");
134 }
135
136 #[test]
137 fn test_checkout_nonexistent_commit_fails() {
138 let tmp = tempfile::tempdir().unwrap();
139 let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
140 let commits = CommitsTable::new();
141
142 let result = checkout(&mut obj, &commits, "nonexistent");
143 assert!(result.is_err());
144 }
145
146 #[test]
147 fn test_commit_checkout_multiple_namespaces() {
148 let tmp = tempfile::tempdir().unwrap();
149 let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
150 let mut commits = CommitsTable::new();
151
152 obj.store
154 .add_triple(&sample_triple("world-s"), "world", Some(1u8))
155 .unwrap();
156 obj.store
157 .add_triple(&sample_triple("work-s"), "work", Some(5u8))
158 .unwrap();
159
160 let c1 = create_commit(&obj, &mut commits, vec![], "multi-ns", "test").unwrap();
161
162 obj.store.clear();
164 assert_eq!(obj.store.len(), 0);
165
166 checkout(&mut obj, &commits, &c1.commit_id).unwrap();
167 assert_eq!(obj.store.len(), 2);
168 }
169
170 #[test]
171 fn test_commit_checkout_benchmark_10k() {
172 let tmp = tempfile::tempdir().unwrap();
173 let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
174 let mut commits = CommitsTable::new();
175
176 let triples: Vec<Triple> = (0..10_000)
177 .map(|i| sample_triple(&format!("bench{i}")))
178 .collect();
179 obj.store.add_batch(&triples, "world", Some(1u8)).unwrap();
180
181 let start = std::time::Instant::now();
183 let c1 = create_commit(&obj, &mut commits, vec![], "bench", "test").unwrap();
184 let commit_ms = start.elapsed().as_millis();
185
186 obj.store.clear();
188 let start = std::time::Instant::now();
189 checkout(&mut obj, &commits, &c1.commit_id).unwrap();
190 let checkout_ms = start.elapsed().as_millis();
191
192 assert_eq!(obj.store.len(), 10_000);
193
194 let total = commit_ms + checkout_ms;
195 eprintln!("10K commit: {commit_ms}ms, checkout: {checkout_ms}ms, total: {total}ms");
196 assert!(
199 total < 500,
200 "Round-trip took {total}ms — should be well under 500ms"
201 );
202 }
203}