Skip to main content

nusy_arrow_git/
checkout.rs

1//! Checkout — load a Parquet snapshot back into the ArrowGraphStore.
2//!
3//! Restores the graph state from a previous commit by reading the
4//! namespace Parquet files and replacing the live store contents.
5
6use crate::commit::{CommitError, CommitsTable};
7use crate::object_store::GitObjectStore;
8use nusy_arrow_core::Namespace;
9use nusy_arrow_core::schema::normalize_to_current;
10use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
11use std::fs;
12
13pub type Result<T> = std::result::Result<T, CommitError>;
14
15/// Checkout a previous commit: load its Parquet snapshots into the live store.
16///
17/// This replaces the current store contents with the committed state.
18pub fn checkout(
19    obj_store: &mut GitObjectStore,
20    commits_table: &CommitsTable,
21    commit_id: &str,
22) -> Result<()> {
23    // Verify the commit exists
24    let _commit = commits_table
25        .get(commit_id)
26        .ok_or_else(|| CommitError::NotFound(commit_id.to_string()))?;
27
28    // Clear the current store
29    obj_store.store.clear();
30
31    // Load each namespace's Parquet file if it exists
32    for ns in Namespace::ALL {
33        let path = obj_store.namespace_parquet_path(commit_id, ns.as_str());
34        if !path.exists() {
35            continue;
36        }
37
38        let file = fs::File::open(&path)?;
39        let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
40
41        // Extract schema version from Parquet metadata (default to "1.0.0" if absent)
42        let version = builder
43            .metadata()
44            .file_metadata()
45            .key_value_metadata()
46            .and_then(|kv| {
47                kv.iter()
48                    .find(|e| e.key == "nusy_schema_version")
49                    .and_then(|e| e.value.clone())
50            })
51            .unwrap_or_else(|| "1.0.0".to_string());
52
53        let reader = builder.build()?;
54
55        let mut batches = Vec::new();
56        for batch_result in reader {
57            let batch = batch_result?;
58            // Normalize to current schema version on read
59            let normalized = normalize_to_current(&batch, &version)?;
60            batches.push(normalized);
61        }
62
63        obj_store.store.set_namespace_batches(ns, batches);
64    }
65
66    Ok(())
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72    use crate::commit::create_commit;
73    use nusy_arrow_core::{Namespace, QuerySpec, Triple, YLayer};
74
75    fn sample_triple(subj: &str) -> Triple {
76        Triple {
77            subject: subj.to_string(),
78            predicate: "rdf:type".to_string(),
79            object: "Thing".to_string(),
80            graph: None,
81            confidence: Some(0.9),
82            source_document: None,
83            source_chunk_id: None,
84            extracted_by: None,
85            caused_by: None,
86            derived_from: None,
87            consolidated_at: None,
88            certifiability_class: None,
89        }
90    }
91
92    #[test]
93    fn test_commit_checkout_roundtrip() {
94        let tmp = tempfile::tempdir().unwrap();
95        let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
96        let mut commits = CommitsTable::new();
97
98        // Add 1K triples and commit
99        let triples: Vec<Triple> = (0..1000).map(|i| sample_triple(&format!("s{i}"))).collect();
100        obj.store
101            .add_batch(&triples, Namespace::World, YLayer::Semantic)
102            .unwrap();
103
104        let c1 = create_commit(&obj, &mut commits, vec![], "with 1K", "DGX").unwrap();
105        assert_eq!(obj.store.len(), 1000);
106
107        // Add 500 more
108        let more: Vec<Triple> = (1000..1500)
109            .map(|i| sample_triple(&format!("s{i}")))
110            .collect();
111        obj.store
112            .add_batch(&more, Namespace::World, YLayer::Semantic)
113            .unwrap();
114        assert_eq!(obj.store.len(), 1500);
115
116        // Checkout previous commit — should have only 1K
117        checkout(&mut obj, &commits, &c1.commit_id).unwrap();
118        assert_eq!(obj.store.len(), 1000);
119
120        // Verify subjects: s0-s999 should exist, s1000+ should not
121        let q = obj
122            .store
123            .query(&QuerySpec {
124                subject: Some("s0".to_string()),
125                ..Default::default()
126            })
127            .unwrap();
128        let count: usize = q.iter().map(|b| b.num_rows()).sum();
129        assert_eq!(count, 1, "s0 should exist after checkout");
130
131        let q2 = obj
132            .store
133            .query(&QuerySpec {
134                subject: Some("s1000".to_string()),
135                ..Default::default()
136            })
137            .unwrap();
138        let count2: usize = q2.iter().map(|b| b.num_rows()).sum();
139        assert_eq!(count2, 0, "s1000 should NOT exist after checkout");
140    }
141
142    #[test]
143    fn test_checkout_nonexistent_commit_fails() {
144        let tmp = tempfile::tempdir().unwrap();
145        let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
146        let commits = CommitsTable::new();
147
148        let result = checkout(&mut obj, &commits, "nonexistent");
149        assert!(result.is_err());
150    }
151
152    #[test]
153    fn test_commit_checkout_multiple_namespaces() {
154        let tmp = tempfile::tempdir().unwrap();
155        let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
156        let mut commits = CommitsTable::new();
157
158        // Add to world and work
159        obj.store
160            .add_triple(
161                &sample_triple("world-s"),
162                Namespace::World,
163                YLayer::Semantic,
164            )
165            .unwrap();
166        obj.store
167            .add_triple(
168                &sample_triple("work-s"),
169                Namespace::Work,
170                YLayer::Procedural,
171            )
172            .unwrap();
173
174        let c1 = create_commit(&obj, &mut commits, vec![], "multi-ns", "DGX").unwrap();
175
176        // Clear and checkout
177        obj.store.clear();
178        assert_eq!(obj.store.len(), 0);
179
180        checkout(&mut obj, &commits, &c1.commit_id).unwrap();
181        assert_eq!(obj.store.len(), 2);
182    }
183
184    #[test]
185    fn test_commit_checkout_benchmark_10k() {
186        let tmp = tempfile::tempdir().unwrap();
187        let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
188        let mut commits = CommitsTable::new();
189
190        let triples: Vec<Triple> = (0..10_000)
191            .map(|i| sample_triple(&format!("bench{i}")))
192            .collect();
193        obj.store
194            .add_batch(&triples, Namespace::World, YLayer::Semantic)
195            .unwrap();
196
197        // Benchmark commit
198        let start = std::time::Instant::now();
199        let c1 = create_commit(&obj, &mut commits, vec![], "bench", "DGX").unwrap();
200        let commit_ms = start.elapsed().as_millis();
201
202        // Benchmark checkout
203        obj.store.clear();
204        let start = std::time::Instant::now();
205        checkout(&mut obj, &commits, &c1.commit_id).unwrap();
206        let checkout_ms = start.elapsed().as_millis();
207
208        assert_eq!(obj.store.len(), 10_000);
209
210        let total = commit_ms + checkout_ms;
211        eprintln!("10K commit: {commit_ms}ms, checkout: {checkout_ms}ms, total: {total}ms");
212        // Target: <50ms for commit+checkout round-trip
213        // Allow generous margin for CI — the important thing is it's fast
214        assert!(
215            total < 500,
216            "Round-trip took {total}ms — should be well under 500ms"
217        );
218    }
219}