1use std::path::Path;
8
9use kyu_catalog::CatalogContent;
10use kyu_common::id::TableId;
11use kyu_common::{KyuError, KyuResult};
12use kyu_types::{LogicalType, TypedValue};
13use smol_str::SmolStr;
14
15use crate::storage::NodeGroupStorage;
16
17const CATALOG_FILE: &str = "catalog.json";
18const DATA_DIR: &str = "data";
19
20pub fn save_catalog(dir: &Path, catalog: &CatalogContent) -> KyuResult<()> {
26 let json = catalog.serialize_json();
27 let path = dir.join(CATALOG_FILE);
28 std::fs::write(&path, json.as_bytes()).map_err(|e| {
29 KyuError::Storage(format!("cannot write catalog to '{}': {e}", path.display()))
30 })
31}
32
33pub fn load_catalog(dir: &Path) -> KyuResult<Option<CatalogContent>> {
35 let path = dir.join(CATALOG_FILE);
36 if !path.exists() {
37 return Ok(None);
38 }
39 let json = std::fs::read_to_string(&path).map_err(|e| {
40 KyuError::Storage(format!(
41 "cannot read catalog from '{}': {e}",
42 path.display()
43 ))
44 })?;
45 let content = CatalogContent::deserialize_json(&json)
46 .map_err(|e| KyuError::Storage(format!("cannot parse catalog JSON: {e}")))?;
47 Ok(Some(content))
48}
49
50pub fn save_storage(
56 dir: &Path,
57 storage: &NodeGroupStorage,
58 catalog: &CatalogContent,
59) -> KyuResult<()> {
60 let data_dir = dir.join(DATA_DIR);
61 std::fs::create_dir_all(&data_dir).map_err(|e| {
62 KyuError::Storage(format!(
63 "cannot create data dir '{}': {e}",
64 data_dir.display()
65 ))
66 })?;
67
68 if let Ok(entries) = std::fs::read_dir(&data_dir) {
70 for entry in entries.flatten() {
71 let name = entry.file_name();
72 let name_str = name.to_string_lossy();
73 if name_str.ends_with(".bin")
74 && let Ok(tid) = name_str.trim_end_matches(".bin").parse::<u64>()
75 && !storage.has_table(TableId(tid))
76 {
77 let _ = std::fs::remove_file(entry.path());
78 }
79 }
80 }
81
82 for nt in catalog.node_tables() {
84 save_table(&data_dir, nt.table_id, storage)?;
85 }
86
87 for rt in catalog.rel_tables() {
89 save_table(&data_dir, rt.table_id, storage)?;
90 }
91
92 Ok(())
93}
94
95fn save_table(data_dir: &Path, table_id: TableId, storage: &NodeGroupStorage) -> KyuResult<()> {
96 let rows = storage.scan_rows(table_id)?;
97 let path = data_dir.join(format!("{}.bin", table_id.0));
98
99 let mut buf = Vec::new();
100
101 buf.extend_from_slice(&(rows.len() as u64).to_le_bytes());
103
104 for (_row_idx, values) in &rows {
106 buf.extend_from_slice(&(values.len() as u32).to_le_bytes());
107 for val in values {
108 serialize_typed_value(&mut buf, val);
109 }
110 }
111
112 std::fs::write(&path, &buf).map_err(|e| {
113 KyuError::Storage(format!(
114 "cannot write table data to '{}': {e}",
115 path.display()
116 ))
117 })
118}
119
120pub fn load_storage(dir: &Path, catalog: &CatalogContent) -> KyuResult<NodeGroupStorage> {
122 let mut storage = NodeGroupStorage::new();
123 let data_dir = dir.join(DATA_DIR);
124
125 for nt in catalog.node_tables() {
127 let schema: Vec<LogicalType> = nt.properties.iter().map(|p| p.data_type.clone()).collect();
128 storage.create_table(nt.table_id, schema);
129 }
130 for rt in catalog.rel_tables() {
131 let from_key_type = catalog
134 .find_by_id(rt.from_table_id)
135 .and_then(|e| e.as_node_table())
136 .map(|n| n.primary_key_property().data_type.clone())
137 .unwrap_or(LogicalType::Int64);
138 let to_key_type = catalog
139 .find_by_id(rt.to_table_id)
140 .and_then(|e| e.as_node_table())
141 .map(|n| n.primary_key_property().data_type.clone())
142 .unwrap_or(LogicalType::Int64);
143 let mut schema = vec![from_key_type, to_key_type];
144 schema.extend(rt.properties.iter().map(|p| p.data_type.clone()));
145 storage.create_table(rt.table_id, schema);
146 }
147
148 if !data_dir.exists() {
149 return Ok(storage);
150 }
151
152 if let Ok(entries) = std::fs::read_dir(&data_dir) {
154 for entry in entries.flatten() {
155 let name = entry.file_name();
156 let name_str = name.to_string_lossy();
157 if !name_str.ends_with(".bin") {
158 continue;
159 }
160 if let Ok(tid) = name_str.trim_end_matches(".bin").parse::<u64>() {
161 let table_id = TableId(tid);
162 if storage.has_table(table_id) {
163 load_table_rows(&entry.path(), table_id, &mut storage)?;
164 }
165 }
166 }
167 }
168
169 Ok(storage)
170}
171
172fn load_table_rows(
173 path: &Path,
174 table_id: TableId,
175 storage: &mut NodeGroupStorage,
176) -> KyuResult<()> {
177 let data = std::fs::read(path).map_err(|e| {
178 KyuError::Storage(format!(
179 "cannot read table data from '{}': {e}",
180 path.display()
181 ))
182 })?;
183
184 if data.len() < 8 {
185 return Ok(()); }
187
188 let mut offset = 0;
189 let num_rows = read_u64_le(&data, &mut offset);
190
191 for _ in 0..num_rows {
192 if offset + 4 > data.len() {
193 break;
194 }
195 let num_cols = read_u32_le(&data, &mut offset) as usize;
196 let mut values = Vec::with_capacity(num_cols);
197 for _ in 0..num_cols {
198 let val = deserialize_typed_value(&data, &mut offset)?;
199 values.push(val);
200 }
201 storage.insert_row(table_id, &values)?;
202 }
203
204 Ok(())
205}
206
207const TAG_NULL: u8 = 0;
213const TAG_BOOL: u8 = 1;
214const TAG_INT8: u8 = 2;
215const TAG_INT16: u8 = 3;
216const TAG_INT32: u8 = 4;
217const TAG_INT64: u8 = 5;
218const TAG_FLOAT: u8 = 6;
219const TAG_DOUBLE: u8 = 7;
220const TAG_STRING: u8 = 8;
221
222fn serialize_typed_value(buf: &mut Vec<u8>, val: &TypedValue) {
223 match val {
224 TypedValue::Null => buf.push(TAG_NULL),
225 TypedValue::Bool(v) => {
226 buf.push(TAG_BOOL);
227 buf.push(if *v { 1 } else { 0 });
228 }
229 TypedValue::Int8(v) => {
230 buf.push(TAG_INT8);
231 buf.extend_from_slice(&v.to_le_bytes());
232 }
233 TypedValue::Int16(v) => {
234 buf.push(TAG_INT16);
235 buf.extend_from_slice(&v.to_le_bytes());
236 }
237 TypedValue::Int32(v) => {
238 buf.push(TAG_INT32);
239 buf.extend_from_slice(&v.to_le_bytes());
240 }
241 TypedValue::Int64(v) => {
242 buf.push(TAG_INT64);
243 buf.extend_from_slice(&v.to_le_bytes());
244 }
245 TypedValue::Float(v) => {
246 buf.push(TAG_FLOAT);
247 buf.extend_from_slice(&v.to_le_bytes());
248 }
249 TypedValue::Double(v) => {
250 buf.push(TAG_DOUBLE);
251 buf.extend_from_slice(&v.to_le_bytes());
252 }
253 TypedValue::String(s) => {
254 buf.push(TAG_STRING);
255 let bytes = s.as_bytes();
256 buf.extend_from_slice(&(bytes.len() as u32).to_le_bytes());
257 buf.extend_from_slice(bytes);
258 }
259 _ => buf.push(TAG_NULL),
261 }
262}
263
264fn deserialize_typed_value(data: &[u8], offset: &mut usize) -> KyuResult<TypedValue> {
265 if *offset >= data.len() {
266 return Err(KyuError::Storage("unexpected end of table data".into()));
267 }
268 let tag = data[*offset];
269 *offset += 1;
270
271 match tag {
272 TAG_NULL => Ok(TypedValue::Null),
273 TAG_BOOL => {
274 ensure_remaining(data, *offset, 1)?;
275 let v = data[*offset] != 0;
276 *offset += 1;
277 Ok(TypedValue::Bool(v))
278 }
279 TAG_INT8 => {
280 ensure_remaining(data, *offset, 1)?;
281 let v = data[*offset] as i8;
282 *offset += 1;
283 Ok(TypedValue::Int8(v))
284 }
285 TAG_INT16 => {
286 ensure_remaining(data, *offset, 2)?;
287 let v = i16::from_le_bytes(data[*offset..*offset + 2].try_into().unwrap());
288 *offset += 2;
289 Ok(TypedValue::Int16(v))
290 }
291 TAG_INT32 => {
292 ensure_remaining(data, *offset, 4)?;
293 let v = i32::from_le_bytes(data[*offset..*offset + 4].try_into().unwrap());
294 *offset += 4;
295 Ok(TypedValue::Int32(v))
296 }
297 TAG_INT64 => {
298 ensure_remaining(data, *offset, 8)?;
299 let v = i64::from_le_bytes(data[*offset..*offset + 8].try_into().unwrap());
300 *offset += 8;
301 Ok(TypedValue::Int64(v))
302 }
303 TAG_FLOAT => {
304 ensure_remaining(data, *offset, 4)?;
305 let v = f32::from_le_bytes(data[*offset..*offset + 4].try_into().unwrap());
306 *offset += 4;
307 Ok(TypedValue::Float(v))
308 }
309 TAG_DOUBLE => {
310 ensure_remaining(data, *offset, 8)?;
311 let v = f64::from_le_bytes(data[*offset..*offset + 8].try_into().unwrap());
312 *offset += 8;
313 Ok(TypedValue::Double(v))
314 }
315 TAG_STRING => {
316 ensure_remaining(data, *offset, 4)?;
317 let len = u32::from_le_bytes(data[*offset..*offset + 4].try_into().unwrap()) as usize;
318 *offset += 4;
319 ensure_remaining(data, *offset, len)?;
320 let s = std::str::from_utf8(&data[*offset..*offset + len])
321 .map_err(|e| KyuError::Storage(format!("invalid UTF-8 in table data: {e}")))?;
322 *offset += len;
323 Ok(TypedValue::String(SmolStr::new(s)))
324 }
325 _ => Err(KyuError::Storage(format!("unknown TypedValue tag: {tag}"))),
326 }
327}
328
329fn read_u64_le(data: &[u8], offset: &mut usize) -> u64 {
330 let v = u64::from_le_bytes(data[*offset..*offset + 8].try_into().unwrap());
331 *offset += 8;
332 v
333}
334
335fn read_u32_le(data: &[u8], offset: &mut usize) -> u32 {
336 let v = u32::from_le_bytes(data[*offset..*offset + 4].try_into().unwrap());
337 *offset += 4;
338 v
339}
340
341fn ensure_remaining(data: &[u8], offset: usize, needed: usize) -> KyuResult<()> {
342 if offset + needed > data.len() {
343 Err(KyuError::Storage("unexpected end of table data".into()))
344 } else {
345 Ok(())
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352 use kyu_catalog::{NodeTableEntry, Property};
353
354 fn make_test_catalog() -> CatalogContent {
355 let mut c = CatalogContent::new();
356 let tid = c.alloc_table_id();
357 let pid0 = c.alloc_property_id();
358 let pid1 = c.alloc_property_id();
359 let pid2 = c.alloc_property_id();
360 c.add_node_table(NodeTableEntry {
361 table_id: tid,
362 name: SmolStr::new("Person"),
363 properties: vec![
364 Property::new(pid0, "id", LogicalType::Int64, true),
365 Property::new(pid1, "name", LogicalType::String, false),
366 Property::new(pid2, "score", LogicalType::Double, false),
367 ],
368 primary_key_idx: 0,
369 num_rows: 0,
370 comment: None,
371 })
372 .unwrap();
373 c
374 }
375
376 #[test]
377 fn catalog_save_load_roundtrip() {
378 let dir = std::env::temp_dir().join("kyu_test_persist_catalog");
379 let _ = std::fs::create_dir_all(&dir);
380
381 let catalog = make_test_catalog();
382 save_catalog(&dir, &catalog).unwrap();
383
384 let loaded = load_catalog(&dir).unwrap().unwrap();
385 assert_eq!(loaded.num_tables(), 1);
386 assert!(loaded.find_by_name("Person").is_some());
387 assert_eq!(loaded.next_table_id, catalog.next_table_id);
388
389 let _ = std::fs::remove_dir_all(&dir);
390 }
391
392 #[test]
393 fn load_catalog_missing_returns_none() {
394 let dir = std::env::temp_dir().join("kyu_test_persist_missing");
395 let _ = std::fs::create_dir_all(&dir);
396 let _ = std::fs::remove_file(dir.join(CATALOG_FILE));
397
398 let result = load_catalog(&dir).unwrap();
399 assert!(result.is_none());
400
401 let _ = std::fs::remove_dir_all(&dir);
402 }
403
404 #[test]
405 fn storage_save_load_roundtrip() {
406 let dir = std::env::temp_dir().join("kyu_test_persist_storage");
407 let _ = std::fs::create_dir_all(&dir);
408
409 let catalog = make_test_catalog();
410 let tid = TableId(0);
411 let schema = vec![LogicalType::Int64, LogicalType::String, LogicalType::Double];
412
413 let mut storage = NodeGroupStorage::new();
414 storage.create_table(tid, schema);
415 storage
416 .insert_row(
417 tid,
418 &[
419 TypedValue::Int64(1),
420 TypedValue::String(SmolStr::new("Alice")),
421 TypedValue::Double(95.5),
422 ],
423 )
424 .unwrap();
425 storage
426 .insert_row(
427 tid,
428 &[
429 TypedValue::Int64(2),
430 TypedValue::String(SmolStr::new("Bob")),
431 TypedValue::Double(87.3),
432 ],
433 )
434 .unwrap();
435
436 save_storage(&dir, &storage, &catalog).unwrap();
437
438 let loaded = load_storage(&dir, &catalog).unwrap();
439 assert!(loaded.has_table(tid));
440 assert_eq!(loaded.num_rows(tid), 2);
441
442 let rows = loaded.scan_rows(tid).unwrap();
443 assert_eq!(rows.len(), 2);
444 assert_eq!(rows[0].1[0], TypedValue::Int64(1));
445 assert_eq!(rows[0].1[1], TypedValue::String(SmolStr::new("Alice")));
446 assert_eq!(rows[0].1[2], TypedValue::Double(95.5));
447 assert_eq!(rows[1].1[0], TypedValue::Int64(2));
448
449 let _ = std::fs::remove_dir_all(&dir);
450 }
451
452 #[test]
453 fn storage_empty_table_roundtrip() {
454 let dir = std::env::temp_dir().join("kyu_test_persist_empty");
455 let _ = std::fs::create_dir_all(&dir);
456
457 let catalog = make_test_catalog();
458 let tid = TableId(0);
459 let mut storage = NodeGroupStorage::new();
460 storage.create_table(
461 tid,
462 vec![LogicalType::Int64, LogicalType::String, LogicalType::Double],
463 );
464
465 save_storage(&dir, &storage, &catalog).unwrap();
466 let loaded = load_storage(&dir, &catalog).unwrap();
467 assert!(loaded.has_table(tid));
468 assert_eq!(loaded.num_rows(tid), 0);
469
470 let _ = std::fs::remove_dir_all(&dir);
471 }
472
473 #[test]
474 fn typed_value_binary_roundtrip() {
475 let values = vec![
476 TypedValue::Null,
477 TypedValue::Bool(true),
478 TypedValue::Bool(false),
479 TypedValue::Int8(-42),
480 TypedValue::Int16(1234),
481 TypedValue::Int32(-999999),
482 TypedValue::Int64(i64::MAX),
483 TypedValue::Float(3.14),
484 TypedValue::Double(2.718281828),
485 TypedValue::String(SmolStr::new("hello world")),
486 TypedValue::String(SmolStr::new("")),
487 ];
488
489 let mut buf = Vec::new();
490 for v in &values {
491 serialize_typed_value(&mut buf, v);
492 }
493
494 let mut offset = 0;
495 for expected in &values {
496 let actual = deserialize_typed_value(&buf, &mut offset).unwrap();
497 assert_eq!(&actual, expected);
498 }
499 assert_eq!(offset, buf.len());
500 }
501}