1use crate::commit::CommitsTable;
20use crate::object_store::GitObjectStore;
21use crate::refs::RefsTable;
22use crate::save::{deserialize_commits, deserialize_refs, serialize_commits, serialize_refs};
23use bytes::Bytes;
24use parquet::arrow::ArrowWriter;
25use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
26use std::io::Cursor;
27
28#[derive(Debug, thiserror::Error)]
34pub enum RemoteError {
35 #[error("Parquet error: {0}")]
36 Parquet(#[from] parquet::errors::ParquetError),
37
38 #[error("Arrow error: {0}")]
39 Arrow(#[from] arrow::error::ArrowError),
40
41 #[error("Invalid snapshot: {0}")]
42 InvalidSnapshot(String),
43}
44
45pub struct Snapshot {
50 pub namespaces: Vec<(String, Vec<u8>)>,
52 pub commits_json: String,
54 pub refs_json: String,
56}
57
58pub fn snapshot_state(
63 obj_store: &GitObjectStore,
64 commits_table: &CommitsTable,
65 refs_table: &RefsTable,
66) -> Result<Snapshot, RemoteError> {
67 let mut namespaces = Vec::new();
68
69 for ns in obj_store.store.namespaces() {
70 let batches = obj_store.store.get_namespace_batches(ns);
71 if batches.is_empty() {
72 continue;
73 }
74
75 let schema = obj_store.store.schema().clone();
76 let mut buf = Vec::new();
77 {
78 let cursor = Cursor::new(&mut buf);
79 let mut writer = ArrowWriter::try_new(cursor, schema, None)?;
80 for batch in batches {
81 writer.write(batch)?;
82 }
83 writer.close()?;
84 }
85
86 namespaces.push((ns.to_string(), buf));
87 }
88
89 let commits_json = serialize_commits(commits_table);
91 let refs_json = serialize_refs(refs_table);
92
93 Ok(Snapshot {
94 namespaces,
95 commits_json,
96 refs_json,
97 })
98}
99
100pub fn restore_snapshot(
106 obj_store: &mut GitObjectStore,
107 snapshot: &Snapshot,
108) -> Result<(CommitsTable, RefsTable), RemoteError> {
109 obj_store.store.clear();
110
111 for (ns_name, parquet_bytes) in &snapshot.namespaces {
113 let bytes = Bytes::from(parquet_bytes.clone());
114 let reader = ParquetRecordBatchReaderBuilder::try_new(bytes)?.build()?;
115
116 let mut batches = Vec::new();
117 for batch_result in reader {
118 batches.push(batch_result?);
119 }
120
121 obj_store.store.set_namespace_batches(ns_name, batches);
122 }
123
124 let commits = if snapshot.commits_json.is_empty() {
126 CommitsTable::new()
127 } else {
128 deserialize_commits(&snapshot.commits_json)
129 };
130
131 let refs = if snapshot.refs_json.is_empty() {
132 RefsTable::new()
133 } else {
134 deserialize_refs(&snapshot.refs_json)
135 };
136
137 Ok((commits, refs))
138}
139
140pub fn snapshot_to_bytes(snapshot: &Snapshot) -> Vec<u8> {
145 let mut ns_entries = Vec::new();
147 let mut offset = 0u64;
148 for (name, data) in &snapshot.namespaces {
149 ns_entries.push(format!(
150 "{{\"name\":\"{}\",\"offset\":{},\"len\":{}}}",
151 name,
152 offset,
153 data.len()
154 ));
155 offset += data.len() as u64;
156 }
157
158 let manifest = format!(
159 "{{\"commits\":{},\"refs\":{},\"namespaces\":[{}]}}",
160 &snapshot.commits_json,
161 &snapshot.refs_json,
162 ns_entries.join(",")
163 );
164
165 let manifest_bytes = manifest.as_bytes();
166 let manifest_len = (manifest_bytes.len() as u64).to_le_bytes();
167
168 let mut result = Vec::new();
169 result.extend_from_slice(&manifest_len);
170 result.extend_from_slice(manifest_bytes);
171 for (_, data) in &snapshot.namespaces {
172 result.extend_from_slice(data);
173 }
174
175 result
176}
177
178pub fn bytes_to_snapshot(bytes: &[u8]) -> Result<Snapshot, RemoteError> {
180 if bytes.len() < 8 {
181 return Err(RemoteError::InvalidSnapshot("Too short".into()));
182 }
183
184 let manifest_len = u64::from_le_bytes(bytes[..8].try_into().unwrap()) as usize;
185 if bytes.len() < 8 + manifest_len {
186 return Err(RemoteError::InvalidSnapshot("Manifest truncated".into()));
187 }
188
189 let manifest_str = std::str::from_utf8(&bytes[8..8 + manifest_len])
190 .map_err(|e| RemoteError::InvalidSnapshot(format!("Invalid UTF-8: {e}")))?;
191
192 let commits_json = extract_json_value(manifest_str, "commits").unwrap_or_default();
194 let refs_json = extract_json_value(manifest_str, "refs").unwrap_or_default();
195
196 let data_start = 8 + manifest_len;
198 let ns_section = extract_json_value(manifest_str, "namespaces").unwrap_or_default();
199
200 let mut namespaces = Vec::new();
201 for entry in extract_json_objects(&ns_section) {
203 let name = extract_json_string_field(&entry, "name").unwrap_or_default();
204 let offset = extract_json_number_field(&entry, "offset").unwrap_or(0) as usize;
205 let len = extract_json_number_field(&entry, "len").unwrap_or(0) as usize;
206
207 if data_start + offset + len <= bytes.len() {
208 let data = bytes[data_start + offset..data_start + offset + len].to_vec();
209 namespaces.push((name, data));
210 }
211 }
212
213 Ok(Snapshot {
214 namespaces,
215 commits_json,
216 refs_json,
217 })
218}
219
220fn extract_json_value(json: &str, key: &str) -> Option<String> {
223 let pattern = format!("\"{}\":", key);
224 let start = json.find(&pattern)? + pattern.len();
225 let rest = json[start..].trim_start();
226
227 if rest.starts_with('[') {
228 let mut depth = 0;
229 let mut end = 0;
230 for (i, ch) in rest.char_indices() {
231 match ch {
232 '[' => depth += 1,
233 ']' => {
234 depth -= 1;
235 if depth == 0 {
236 end = i + 1;
237 break;
238 }
239 }
240 _ => {}
241 }
242 }
243 Some(rest[..end].to_string())
244 } else if rest.starts_with('{') {
245 let mut depth = 0;
246 let mut end = 0;
247 for (i, ch) in rest.char_indices() {
248 match ch {
249 '{' => depth += 1,
250 '}' => {
251 depth -= 1;
252 if depth == 0 {
253 end = i + 1;
254 break;
255 }
256 }
257 _ => {}
258 }
259 }
260 Some(rest[..end].to_string())
261 } else {
262 let end = rest.find([',', '}']).unwrap_or(rest.len());
263 Some(rest[..end].trim().to_string())
264 }
265}
266
267fn extract_json_objects(json: &str) -> Vec<String> {
268 let mut objects = Vec::new();
269 let mut depth = 0;
270 let mut start = None;
271 for (i, ch) in json.char_indices() {
272 match ch {
273 '{' => {
274 if depth == 0 {
275 start = Some(i);
276 }
277 depth += 1;
278 }
279 '}' => {
280 depth -= 1;
281 if depth == 0 {
282 if let Some(s) = start {
283 objects.push(json[s..=i].to_string());
284 }
285 start = None;
286 }
287 }
288 _ => {}
289 }
290 }
291 objects
292}
293
294fn extract_json_string_field(obj: &str, key: &str) -> Option<String> {
295 let pattern = format!("\"{}\":\"", key);
296 let start = obj.find(&pattern)? + pattern.len();
297 let rest = &obj[start..];
298 let end = rest.find('"')?;
299 Some(rest[..end].to_string())
300}
301
302fn extract_json_number_field(obj: &str, key: &str) -> Option<i64> {
303 let pattern = format!("\"{}\":", key);
304 let start = obj.find(&pattern)? + pattern.len();
305 let rest = obj[start..].trim_start();
306 let end = rest
307 .find(|c: char| !c.is_ascii_digit() && c != '-')
308 .unwrap_or(rest.len());
309 rest[..end].parse().ok()
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315 use crate::{CommitsTable, GitObjectStore, RefsTable, create_commit};
316 use arrow_graph_core::Triple;
317
318 fn make_triple(s: &str, p: &str, o: &str) -> Triple {
319 Triple {
320 subject: s.to_string(),
321 predicate: p.to_string(),
322 object: o.to_string(),
323 graph: None,
324 confidence: Some(1.0),
325 source_document: None,
326 source_chunk_id: None,
327 extracted_by: None,
328 caused_by: None,
329 derived_from: None,
330 consolidated_at: None,
331 }
332 }
333
334 #[test]
335 fn test_snapshot_roundtrip() {
336 let tmp = tempfile::tempdir().unwrap();
337 let mut obj = GitObjectStore::with_snapshot_dir(tmp.path().join("snap"));
338 let mut commits = CommitsTable::new();
339 let mut refs = RefsTable::new();
340
341 for i in 0..50 {
342 obj.store
343 .add_triple(
344 &make_triple(&format!("entity-{i}"), "rdf:type", "Thing"),
345 "world",
346 Some(1u8),
347 )
348 .unwrap();
349 }
350
351 let c1 = create_commit(&obj, &mut commits, vec![], "init", "Mini").unwrap();
352 refs.init_main(&c1.commit_id);
353
354 assert_eq!(obj.store.len(), 50);
355
356 let snapshot = snapshot_state(&obj, &commits, &refs).unwrap();
357 assert!(!snapshot.namespaces.is_empty());
358 assert!(!snapshot.commits_json.is_empty());
359
360 let bytes = snapshot_to_bytes(&snapshot);
361 assert!(bytes.len() > 100);
362
363 let restored_snapshot = bytes_to_snapshot(&bytes).unwrap();
364 assert_eq!(
365 restored_snapshot.namespaces.len(),
366 snapshot.namespaces.len()
367 );
368
369 let mut obj2 = GitObjectStore::with_snapshot_dir(tmp.path().join("snap2"));
370 let (commits2, refs2) = restore_snapshot(&mut obj2, &restored_snapshot).unwrap();
371
372 assert_eq!(obj2.store.len(), 50);
373 assert_eq!(commits2.len(), 1);
374 assert!(refs2.head().is_some());
375 }
376
377 #[test]
378 fn test_snapshot_empty_store() {
379 let tmp = tempfile::tempdir().unwrap();
380 let obj = GitObjectStore::with_snapshot_dir(tmp.path());
381 let commits = CommitsTable::new();
382 let refs = RefsTable::new();
383
384 let snapshot = snapshot_state(&obj, &commits, &refs).unwrap();
385 assert!(snapshot.namespaces.is_empty());
386
387 let bytes = snapshot_to_bytes(&snapshot);
388 let restored = bytes_to_snapshot(&bytes).unwrap();
389 assert!(restored.namespaces.is_empty());
390 }
391
392 #[test]
393 fn test_snapshot_multiple_namespaces() {
394 let tmp = tempfile::tempdir().unwrap();
395 let mut obj = GitObjectStore::with_snapshot_dir(tmp.path().join("snap"));
396 let mut commits = CommitsTable::new();
397 let refs = RefsTable::new();
398
399 obj.store
400 .add_triple(&make_triple("w", "r", "1"), "world", Some(1u8))
401 .unwrap();
402 obj.store
403 .add_triple(&make_triple("k", "r", "2"), "work", Some(5u8))
404 .unwrap();
405
406 let _c1 = create_commit(&obj, &mut commits, vec![], "multi-ns", "test").unwrap();
407
408 let snapshot = snapshot_state(&obj, &commits, &refs).unwrap();
409 assert_eq!(snapshot.namespaces.len(), 2);
410
411 let bytes = snapshot_to_bytes(&snapshot);
412 let restored = bytes_to_snapshot(&bytes).unwrap();
413
414 let mut obj2 = GitObjectStore::with_snapshot_dir(tmp.path().join("snap2"));
415 let (_, _) = restore_snapshot(&mut obj2, &restored).unwrap();
416 assert_eq!(obj2.store.len(), 2);
417 }
418
419 #[test]
420 fn test_bytes_to_snapshot_invalid() {
421 assert!(bytes_to_snapshot(&[]).is_err());
422 assert!(bytes_to_snapshot(&[0; 4]).is_err());
423 }
424}