Skip to main content

kyu_api/
persistence.rs

1//! Persistence — save/load catalog and storage to/from disk.
2//!
3//! Layout in database directory:
4//!   catalog.json   — serialized CatalogContent
5//!   data/<tid>.bin — per-table row data in a simple binary format
6
7use std::path::Path;
8
9use kyu_catalog::CatalogContent;
10use kyu_common::id::TableId;
11use kyu_common::{KyuError, KyuResult};
12use kyu_types::{LogicalType, TypedValue};
13use smol_str::SmolStr;
14
15use crate::storage::NodeGroupStorage;
16
17const CATALOG_FILE: &str = "catalog.json";
18const DATA_DIR: &str = "data";
19
20// ---------------------------------------------------------------------------
21// Catalog persistence
22// ---------------------------------------------------------------------------
23
24/// Save catalog content to `dir/catalog.json`.
25pub fn save_catalog(dir: &Path, catalog: &CatalogContent) -> KyuResult<()> {
26    let json = catalog.serialize_json();
27    let path = dir.join(CATALOG_FILE);
28    std::fs::write(&path, json.as_bytes()).map_err(|e| {
29        KyuError::Storage(format!("cannot write catalog to '{}': {e}", path.display()))
30    })
31}
32
33/// Load catalog content from `dir/catalog.json`. Returns `None` if not found.
34pub fn load_catalog(dir: &Path) -> KyuResult<Option<CatalogContent>> {
35    let path = dir.join(CATALOG_FILE);
36    if !path.exists() {
37        return Ok(None);
38    }
39    let json = std::fs::read_to_string(&path).map_err(|e| {
40        KyuError::Storage(format!("cannot read catalog from '{}': {e}", path.display()))
41    })?;
42    let content = CatalogContent::deserialize_json(&json).map_err(|e| {
43        KyuError::Storage(format!("cannot parse catalog JSON: {e}"))
44    })?;
45    Ok(Some(content))
46}
47
48// ---------------------------------------------------------------------------
49// Storage persistence
50// ---------------------------------------------------------------------------
51
52/// Save all table data to `dir/data/<table_id>.bin`.
53pub fn save_storage(
54    dir: &Path,
55    storage: &NodeGroupStorage,
56    catalog: &CatalogContent,
57) -> KyuResult<()> {
58    let data_dir = dir.join(DATA_DIR);
59    std::fs::create_dir_all(&data_dir).map_err(|e| {
60        KyuError::Storage(format!("cannot create data dir '{}': {e}", data_dir.display()))
61    })?;
62
63    // Remove stale files for tables that no longer exist.
64    if let Ok(entries) = std::fs::read_dir(&data_dir) {
65        for entry in entries.flatten() {
66            let name = entry.file_name();
67            let name_str = name.to_string_lossy();
68            if name_str.ends_with(".bin")
69                && let Ok(tid) = name_str.trim_end_matches(".bin").parse::<u64>()
70                && !storage.has_table(TableId(tid))
71            {
72                let _ = std::fs::remove_file(entry.path());
73            }
74        }
75    }
76
77    // Save node tables.
78    for nt in catalog.node_tables() {
79        save_table(&data_dir, nt.table_id, storage)?;
80    }
81
82    // Save rel tables.
83    for rt in catalog.rel_tables() {
84        save_table(&data_dir, rt.table_id, storage)?;
85    }
86
87    Ok(())
88}
89
90fn save_table(
91    data_dir: &Path,
92    table_id: TableId,
93    storage: &NodeGroupStorage,
94) -> KyuResult<()> {
95    let rows = storage.scan_rows(table_id)?;
96    let path = data_dir.join(format!("{}.bin", table_id.0));
97
98    let mut buf = Vec::new();
99
100    // Header: num_rows as u64 LE
101    buf.extend_from_slice(&(rows.len() as u64).to_le_bytes());
102
103    // Each row: num_cols u32 LE, then each value
104    for (_row_idx, values) in &rows {
105        buf.extend_from_slice(&(values.len() as u32).to_le_bytes());
106        for val in values {
107            serialize_typed_value(&mut buf, val);
108        }
109    }
110
111    std::fs::write(&path, &buf).map_err(|e| {
112        KyuError::Storage(format!("cannot write table data to '{}': {e}", path.display()))
113    })
114}
115
116/// Load all table data from `dir/data/`.
117pub fn load_storage(dir: &Path, catalog: &CatalogContent) -> KyuResult<NodeGroupStorage> {
118    let mut storage = NodeGroupStorage::new();
119    let data_dir = dir.join(DATA_DIR);
120
121    // Create tables from catalog schema.
122    for nt in catalog.node_tables() {
123        let schema: Vec<LogicalType> = nt.properties.iter().map(|p| p.data_type.clone()).collect();
124        storage.create_table(nt.table_id, schema);
125    }
126    for rt in catalog.rel_tables() {
127        // Rel tables store: src_key, dst_key, then user properties.
128        // We need to reconstruct the storage schema the same way connection.rs does.
129        let from_key_type = catalog
130            .find_by_id(rt.from_table_id)
131            .and_then(|e| e.as_node_table())
132            .map(|n| n.primary_key_property().data_type.clone())
133            .unwrap_or(LogicalType::Int64);
134        let to_key_type = catalog
135            .find_by_id(rt.to_table_id)
136            .and_then(|e| e.as_node_table())
137            .map(|n| n.primary_key_property().data_type.clone())
138            .unwrap_or(LogicalType::Int64);
139        let mut schema = vec![from_key_type, to_key_type];
140        schema.extend(rt.properties.iter().map(|p| p.data_type.clone()));
141        storage.create_table(rt.table_id, schema);
142    }
143
144    if !data_dir.exists() {
145        return Ok(storage);
146    }
147
148    // Load row data from .bin files.
149    if let Ok(entries) = std::fs::read_dir(&data_dir) {
150        for entry in entries.flatten() {
151            let name = entry.file_name();
152            let name_str = name.to_string_lossy();
153            if !name_str.ends_with(".bin") {
154                continue;
155            }
156            if let Ok(tid) = name_str.trim_end_matches(".bin").parse::<u64>() {
157                let table_id = TableId(tid);
158                if storage.has_table(table_id) {
159                    load_table_rows(&entry.path(), table_id, &mut storage)?;
160                }
161            }
162        }
163    }
164
165    Ok(storage)
166}
167
168fn load_table_rows(
169    path: &Path,
170    table_id: TableId,
171    storage: &mut NodeGroupStorage,
172) -> KyuResult<()> {
173    let data = std::fs::read(path).map_err(|e| {
174        KyuError::Storage(format!("cannot read table data from '{}': {e}", path.display()))
175    })?;
176
177    if data.len() < 8 {
178        return Ok(()); // Empty or corrupt file
179    }
180
181    let mut offset = 0;
182    let num_rows = read_u64_le(&data, &mut offset);
183
184    for _ in 0..num_rows {
185        if offset + 4 > data.len() {
186            break;
187        }
188        let num_cols = read_u32_le(&data, &mut offset) as usize;
189        let mut values = Vec::with_capacity(num_cols);
190        for _ in 0..num_cols {
191            let val = deserialize_typed_value(&data, &mut offset)?;
192            values.push(val);
193        }
194        storage.insert_row(table_id, &values)?;
195    }
196
197    Ok(())
198}
199
200// ---------------------------------------------------------------------------
201// TypedValue binary serialization
202// ---------------------------------------------------------------------------
203
204// Tags for TypedValue serialization.
205const TAG_NULL: u8 = 0;
206const TAG_BOOL: u8 = 1;
207const TAG_INT8: u8 = 2;
208const TAG_INT16: u8 = 3;
209const TAG_INT32: u8 = 4;
210const TAG_INT64: u8 = 5;
211const TAG_FLOAT: u8 = 6;
212const TAG_DOUBLE: u8 = 7;
213const TAG_STRING: u8 = 8;
214
215fn serialize_typed_value(buf: &mut Vec<u8>, val: &TypedValue) {
216    match val {
217        TypedValue::Null => buf.push(TAG_NULL),
218        TypedValue::Bool(v) => {
219            buf.push(TAG_BOOL);
220            buf.push(if *v { 1 } else { 0 });
221        }
222        TypedValue::Int8(v) => {
223            buf.push(TAG_INT8);
224            buf.extend_from_slice(&v.to_le_bytes());
225        }
226        TypedValue::Int16(v) => {
227            buf.push(TAG_INT16);
228            buf.extend_from_slice(&v.to_le_bytes());
229        }
230        TypedValue::Int32(v) => {
231            buf.push(TAG_INT32);
232            buf.extend_from_slice(&v.to_le_bytes());
233        }
234        TypedValue::Int64(v) => {
235            buf.push(TAG_INT64);
236            buf.extend_from_slice(&v.to_le_bytes());
237        }
238        TypedValue::Float(v) => {
239            buf.push(TAG_FLOAT);
240            buf.extend_from_slice(&v.to_le_bytes());
241        }
242        TypedValue::Double(v) => {
243            buf.push(TAG_DOUBLE);
244            buf.extend_from_slice(&v.to_le_bytes());
245        }
246        TypedValue::String(s) => {
247            buf.push(TAG_STRING);
248            let bytes = s.as_bytes();
249            buf.extend_from_slice(&(bytes.len() as u32).to_le_bytes());
250            buf.extend_from_slice(bytes);
251        }
252        // Unsupported types stored as null.
253        _ => buf.push(TAG_NULL),
254    }
255}
256
257fn deserialize_typed_value(data: &[u8], offset: &mut usize) -> KyuResult<TypedValue> {
258    if *offset >= data.len() {
259        return Err(KyuError::Storage("unexpected end of table data".into()));
260    }
261    let tag = data[*offset];
262    *offset += 1;
263
264    match tag {
265        TAG_NULL => Ok(TypedValue::Null),
266        TAG_BOOL => {
267            ensure_remaining(data, *offset, 1)?;
268            let v = data[*offset] != 0;
269            *offset += 1;
270            Ok(TypedValue::Bool(v))
271        }
272        TAG_INT8 => {
273            ensure_remaining(data, *offset, 1)?;
274            let v = data[*offset] as i8;
275            *offset += 1;
276            Ok(TypedValue::Int8(v))
277        }
278        TAG_INT16 => {
279            ensure_remaining(data, *offset, 2)?;
280            let v = i16::from_le_bytes(data[*offset..*offset + 2].try_into().unwrap());
281            *offset += 2;
282            Ok(TypedValue::Int16(v))
283        }
284        TAG_INT32 => {
285            ensure_remaining(data, *offset, 4)?;
286            let v = i32::from_le_bytes(data[*offset..*offset + 4].try_into().unwrap());
287            *offset += 4;
288            Ok(TypedValue::Int32(v))
289        }
290        TAG_INT64 => {
291            ensure_remaining(data, *offset, 8)?;
292            let v = i64::from_le_bytes(data[*offset..*offset + 8].try_into().unwrap());
293            *offset += 8;
294            Ok(TypedValue::Int64(v))
295        }
296        TAG_FLOAT => {
297            ensure_remaining(data, *offset, 4)?;
298            let v = f32::from_le_bytes(data[*offset..*offset + 4].try_into().unwrap());
299            *offset += 4;
300            Ok(TypedValue::Float(v))
301        }
302        TAG_DOUBLE => {
303            ensure_remaining(data, *offset, 8)?;
304            let v = f64::from_le_bytes(data[*offset..*offset + 8].try_into().unwrap());
305            *offset += 8;
306            Ok(TypedValue::Double(v))
307        }
308        TAG_STRING => {
309            ensure_remaining(data, *offset, 4)?;
310            let len = u32::from_le_bytes(data[*offset..*offset + 4].try_into().unwrap()) as usize;
311            *offset += 4;
312            ensure_remaining(data, *offset, len)?;
313            let s = std::str::from_utf8(&data[*offset..*offset + len]).map_err(|e| {
314                KyuError::Storage(format!("invalid UTF-8 in table data: {e}"))
315            })?;
316            *offset += len;
317            Ok(TypedValue::String(SmolStr::new(s)))
318        }
319        _ => Err(KyuError::Storage(format!("unknown TypedValue tag: {tag}"))),
320    }
321}
322
323fn read_u64_le(data: &[u8], offset: &mut usize) -> u64 {
324    let v = u64::from_le_bytes(data[*offset..*offset + 8].try_into().unwrap());
325    *offset += 8;
326    v
327}
328
329fn read_u32_le(data: &[u8], offset: &mut usize) -> u32 {
330    let v = u32::from_le_bytes(data[*offset..*offset + 4].try_into().unwrap());
331    *offset += 4;
332    v
333}
334
335fn ensure_remaining(data: &[u8], offset: usize, needed: usize) -> KyuResult<()> {
336    if offset + needed > data.len() {
337        Err(KyuError::Storage("unexpected end of table data".into()))
338    } else {
339        Ok(())
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use kyu_catalog::{NodeTableEntry, Property};
347
348    fn make_test_catalog() -> CatalogContent {
349        let mut c = CatalogContent::new();
350        let tid = c.alloc_table_id();
351        let pid0 = c.alloc_property_id();
352        let pid1 = c.alloc_property_id();
353        let pid2 = c.alloc_property_id();
354        c.add_node_table(NodeTableEntry {
355            table_id: tid,
356            name: SmolStr::new("Person"),
357            properties: vec![
358                Property::new(pid0, "id", LogicalType::Int64, true),
359                Property::new(pid1, "name", LogicalType::String, false),
360                Property::new(pid2, "score", LogicalType::Double, false),
361            ],
362            primary_key_idx: 0,
363            num_rows: 0,
364            comment: None,
365        })
366        .unwrap();
367        c
368    }
369
370    #[test]
371    fn catalog_save_load_roundtrip() {
372        let dir = std::env::temp_dir().join("kyu_test_persist_catalog");
373        let _ = std::fs::create_dir_all(&dir);
374
375        let catalog = make_test_catalog();
376        save_catalog(&dir, &catalog).unwrap();
377
378        let loaded = load_catalog(&dir).unwrap().unwrap();
379        assert_eq!(loaded.num_tables(), 1);
380        assert!(loaded.find_by_name("Person").is_some());
381        assert_eq!(loaded.next_table_id, catalog.next_table_id);
382
383        let _ = std::fs::remove_dir_all(&dir);
384    }
385
386    #[test]
387    fn load_catalog_missing_returns_none() {
388        let dir = std::env::temp_dir().join("kyu_test_persist_missing");
389        let _ = std::fs::create_dir_all(&dir);
390        let _ = std::fs::remove_file(dir.join(CATALOG_FILE));
391
392        let result = load_catalog(&dir).unwrap();
393        assert!(result.is_none());
394
395        let _ = std::fs::remove_dir_all(&dir);
396    }
397
398    #[test]
399    fn storage_save_load_roundtrip() {
400        let dir = std::env::temp_dir().join("kyu_test_persist_storage");
401        let _ = std::fs::create_dir_all(&dir);
402
403        let catalog = make_test_catalog();
404        let tid = TableId(0);
405        let schema = vec![LogicalType::Int64, LogicalType::String, LogicalType::Double];
406
407        let mut storage = NodeGroupStorage::new();
408        storage.create_table(tid, schema);
409        storage
410            .insert_row(tid, &[
411                TypedValue::Int64(1),
412                TypedValue::String(SmolStr::new("Alice")),
413                TypedValue::Double(95.5),
414            ])
415            .unwrap();
416        storage
417            .insert_row(tid, &[
418                TypedValue::Int64(2),
419                TypedValue::String(SmolStr::new("Bob")),
420                TypedValue::Double(87.3),
421            ])
422            .unwrap();
423
424        save_storage(&dir, &storage, &catalog).unwrap();
425
426        let loaded = load_storage(&dir, &catalog).unwrap();
427        assert!(loaded.has_table(tid));
428        assert_eq!(loaded.num_rows(tid), 2);
429
430        let rows = loaded.scan_rows(tid).unwrap();
431        assert_eq!(rows.len(), 2);
432        assert_eq!(rows[0].1[0], TypedValue::Int64(1));
433        assert_eq!(rows[0].1[1], TypedValue::String(SmolStr::new("Alice")));
434        assert_eq!(rows[0].1[2], TypedValue::Double(95.5));
435        assert_eq!(rows[1].1[0], TypedValue::Int64(2));
436
437        let _ = std::fs::remove_dir_all(&dir);
438    }
439
440    #[test]
441    fn storage_empty_table_roundtrip() {
442        let dir = std::env::temp_dir().join("kyu_test_persist_empty");
443        let _ = std::fs::create_dir_all(&dir);
444
445        let catalog = make_test_catalog();
446        let tid = TableId(0);
447        let mut storage = NodeGroupStorage::new();
448        storage.create_table(tid, vec![LogicalType::Int64, LogicalType::String, LogicalType::Double]);
449
450        save_storage(&dir, &storage, &catalog).unwrap();
451        let loaded = load_storage(&dir, &catalog).unwrap();
452        assert!(loaded.has_table(tid));
453        assert_eq!(loaded.num_rows(tid), 0);
454
455        let _ = std::fs::remove_dir_all(&dir);
456    }
457
458    #[test]
459    fn typed_value_binary_roundtrip() {
460        let values = vec![
461            TypedValue::Null,
462            TypedValue::Bool(true),
463            TypedValue::Bool(false),
464            TypedValue::Int8(-42),
465            TypedValue::Int16(1234),
466            TypedValue::Int32(-999999),
467            TypedValue::Int64(i64::MAX),
468            TypedValue::Float(3.14),
469            TypedValue::Double(2.718281828),
470            TypedValue::String(SmolStr::new("hello world")),
471            TypedValue::String(SmolStr::new("")),
472        ];
473
474        let mut buf = Vec::new();
475        for v in &values {
476            serialize_typed_value(&mut buf, v);
477        }
478
479        let mut offset = 0;
480        for expected in &values {
481            let actual = deserialize_typed_value(&buf, &mut offset).unwrap();
482            assert_eq!(&actual, expected);
483        }
484        assert_eq!(offset, buf.len());
485    }
486}