1use crate::commit::CommitsTable;
24use crate::object_store::GitObjectStore;
25use crate::refs::RefsTable;
26use crate::save::{deserialize_commits, deserialize_refs, serialize_commits, serialize_refs};
27use bytes::Bytes;
28use nusy_arrow_core::Namespace;
29use parquet::arrow::ArrowWriter;
30use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
31use std::io::Cursor;
32
33#[derive(Debug, thiserror::Error)]
39pub enum RemoteError {
40 #[error("Parquet error: {0}")]
41 Parquet(#[from] parquet::errors::ParquetError),
42
43 #[error("Arrow error: {0}")]
44 Arrow(#[from] arrow::error::ArrowError),
45
46 #[error("Invalid snapshot: {0}")]
47 InvalidSnapshot(String),
48}
49
50pub struct Snapshot {
55 pub namespaces: Vec<(String, Vec<u8>)>,
57 pub commits_json: String,
59 pub refs_json: String,
61}
62
63pub fn snapshot_state(
68 obj_store: &GitObjectStore,
69 commits_table: &CommitsTable,
70 refs_table: &RefsTable,
71) -> Result<Snapshot, RemoteError> {
72 let mut namespaces = Vec::new();
73
74 for ns in Namespace::ALL {
75 let batches = obj_store.store.get_namespace_batches(ns);
76 if batches.is_empty() {
77 continue;
78 }
79
80 let schema = obj_store.store.schema().clone();
81 let mut buf = Vec::new();
82 {
83 let cursor = Cursor::new(&mut buf);
84 let mut writer = ArrowWriter::try_new(cursor, schema, None)?;
85 for batch in batches {
86 writer.write(batch)?;
87 }
88 writer.close()?;
89 }
90
91 namespaces.push((ns.as_str().to_string(), buf));
92 }
93
94 let commits_json = serialize_commits(commits_table);
96 let refs_json = serialize_refs(refs_table);
97
98 Ok(Snapshot {
99 namespaces,
100 commits_json,
101 refs_json,
102 })
103}
104
105pub fn restore_snapshot(
111 obj_store: &mut GitObjectStore,
112 snapshot: &Snapshot,
113) -> Result<(CommitsTable, RefsTable), RemoteError> {
114 obj_store.store.clear();
115
116 for (ns_name, parquet_bytes) in &snapshot.namespaces {
118 let ns = Namespace::from_str_loose(ns_name)
119 .ok_or_else(|| RemoteError::InvalidSnapshot(format!("Unknown namespace: {ns_name}")))?;
120
121 let bytes = Bytes::from(parquet_bytes.clone());
122 let reader = ParquetRecordBatchReaderBuilder::try_new(bytes)?.build()?;
123
124 let mut batches = Vec::new();
125 for batch_result in reader {
126 batches.push(batch_result?);
127 }
128
129 obj_store.store.set_namespace_batches(ns, batches);
130 }
131
132 let commits = if snapshot.commits_json.is_empty() {
134 CommitsTable::new()
135 } else {
136 deserialize_commits(&snapshot.commits_json)
137 };
138
139 let refs = if snapshot.refs_json.is_empty() {
140 RefsTable::new()
141 } else {
142 deserialize_refs(&snapshot.refs_json)
143 };
144
145 Ok((commits, refs))
146}
147
148pub fn snapshot_to_bytes(snapshot: &Snapshot) -> Vec<u8> {
153 let mut ns_entries = Vec::new();
155 let mut offset = 0u64;
156 for (name, data) in &snapshot.namespaces {
157 ns_entries.push(format!(
158 "{{\"name\":\"{}\",\"offset\":{},\"len\":{}}}",
159 name,
160 offset,
161 data.len()
162 ));
163 offset += data.len() as u64;
164 }
165
166 let manifest = format!(
167 "{{\"commits\":{},\"refs\":{},\"namespaces\":[{}]}}",
168 &snapshot.commits_json,
169 &snapshot.refs_json,
170 ns_entries.join(",")
171 );
172
173 let manifest_bytes = manifest.as_bytes();
174 let manifest_len = (manifest_bytes.len() as u64).to_le_bytes();
175
176 let mut result = Vec::new();
177 result.extend_from_slice(&manifest_len);
178 result.extend_from_slice(manifest_bytes);
179 for (_, data) in &snapshot.namespaces {
180 result.extend_from_slice(data);
181 }
182
183 result
184}
185
186pub fn bytes_to_snapshot(bytes: &[u8]) -> Result<Snapshot, RemoteError> {
188 if bytes.len() < 8 {
189 return Err(RemoteError::InvalidSnapshot("Too short".into()));
190 }
191
192 let manifest_len = u64::from_le_bytes(bytes[..8].try_into().unwrap()) as usize;
193 if bytes.len() < 8 + manifest_len {
194 return Err(RemoteError::InvalidSnapshot("Manifest truncated".into()));
195 }
196
197 let manifest_str = std::str::from_utf8(&bytes[8..8 + manifest_len])
198 .map_err(|e| RemoteError::InvalidSnapshot(format!("Invalid UTF-8: {e}")))?;
199
200 let commits_json = extract_json_value(manifest_str, "commits").unwrap_or_default();
202 let refs_json = extract_json_value(manifest_str, "refs").unwrap_or_default();
203
204 let data_start = 8 + manifest_len;
206 let ns_section = extract_json_value(manifest_str, "namespaces").unwrap_or_default();
207
208 let mut namespaces = Vec::new();
209 for entry in extract_json_objects(&ns_section) {
211 let name = extract_json_string_field(&entry, "name").unwrap_or_default();
212 let offset = extract_json_number_field(&entry, "offset").unwrap_or(0) as usize;
213 let len = extract_json_number_field(&entry, "len").unwrap_or(0) as usize;
214
215 if data_start + offset + len <= bytes.len() {
216 let data = bytes[data_start + offset..data_start + offset + len].to_vec();
217 namespaces.push((name, data));
218 }
219 }
220
221 Ok(Snapshot {
222 namespaces,
223 commits_json,
224 refs_json,
225 })
226}
227
228fn extract_json_value(json: &str, key: &str) -> Option<String> {
231 let pattern = format!("\"{}\":", key);
232 let start = json.find(&pattern)? + pattern.len();
233 let rest = json[start..].trim_start();
234
235 if rest.starts_with('[') {
236 let mut depth = 0;
238 let mut end = 0;
239 for (i, ch) in rest.char_indices() {
240 match ch {
241 '[' => depth += 1,
242 ']' => {
243 depth -= 1;
244 if depth == 0 {
245 end = i + 1;
246 break;
247 }
248 }
249 _ => {}
250 }
251 }
252 Some(rest[..end].to_string())
253 } else if rest.starts_with('{') {
254 let mut depth = 0;
256 let mut end = 0;
257 for (i, ch) in rest.char_indices() {
258 match ch {
259 '{' => depth += 1,
260 '}' => {
261 depth -= 1;
262 if depth == 0 {
263 end = i + 1;
264 break;
265 }
266 }
267 _ => {}
268 }
269 }
270 Some(rest[..end].to_string())
271 } else {
272 let end = rest.find([',', '}']).unwrap_or(rest.len());
274 Some(rest[..end].trim().to_string())
275 }
276}
277
278fn extract_json_objects(json: &str) -> Vec<String> {
279 let mut objects = Vec::new();
280 let mut depth = 0;
281 let mut start = None;
282 for (i, ch) in json.char_indices() {
283 match ch {
284 '{' => {
285 if depth == 0 {
286 start = Some(i);
287 }
288 depth += 1;
289 }
290 '}' => {
291 depth -= 1;
292 if depth == 0 {
293 if let Some(s) = start {
294 objects.push(json[s..=i].to_string());
295 }
296 start = None;
297 }
298 }
299 _ => {}
300 }
301 }
302 objects
303}
304
305fn extract_json_string_field(obj: &str, key: &str) -> Option<String> {
306 let pattern = format!("\"{}\":\"", key);
307 let start = obj.find(&pattern)? + pattern.len();
308 let rest = &obj[start..];
309 let end = rest.find('"')?;
310 Some(rest[..end].to_string())
311}
312
313fn extract_json_number_field(obj: &str, key: &str) -> Option<i64> {
314 let pattern = format!("\"{}\":", key);
315 let start = obj.find(&pattern)? + pattern.len();
316 let rest = obj[start..].trim_start();
317 let end = rest
318 .find(|c: char| !c.is_ascii_digit() && c != '-')
319 .unwrap_or(rest.len());
320 rest[..end].parse().ok()
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use crate::{CommitsTable, GitObjectStore, RefsTable, create_commit};
327 use nusy_arrow_core::{Namespace, Triple, YLayer};
328
329 fn make_triple(s: &str, p: &str, o: &str) -> Triple {
330 Triple {
331 subject: s.to_string(),
332 predicate: p.to_string(),
333 object: o.to_string(),
334 graph: None,
335 confidence: Some(1.0),
336 source_document: None,
337 source_chunk_id: None,
338 extracted_by: None,
339 caused_by: None,
340 derived_from: None,
341 consolidated_at: None,
342 certifiability_class: None,
343 }
344 }
345
346 #[test]
347 fn test_snapshot_roundtrip() {
348 let tmp = tempfile::tempdir().unwrap();
349 let mut obj = GitObjectStore::with_snapshot_dir(tmp.path().join("snap"));
350 let mut commits = CommitsTable::new();
351 let mut refs = RefsTable::new();
352
353 for i in 0..50 {
355 obj.store
356 .add_triple(
357 &make_triple(&format!("entity-{i}"), "rdf:type", "Thing"),
358 Namespace::World,
359 YLayer::Semantic,
360 )
361 .unwrap();
362 }
363
364 let c1 = create_commit(&obj, &mut commits, vec![], "init", "Mini").unwrap();
365 refs.init_main(&c1.commit_id);
366
367 assert_eq!(obj.store.len(), 50);
368
369 let snapshot = snapshot_state(&obj, &commits, &refs).unwrap();
371 assert!(!snapshot.namespaces.is_empty());
372 assert!(!snapshot.commits_json.is_empty());
373
374 let bytes = snapshot_to_bytes(&snapshot);
376 assert!(bytes.len() > 100);
377
378 let restored_snapshot = bytes_to_snapshot(&bytes).unwrap();
380 assert_eq!(
381 restored_snapshot.namespaces.len(),
382 snapshot.namespaces.len()
383 );
384
385 let mut obj2 = GitObjectStore::with_snapshot_dir(tmp.path().join("snap2"));
387 let (commits2, refs2) = restore_snapshot(&mut obj2, &restored_snapshot).unwrap();
388
389 assert_eq!(obj2.store.len(), 50);
390 assert_eq!(commits2.len(), 1);
391 assert!(refs2.head().is_some());
392 }
393
394 #[test]
395 fn test_snapshot_empty_store() {
396 let tmp = tempfile::tempdir().unwrap();
397 let obj = GitObjectStore::with_snapshot_dir(tmp.path());
398 let commits = CommitsTable::new();
399 let refs = RefsTable::new();
400
401 let snapshot = snapshot_state(&obj, &commits, &refs).unwrap();
402 assert!(snapshot.namespaces.is_empty());
403
404 let bytes = snapshot_to_bytes(&snapshot);
405 let restored = bytes_to_snapshot(&bytes).unwrap();
406 assert!(restored.namespaces.is_empty());
407 }
408
409 #[test]
410 fn test_snapshot_multiple_namespaces() {
411 let tmp = tempfile::tempdir().unwrap();
412 let mut obj = GitObjectStore::with_snapshot_dir(tmp.path().join("snap"));
413 let mut commits = CommitsTable::new();
414 let refs = RefsTable::new();
415
416 obj.store
417 .add_triple(
418 &make_triple("w", "r", "1"),
419 Namespace::World,
420 YLayer::Semantic,
421 )
422 .unwrap();
423 obj.store
424 .add_triple(
425 &make_triple("k", "r", "2"),
426 Namespace::Work,
427 YLayer::Procedural,
428 )
429 .unwrap();
430
431 let _c1 = create_commit(&obj, &mut commits, vec![], "multi-ns", "test").unwrap();
432
433 let snapshot = snapshot_state(&obj, &commits, &refs).unwrap();
434 assert_eq!(snapshot.namespaces.len(), 2);
435
436 let bytes = snapshot_to_bytes(&snapshot);
437 let restored = bytes_to_snapshot(&bytes).unwrap();
438
439 let mut obj2 = GitObjectStore::with_snapshot_dir(tmp.path().join("snap2"));
440 let (_, _) = restore_snapshot(&mut obj2, &restored).unwrap();
441 assert_eq!(obj2.store.len(), 2);
442 }
443
444 #[test]
445 fn test_bytes_to_snapshot_invalid() {
446 assert!(bytes_to_snapshot(&[]).is_err());
447 assert!(bytes_to_snapshot(&[0; 4]).is_err());
448 }
449}