Skip to main content

arrow_graph_git/
checkout.rs

1//! Checkout — load a Parquet snapshot back into the ArrowGraphStore.
2//!
3//! Restores the graph state from a previous commit by reading the
4//! namespace Parquet files and replacing the live store contents.
5
6use crate::commit::{CommitError, CommitsTable};
7use crate::object_store::GitObjectStore;
8use arrow_graph_core::schema::normalize_to_current;
9use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
10use std::fs;
11
12pub type Result<T> = std::result::Result<T, CommitError>;
13
14/// Checkout a previous commit: load its Parquet snapshots into the live store.
15///
16/// This replaces the current store contents with the committed state.
17pub fn checkout(
18    obj_store: &mut GitObjectStore,
19    commits_table: &CommitsTable,
20    commit_id: &str,
21) -> Result<()> {
22    // Verify the commit exists
23    let _commit = commits_table
24        .get(commit_id)
25        .ok_or_else(|| CommitError::NotFound(commit_id.to_string()))?;
26
27    // Clear the current store
28    obj_store.store.clear();
29
30    // Load each namespace's Parquet file if it exists
31    for ns in obj_store.store.namespaces().to_vec() {
32        let path = obj_store.namespace_parquet_path(commit_id, &ns);
33        if !path.exists() {
34            continue;
35        }
36
37        let file = fs::File::open(&path)?;
38        let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
39
40        // Extract schema version from Parquet metadata (default to "1.0.0" if absent)
41        let version = builder
42            .metadata()
43            .file_metadata()
44            .key_value_metadata()
45            .and_then(|kv| {
46                kv.iter()
47                    .find(|e| e.key == "schema_version" || e.key == "nusy_schema_version")
48                    .and_then(|e| e.value.clone())
49            })
50            .unwrap_or_else(|| "1.0.0".to_string());
51
52        let reader = builder.build()?;
53
54        let mut batches = Vec::new();
55        for batch_result in reader {
56            let batch = batch_result?;
57            // Normalize to current schema version on read
58            let normalized = normalize_to_current(&batch, &version)?;
59            batches.push(normalized);
60        }
61
62        obj_store.store.set_namespace_batches(&ns, batches);
63    }
64
65    Ok(())
66}
67
68#[cfg(test)]
69mod tests {
70    use super::*;
71    use crate::commit::create_commit;
72    use arrow_graph_core::{QuerySpec, Triple};
73
74    fn sample_triple(subj: &str) -> Triple {
75        Triple {
76            subject: subj.to_string(),
77            predicate: "rdf:type".to_string(),
78            object: "Thing".to_string(),
79            graph: None,
80            confidence: Some(0.9),
81            source_document: None,
82            source_chunk_id: None,
83            extracted_by: None,
84            caused_by: None,
85            derived_from: None,
86            consolidated_at: None,
87        }
88    }
89
90    #[test]
91    fn test_commit_checkout_roundtrip() {
92        let tmp = tempfile::tempdir().unwrap();
93        let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
94        let mut commits = CommitsTable::new();
95
96        // Add 1K triples and commit
97        let triples: Vec<Triple> = (0..1000).map(|i| sample_triple(&format!("s{i}"))).collect();
98        obj.store.add_batch(&triples, "world", Some(1u8)).unwrap();
99
100        let c1 = create_commit(&obj, &mut commits, vec![], "with 1K", "test").unwrap();
101        assert_eq!(obj.store.len(), 1000);
102
103        // Add 500 more
104        let more: Vec<Triple> = (1000..1500)
105            .map(|i| sample_triple(&format!("s{i}")))
106            .collect();
107        obj.store.add_batch(&more, "world", Some(1u8)).unwrap();
108        assert_eq!(obj.store.len(), 1500);
109
110        // Checkout previous commit — should have only 1K
111        checkout(&mut obj, &commits, &c1.commit_id).unwrap();
112        assert_eq!(obj.store.len(), 1000);
113
114        // Verify subjects: s0-s999 should exist, s1000+ should not
115        let q = obj
116            .store
117            .query(&QuerySpec {
118                subject: Some("s0".to_string()),
119                ..Default::default()
120            })
121            .unwrap();
122        let count: usize = q.iter().map(|b| b.num_rows()).sum();
123        assert_eq!(count, 1, "s0 should exist after checkout");
124
125        let q2 = obj
126            .store
127            .query(&QuerySpec {
128                subject: Some("s1000".to_string()),
129                ..Default::default()
130            })
131            .unwrap();
132        let count2: usize = q2.iter().map(|b| b.num_rows()).sum();
133        assert_eq!(count2, 0, "s1000 should NOT exist after checkout");
134    }
135
136    #[test]
137    fn test_checkout_nonexistent_commit_fails() {
138        let tmp = tempfile::tempdir().unwrap();
139        let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
140        let commits = CommitsTable::new();
141
142        let result = checkout(&mut obj, &commits, "nonexistent");
143        assert!(result.is_err());
144    }
145
146    #[test]
147    fn test_commit_checkout_multiple_namespaces() {
148        let tmp = tempfile::tempdir().unwrap();
149        let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
150        let mut commits = CommitsTable::new();
151
152        // Add to world and work
153        obj.store
154            .add_triple(&sample_triple("world-s"), "world", Some(1u8))
155            .unwrap();
156        obj.store
157            .add_triple(&sample_triple("work-s"), "work", Some(5u8))
158            .unwrap();
159
160        let c1 = create_commit(&obj, &mut commits, vec![], "multi-ns", "test").unwrap();
161
162        // Clear and checkout
163        obj.store.clear();
164        assert_eq!(obj.store.len(), 0);
165
166        checkout(&mut obj, &commits, &c1.commit_id).unwrap();
167        assert_eq!(obj.store.len(), 2);
168    }
169
170    #[test]
171    fn test_commit_checkout_benchmark_10k() {
172        let tmp = tempfile::tempdir().unwrap();
173        let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
174        let mut commits = CommitsTable::new();
175
176        let triples: Vec<Triple> = (0..10_000)
177            .map(|i| sample_triple(&format!("bench{i}")))
178            .collect();
179        obj.store.add_batch(&triples, "world", Some(1u8)).unwrap();
180
181        // Benchmark commit
182        let start = std::time::Instant::now();
183        let c1 = create_commit(&obj, &mut commits, vec![], "bench", "test").unwrap();
184        let commit_ms = start.elapsed().as_millis();
185
186        // Benchmark checkout
187        obj.store.clear();
188        let start = std::time::Instant::now();
189        checkout(&mut obj, &commits, &c1.commit_id).unwrap();
190        let checkout_ms = start.elapsed().as_millis();
191
192        assert_eq!(obj.store.len(), 10_000);
193
194        let total = commit_ms + checkout_ms;
195        eprintln!("10K commit: {commit_ms}ms, checkout: {checkout_ms}ms, total: {total}ms");
196        // Target: <50ms for commit+checkout round-trip
197        // Allow generous margin for CI — the important thing is it's fast
198        assert!(
199            total < 500,
200            "Round-trip took {total}ms — should be well under 500ms"
201        );
202    }
203}