1use std::path::Path;
8
9use kyu_catalog::CatalogContent;
10use kyu_common::id::TableId;
11use kyu_common::{KyuError, KyuResult};
12use kyu_types::{LogicalType, TypedValue};
13use smol_str::SmolStr;
14
15use crate::storage::NodeGroupStorage;
16
17const CATALOG_FILE: &str = "catalog.json";
18const DATA_DIR: &str = "data";
19
20pub fn save_catalog(dir: &Path, catalog: &CatalogContent) -> KyuResult<()> {
26 let json = catalog.serialize_json();
27 let path = dir.join(CATALOG_FILE);
28 std::fs::write(&path, json.as_bytes()).map_err(|e| {
29 KyuError::Storage(format!("cannot write catalog to '{}': {e}", path.display()))
30 })
31}
32
33pub fn load_catalog(dir: &Path) -> KyuResult<Option<CatalogContent>> {
35 let path = dir.join(CATALOG_FILE);
36 if !path.exists() {
37 return Ok(None);
38 }
39 let json = std::fs::read_to_string(&path).map_err(|e| {
40 KyuError::Storage(format!("cannot read catalog from '{}': {e}", path.display()))
41 })?;
42 let content = CatalogContent::deserialize_json(&json).map_err(|e| {
43 KyuError::Storage(format!("cannot parse catalog JSON: {e}"))
44 })?;
45 Ok(Some(content))
46}
47
48pub fn save_storage(
54 dir: &Path,
55 storage: &NodeGroupStorage,
56 catalog: &CatalogContent,
57) -> KyuResult<()> {
58 let data_dir = dir.join(DATA_DIR);
59 std::fs::create_dir_all(&data_dir).map_err(|e| {
60 KyuError::Storage(format!("cannot create data dir '{}': {e}", data_dir.display()))
61 })?;
62
63 if let Ok(entries) = std::fs::read_dir(&data_dir) {
65 for entry in entries.flatten() {
66 let name = entry.file_name();
67 let name_str = name.to_string_lossy();
68 if name_str.ends_with(".bin")
69 && let Ok(tid) = name_str.trim_end_matches(".bin").parse::<u64>()
70 && !storage.has_table(TableId(tid))
71 {
72 let _ = std::fs::remove_file(entry.path());
73 }
74 }
75 }
76
77 for nt in catalog.node_tables() {
79 save_table(&data_dir, nt.table_id, storage)?;
80 }
81
82 for rt in catalog.rel_tables() {
84 save_table(&data_dir, rt.table_id, storage)?;
85 }
86
87 Ok(())
88}
89
90fn save_table(
91 data_dir: &Path,
92 table_id: TableId,
93 storage: &NodeGroupStorage,
94) -> KyuResult<()> {
95 let rows = storage.scan_rows(table_id)?;
96 let path = data_dir.join(format!("{}.bin", table_id.0));
97
98 let mut buf = Vec::new();
99
100 buf.extend_from_slice(&(rows.len() as u64).to_le_bytes());
102
103 for (_row_idx, values) in &rows {
105 buf.extend_from_slice(&(values.len() as u32).to_le_bytes());
106 for val in values {
107 serialize_typed_value(&mut buf, val);
108 }
109 }
110
111 std::fs::write(&path, &buf).map_err(|e| {
112 KyuError::Storage(format!("cannot write table data to '{}': {e}", path.display()))
113 })
114}
115
116pub fn load_storage(dir: &Path, catalog: &CatalogContent) -> KyuResult<NodeGroupStorage> {
118 let mut storage = NodeGroupStorage::new();
119 let data_dir = dir.join(DATA_DIR);
120
121 for nt in catalog.node_tables() {
123 let schema: Vec<LogicalType> = nt.properties.iter().map(|p| p.data_type.clone()).collect();
124 storage.create_table(nt.table_id, schema);
125 }
126 for rt in catalog.rel_tables() {
127 let from_key_type = catalog
130 .find_by_id(rt.from_table_id)
131 .and_then(|e| e.as_node_table())
132 .map(|n| n.primary_key_property().data_type.clone())
133 .unwrap_or(LogicalType::Int64);
134 let to_key_type = catalog
135 .find_by_id(rt.to_table_id)
136 .and_then(|e| e.as_node_table())
137 .map(|n| n.primary_key_property().data_type.clone())
138 .unwrap_or(LogicalType::Int64);
139 let mut schema = vec![from_key_type, to_key_type];
140 schema.extend(rt.properties.iter().map(|p| p.data_type.clone()));
141 storage.create_table(rt.table_id, schema);
142 }
143
144 if !data_dir.exists() {
145 return Ok(storage);
146 }
147
148 if let Ok(entries) = std::fs::read_dir(&data_dir) {
150 for entry in entries.flatten() {
151 let name = entry.file_name();
152 let name_str = name.to_string_lossy();
153 if !name_str.ends_with(".bin") {
154 continue;
155 }
156 if let Ok(tid) = name_str.trim_end_matches(".bin").parse::<u64>() {
157 let table_id = TableId(tid);
158 if storage.has_table(table_id) {
159 load_table_rows(&entry.path(), table_id, &mut storage)?;
160 }
161 }
162 }
163 }
164
165 Ok(storage)
166}
167
168fn load_table_rows(
169 path: &Path,
170 table_id: TableId,
171 storage: &mut NodeGroupStorage,
172) -> KyuResult<()> {
173 let data = std::fs::read(path).map_err(|e| {
174 KyuError::Storage(format!("cannot read table data from '{}': {e}", path.display()))
175 })?;
176
177 if data.len() < 8 {
178 return Ok(()); }
180
181 let mut offset = 0;
182 let num_rows = read_u64_le(&data, &mut offset);
183
184 for _ in 0..num_rows {
185 if offset + 4 > data.len() {
186 break;
187 }
188 let num_cols = read_u32_le(&data, &mut offset) as usize;
189 let mut values = Vec::with_capacity(num_cols);
190 for _ in 0..num_cols {
191 let val = deserialize_typed_value(&data, &mut offset)?;
192 values.push(val);
193 }
194 storage.insert_row(table_id, &values)?;
195 }
196
197 Ok(())
198}
199
200const TAG_NULL: u8 = 0;
206const TAG_BOOL: u8 = 1;
207const TAG_INT8: u8 = 2;
208const TAG_INT16: u8 = 3;
209const TAG_INT32: u8 = 4;
210const TAG_INT64: u8 = 5;
211const TAG_FLOAT: u8 = 6;
212const TAG_DOUBLE: u8 = 7;
213const TAG_STRING: u8 = 8;
214
215fn serialize_typed_value(buf: &mut Vec<u8>, val: &TypedValue) {
216 match val {
217 TypedValue::Null => buf.push(TAG_NULL),
218 TypedValue::Bool(v) => {
219 buf.push(TAG_BOOL);
220 buf.push(if *v { 1 } else { 0 });
221 }
222 TypedValue::Int8(v) => {
223 buf.push(TAG_INT8);
224 buf.extend_from_slice(&v.to_le_bytes());
225 }
226 TypedValue::Int16(v) => {
227 buf.push(TAG_INT16);
228 buf.extend_from_slice(&v.to_le_bytes());
229 }
230 TypedValue::Int32(v) => {
231 buf.push(TAG_INT32);
232 buf.extend_from_slice(&v.to_le_bytes());
233 }
234 TypedValue::Int64(v) => {
235 buf.push(TAG_INT64);
236 buf.extend_from_slice(&v.to_le_bytes());
237 }
238 TypedValue::Float(v) => {
239 buf.push(TAG_FLOAT);
240 buf.extend_from_slice(&v.to_le_bytes());
241 }
242 TypedValue::Double(v) => {
243 buf.push(TAG_DOUBLE);
244 buf.extend_from_slice(&v.to_le_bytes());
245 }
246 TypedValue::String(s) => {
247 buf.push(TAG_STRING);
248 let bytes = s.as_bytes();
249 buf.extend_from_slice(&(bytes.len() as u32).to_le_bytes());
250 buf.extend_from_slice(bytes);
251 }
252 _ => buf.push(TAG_NULL),
254 }
255}
256
257fn deserialize_typed_value(data: &[u8], offset: &mut usize) -> KyuResult<TypedValue> {
258 if *offset >= data.len() {
259 return Err(KyuError::Storage("unexpected end of table data".into()));
260 }
261 let tag = data[*offset];
262 *offset += 1;
263
264 match tag {
265 TAG_NULL => Ok(TypedValue::Null),
266 TAG_BOOL => {
267 ensure_remaining(data, *offset, 1)?;
268 let v = data[*offset] != 0;
269 *offset += 1;
270 Ok(TypedValue::Bool(v))
271 }
272 TAG_INT8 => {
273 ensure_remaining(data, *offset, 1)?;
274 let v = data[*offset] as i8;
275 *offset += 1;
276 Ok(TypedValue::Int8(v))
277 }
278 TAG_INT16 => {
279 ensure_remaining(data, *offset, 2)?;
280 let v = i16::from_le_bytes(data[*offset..*offset + 2].try_into().unwrap());
281 *offset += 2;
282 Ok(TypedValue::Int16(v))
283 }
284 TAG_INT32 => {
285 ensure_remaining(data, *offset, 4)?;
286 let v = i32::from_le_bytes(data[*offset..*offset + 4].try_into().unwrap());
287 *offset += 4;
288 Ok(TypedValue::Int32(v))
289 }
290 TAG_INT64 => {
291 ensure_remaining(data, *offset, 8)?;
292 let v = i64::from_le_bytes(data[*offset..*offset + 8].try_into().unwrap());
293 *offset += 8;
294 Ok(TypedValue::Int64(v))
295 }
296 TAG_FLOAT => {
297 ensure_remaining(data, *offset, 4)?;
298 let v = f32::from_le_bytes(data[*offset..*offset + 4].try_into().unwrap());
299 *offset += 4;
300 Ok(TypedValue::Float(v))
301 }
302 TAG_DOUBLE => {
303 ensure_remaining(data, *offset, 8)?;
304 let v = f64::from_le_bytes(data[*offset..*offset + 8].try_into().unwrap());
305 *offset += 8;
306 Ok(TypedValue::Double(v))
307 }
308 TAG_STRING => {
309 ensure_remaining(data, *offset, 4)?;
310 let len = u32::from_le_bytes(data[*offset..*offset + 4].try_into().unwrap()) as usize;
311 *offset += 4;
312 ensure_remaining(data, *offset, len)?;
313 let s = std::str::from_utf8(&data[*offset..*offset + len]).map_err(|e| {
314 KyuError::Storage(format!("invalid UTF-8 in table data: {e}"))
315 })?;
316 *offset += len;
317 Ok(TypedValue::String(SmolStr::new(s)))
318 }
319 _ => Err(KyuError::Storage(format!("unknown TypedValue tag: {tag}"))),
320 }
321}
322
323fn read_u64_le(data: &[u8], offset: &mut usize) -> u64 {
324 let v = u64::from_le_bytes(data[*offset..*offset + 8].try_into().unwrap());
325 *offset += 8;
326 v
327}
328
329fn read_u32_le(data: &[u8], offset: &mut usize) -> u32 {
330 let v = u32::from_le_bytes(data[*offset..*offset + 4].try_into().unwrap());
331 *offset += 4;
332 v
333}
334
335fn ensure_remaining(data: &[u8], offset: usize, needed: usize) -> KyuResult<()> {
336 if offset + needed > data.len() {
337 Err(KyuError::Storage("unexpected end of table data".into()))
338 } else {
339 Ok(())
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346 use kyu_catalog::{NodeTableEntry, Property};
347
348 fn make_test_catalog() -> CatalogContent {
349 let mut c = CatalogContent::new();
350 let tid = c.alloc_table_id();
351 let pid0 = c.alloc_property_id();
352 let pid1 = c.alloc_property_id();
353 let pid2 = c.alloc_property_id();
354 c.add_node_table(NodeTableEntry {
355 table_id: tid,
356 name: SmolStr::new("Person"),
357 properties: vec![
358 Property::new(pid0, "id", LogicalType::Int64, true),
359 Property::new(pid1, "name", LogicalType::String, false),
360 Property::new(pid2, "score", LogicalType::Double, false),
361 ],
362 primary_key_idx: 0,
363 num_rows: 0,
364 comment: None,
365 })
366 .unwrap();
367 c
368 }
369
370 #[test]
371 fn catalog_save_load_roundtrip() {
372 let dir = std::env::temp_dir().join("kyu_test_persist_catalog");
373 let _ = std::fs::create_dir_all(&dir);
374
375 let catalog = make_test_catalog();
376 save_catalog(&dir, &catalog).unwrap();
377
378 let loaded = load_catalog(&dir).unwrap().unwrap();
379 assert_eq!(loaded.num_tables(), 1);
380 assert!(loaded.find_by_name("Person").is_some());
381 assert_eq!(loaded.next_table_id, catalog.next_table_id);
382
383 let _ = std::fs::remove_dir_all(&dir);
384 }
385
386 #[test]
387 fn load_catalog_missing_returns_none() {
388 let dir = std::env::temp_dir().join("kyu_test_persist_missing");
389 let _ = std::fs::create_dir_all(&dir);
390 let _ = std::fs::remove_file(dir.join(CATALOG_FILE));
391
392 let result = load_catalog(&dir).unwrap();
393 assert!(result.is_none());
394
395 let _ = std::fs::remove_dir_all(&dir);
396 }
397
398 #[test]
399 fn storage_save_load_roundtrip() {
400 let dir = std::env::temp_dir().join("kyu_test_persist_storage");
401 let _ = std::fs::create_dir_all(&dir);
402
403 let catalog = make_test_catalog();
404 let tid = TableId(0);
405 let schema = vec![LogicalType::Int64, LogicalType::String, LogicalType::Double];
406
407 let mut storage = NodeGroupStorage::new();
408 storage.create_table(tid, schema);
409 storage
410 .insert_row(tid, &[
411 TypedValue::Int64(1),
412 TypedValue::String(SmolStr::new("Alice")),
413 TypedValue::Double(95.5),
414 ])
415 .unwrap();
416 storage
417 .insert_row(tid, &[
418 TypedValue::Int64(2),
419 TypedValue::String(SmolStr::new("Bob")),
420 TypedValue::Double(87.3),
421 ])
422 .unwrap();
423
424 save_storage(&dir, &storage, &catalog).unwrap();
425
426 let loaded = load_storage(&dir, &catalog).unwrap();
427 assert!(loaded.has_table(tid));
428 assert_eq!(loaded.num_rows(tid), 2);
429
430 let rows = loaded.scan_rows(tid).unwrap();
431 assert_eq!(rows.len(), 2);
432 assert_eq!(rows[0].1[0], TypedValue::Int64(1));
433 assert_eq!(rows[0].1[1], TypedValue::String(SmolStr::new("Alice")));
434 assert_eq!(rows[0].1[2], TypedValue::Double(95.5));
435 assert_eq!(rows[1].1[0], TypedValue::Int64(2));
436
437 let _ = std::fs::remove_dir_all(&dir);
438 }
439
440 #[test]
441 fn storage_empty_table_roundtrip() {
442 let dir = std::env::temp_dir().join("kyu_test_persist_empty");
443 let _ = std::fs::create_dir_all(&dir);
444
445 let catalog = make_test_catalog();
446 let tid = TableId(0);
447 let mut storage = NodeGroupStorage::new();
448 storage.create_table(tid, vec![LogicalType::Int64, LogicalType::String, LogicalType::Double]);
449
450 save_storage(&dir, &storage, &catalog).unwrap();
451 let loaded = load_storage(&dir, &catalog).unwrap();
452 assert!(loaded.has_table(tid));
453 assert_eq!(loaded.num_rows(tid), 0);
454
455 let _ = std::fs::remove_dir_all(&dir);
456 }
457
458 #[test]
459 fn typed_value_binary_roundtrip() {
460 let values = vec![
461 TypedValue::Null,
462 TypedValue::Bool(true),
463 TypedValue::Bool(false),
464 TypedValue::Int8(-42),
465 TypedValue::Int16(1234),
466 TypedValue::Int32(-999999),
467 TypedValue::Int64(i64::MAX),
468 TypedValue::Float(3.14),
469 TypedValue::Double(2.718281828),
470 TypedValue::String(SmolStr::new("hello world")),
471 TypedValue::String(SmolStr::new("")),
472 ];
473
474 let mut buf = Vec::new();
475 for v in &values {
476 serialize_typed_value(&mut buf, v);
477 }
478
479 let mut offset = 0;
480 for expected in &values {
481 let actual = deserialize_typed_value(&buf, &mut offset).unwrap();
482 assert_eq!(&actual, expected);
483 }
484 assert_eq!(offset, buf.len());
485 }
486}