Skip to main content

kyu_api/
persistence.rs

1//! Persistence — save/load catalog and storage to/from disk.
2//!
3//! Layout in database directory:
4//!   catalog.json   — serialized CatalogContent
5//!   data/<tid>.bin — per-table row data in a simple binary format
6
7use std::path::Path;
8
9use kyu_catalog::CatalogContent;
10use kyu_common::id::TableId;
11use kyu_common::{KyuError, KyuResult};
12use kyu_types::{LogicalType, TypedValue};
13use smol_str::SmolStr;
14
15use crate::storage::NodeGroupStorage;
16
17const CATALOG_FILE: &str = "catalog.json";
18const DATA_DIR: &str = "data";
19
20// ---------------------------------------------------------------------------
21// Catalog persistence
22// ---------------------------------------------------------------------------
23
24/// Save catalog content to `dir/catalog.json`.
25pub fn save_catalog(dir: &Path, catalog: &CatalogContent) -> KyuResult<()> {
26    let json = catalog.serialize_json();
27    let path = dir.join(CATALOG_FILE);
28    std::fs::write(&path, json.as_bytes()).map_err(|e| {
29        KyuError::Storage(format!("cannot write catalog to '{}': {e}", path.display()))
30    })
31}
32
33/// Load catalog content from `dir/catalog.json`. Returns `None` if not found.
34pub fn load_catalog(dir: &Path) -> KyuResult<Option<CatalogContent>> {
35    let path = dir.join(CATALOG_FILE);
36    if !path.exists() {
37        return Ok(None);
38    }
39    let json = std::fs::read_to_string(&path).map_err(|e| {
40        KyuError::Storage(format!(
41            "cannot read catalog from '{}': {e}",
42            path.display()
43        ))
44    })?;
45    let content = CatalogContent::deserialize_json(&json)
46        .map_err(|e| KyuError::Storage(format!("cannot parse catalog JSON: {e}")))?;
47    Ok(Some(content))
48}
49
50// ---------------------------------------------------------------------------
51// Storage persistence
52// ---------------------------------------------------------------------------
53
54/// Save all table data to `dir/data/<table_id>.bin`.
55pub fn save_storage(
56    dir: &Path,
57    storage: &NodeGroupStorage,
58    catalog: &CatalogContent,
59) -> KyuResult<()> {
60    let data_dir = dir.join(DATA_DIR);
61    std::fs::create_dir_all(&data_dir).map_err(|e| {
62        KyuError::Storage(format!(
63            "cannot create data dir '{}': {e}",
64            data_dir.display()
65        ))
66    })?;
67
68    // Remove stale files for tables that no longer exist.
69    if let Ok(entries) = std::fs::read_dir(&data_dir) {
70        for entry in entries.flatten() {
71            let name = entry.file_name();
72            let name_str = name.to_string_lossy();
73            if name_str.ends_with(".bin")
74                && let Ok(tid) = name_str.trim_end_matches(".bin").parse::<u64>()
75                && !storage.has_table(TableId(tid))
76            {
77                let _ = std::fs::remove_file(entry.path());
78            }
79        }
80    }
81
82    // Save node tables.
83    for nt in catalog.node_tables() {
84        save_table(&data_dir, nt.table_id, storage)?;
85    }
86
87    // Save rel tables.
88    for rt in catalog.rel_tables() {
89        save_table(&data_dir, rt.table_id, storage)?;
90    }
91
92    Ok(())
93}
94
95fn save_table(data_dir: &Path, table_id: TableId, storage: &NodeGroupStorage) -> KyuResult<()> {
96    let rows = storage.scan_rows(table_id)?;
97    let path = data_dir.join(format!("{}.bin", table_id.0));
98
99    let mut buf = Vec::new();
100
101    // Header: num_rows as u64 LE
102    buf.extend_from_slice(&(rows.len() as u64).to_le_bytes());
103
104    // Each row: num_cols u32 LE, then each value
105    for (_row_idx, values) in &rows {
106        buf.extend_from_slice(&(values.len() as u32).to_le_bytes());
107        for val in values {
108            serialize_typed_value(&mut buf, val);
109        }
110    }
111
112    std::fs::write(&path, &buf).map_err(|e| {
113        KyuError::Storage(format!(
114            "cannot write table data to '{}': {e}",
115            path.display()
116        ))
117    })
118}
119
120/// Load all table data from `dir/data/`.
121pub fn load_storage(dir: &Path, catalog: &CatalogContent) -> KyuResult<NodeGroupStorage> {
122    let mut storage = NodeGroupStorage::new();
123    let data_dir = dir.join(DATA_DIR);
124
125    // Create tables from catalog schema.
126    for nt in catalog.node_tables() {
127        let schema: Vec<LogicalType> = nt.properties.iter().map(|p| p.data_type.clone()).collect();
128        storage.create_table(nt.table_id, schema);
129    }
130    for rt in catalog.rel_tables() {
131        // Rel tables store: src_key, dst_key, then user properties.
132        // We need to reconstruct the storage schema the same way connection.rs does.
133        let from_key_type = catalog
134            .find_by_id(rt.from_table_id)
135            .and_then(|e| e.as_node_table())
136            .map(|n| n.primary_key_property().data_type.clone())
137            .unwrap_or(LogicalType::Int64);
138        let to_key_type = catalog
139            .find_by_id(rt.to_table_id)
140            .and_then(|e| e.as_node_table())
141            .map(|n| n.primary_key_property().data_type.clone())
142            .unwrap_or(LogicalType::Int64);
143        let mut schema = vec![from_key_type, to_key_type];
144        schema.extend(rt.properties.iter().map(|p| p.data_type.clone()));
145        storage.create_table(rt.table_id, schema);
146    }
147
148    if !data_dir.exists() {
149        return Ok(storage);
150    }
151
152    // Load row data from .bin files.
153    if let Ok(entries) = std::fs::read_dir(&data_dir) {
154        for entry in entries.flatten() {
155            let name = entry.file_name();
156            let name_str = name.to_string_lossy();
157            if !name_str.ends_with(".bin") {
158                continue;
159            }
160            if let Ok(tid) = name_str.trim_end_matches(".bin").parse::<u64>() {
161                let table_id = TableId(tid);
162                if storage.has_table(table_id) {
163                    load_table_rows(&entry.path(), table_id, &mut storage)?;
164                }
165            }
166        }
167    }
168
169    Ok(storage)
170}
171
172fn load_table_rows(
173    path: &Path,
174    table_id: TableId,
175    storage: &mut NodeGroupStorage,
176) -> KyuResult<()> {
177    let data = std::fs::read(path).map_err(|e| {
178        KyuError::Storage(format!(
179            "cannot read table data from '{}': {e}",
180            path.display()
181        ))
182    })?;
183
184    if data.len() < 8 {
185        return Ok(()); // Empty or corrupt file
186    }
187
188    let mut offset = 0;
189    let num_rows = read_u64_le(&data, &mut offset);
190
191    for _ in 0..num_rows {
192        if offset + 4 > data.len() {
193            break;
194        }
195        let num_cols = read_u32_le(&data, &mut offset) as usize;
196        let mut values = Vec::with_capacity(num_cols);
197        for _ in 0..num_cols {
198            let val = deserialize_typed_value(&data, &mut offset)?;
199            values.push(val);
200        }
201        storage.insert_row(table_id, &values)?;
202    }
203
204    Ok(())
205}
206
207// ---------------------------------------------------------------------------
208// TypedValue binary serialization
209// ---------------------------------------------------------------------------
210
211// Tags for TypedValue serialization.
212const TAG_NULL: u8 = 0;
213const TAG_BOOL: u8 = 1;
214const TAG_INT8: u8 = 2;
215const TAG_INT16: u8 = 3;
216const TAG_INT32: u8 = 4;
217const TAG_INT64: u8 = 5;
218const TAG_FLOAT: u8 = 6;
219const TAG_DOUBLE: u8 = 7;
220const TAG_STRING: u8 = 8;
221
222fn serialize_typed_value(buf: &mut Vec<u8>, val: &TypedValue) {
223    match val {
224        TypedValue::Null => buf.push(TAG_NULL),
225        TypedValue::Bool(v) => {
226            buf.push(TAG_BOOL);
227            buf.push(if *v { 1 } else { 0 });
228        }
229        TypedValue::Int8(v) => {
230            buf.push(TAG_INT8);
231            buf.extend_from_slice(&v.to_le_bytes());
232        }
233        TypedValue::Int16(v) => {
234            buf.push(TAG_INT16);
235            buf.extend_from_slice(&v.to_le_bytes());
236        }
237        TypedValue::Int32(v) => {
238            buf.push(TAG_INT32);
239            buf.extend_from_slice(&v.to_le_bytes());
240        }
241        TypedValue::Int64(v) => {
242            buf.push(TAG_INT64);
243            buf.extend_from_slice(&v.to_le_bytes());
244        }
245        TypedValue::Float(v) => {
246            buf.push(TAG_FLOAT);
247            buf.extend_from_slice(&v.to_le_bytes());
248        }
249        TypedValue::Double(v) => {
250            buf.push(TAG_DOUBLE);
251            buf.extend_from_slice(&v.to_le_bytes());
252        }
253        TypedValue::String(s) => {
254            buf.push(TAG_STRING);
255            let bytes = s.as_bytes();
256            buf.extend_from_slice(&(bytes.len() as u32).to_le_bytes());
257            buf.extend_from_slice(bytes);
258        }
259        // Unsupported types stored as null.
260        _ => buf.push(TAG_NULL),
261    }
262}
263
264fn deserialize_typed_value(data: &[u8], offset: &mut usize) -> KyuResult<TypedValue> {
265    if *offset >= data.len() {
266        return Err(KyuError::Storage("unexpected end of table data".into()));
267    }
268    let tag = data[*offset];
269    *offset += 1;
270
271    match tag {
272        TAG_NULL => Ok(TypedValue::Null),
273        TAG_BOOL => {
274            ensure_remaining(data, *offset, 1)?;
275            let v = data[*offset] != 0;
276            *offset += 1;
277            Ok(TypedValue::Bool(v))
278        }
279        TAG_INT8 => {
280            ensure_remaining(data, *offset, 1)?;
281            let v = data[*offset] as i8;
282            *offset += 1;
283            Ok(TypedValue::Int8(v))
284        }
285        TAG_INT16 => {
286            ensure_remaining(data, *offset, 2)?;
287            let v = i16::from_le_bytes(data[*offset..*offset + 2].try_into().unwrap());
288            *offset += 2;
289            Ok(TypedValue::Int16(v))
290        }
291        TAG_INT32 => {
292            ensure_remaining(data, *offset, 4)?;
293            let v = i32::from_le_bytes(data[*offset..*offset + 4].try_into().unwrap());
294            *offset += 4;
295            Ok(TypedValue::Int32(v))
296        }
297        TAG_INT64 => {
298            ensure_remaining(data, *offset, 8)?;
299            let v = i64::from_le_bytes(data[*offset..*offset + 8].try_into().unwrap());
300            *offset += 8;
301            Ok(TypedValue::Int64(v))
302        }
303        TAG_FLOAT => {
304            ensure_remaining(data, *offset, 4)?;
305            let v = f32::from_le_bytes(data[*offset..*offset + 4].try_into().unwrap());
306            *offset += 4;
307            Ok(TypedValue::Float(v))
308        }
309        TAG_DOUBLE => {
310            ensure_remaining(data, *offset, 8)?;
311            let v = f64::from_le_bytes(data[*offset..*offset + 8].try_into().unwrap());
312            *offset += 8;
313            Ok(TypedValue::Double(v))
314        }
315        TAG_STRING => {
316            ensure_remaining(data, *offset, 4)?;
317            let len = u32::from_le_bytes(data[*offset..*offset + 4].try_into().unwrap()) as usize;
318            *offset += 4;
319            ensure_remaining(data, *offset, len)?;
320            let s = std::str::from_utf8(&data[*offset..*offset + len])
321                .map_err(|e| KyuError::Storage(format!("invalid UTF-8 in table data: {e}")))?;
322            *offset += len;
323            Ok(TypedValue::String(SmolStr::new(s)))
324        }
325        _ => Err(KyuError::Storage(format!("unknown TypedValue tag: {tag}"))),
326    }
327}
328
329fn read_u64_le(data: &[u8], offset: &mut usize) -> u64 {
330    let v = u64::from_le_bytes(data[*offset..*offset + 8].try_into().unwrap());
331    *offset += 8;
332    v
333}
334
335fn read_u32_le(data: &[u8], offset: &mut usize) -> u32 {
336    let v = u32::from_le_bytes(data[*offset..*offset + 4].try_into().unwrap());
337    *offset += 4;
338    v
339}
340
341fn ensure_remaining(data: &[u8], offset: usize, needed: usize) -> KyuResult<()> {
342    if offset + needed > data.len() {
343        Err(KyuError::Storage("unexpected end of table data".into()))
344    } else {
345        Ok(())
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352    use kyu_catalog::{NodeTableEntry, Property};
353
354    fn make_test_catalog() -> CatalogContent {
355        let mut c = CatalogContent::new();
356        let tid = c.alloc_table_id();
357        let pid0 = c.alloc_property_id();
358        let pid1 = c.alloc_property_id();
359        let pid2 = c.alloc_property_id();
360        c.add_node_table(NodeTableEntry {
361            table_id: tid,
362            name: SmolStr::new("Person"),
363            properties: vec![
364                Property::new(pid0, "id", LogicalType::Int64, true),
365                Property::new(pid1, "name", LogicalType::String, false),
366                Property::new(pid2, "score", LogicalType::Double, false),
367            ],
368            primary_key_idx: 0,
369            num_rows: 0,
370            comment: None,
371        })
372        .unwrap();
373        c
374    }
375
376    #[test]
377    fn catalog_save_load_roundtrip() {
378        let dir = std::env::temp_dir().join("kyu_test_persist_catalog");
379        let _ = std::fs::create_dir_all(&dir);
380
381        let catalog = make_test_catalog();
382        save_catalog(&dir, &catalog).unwrap();
383
384        let loaded = load_catalog(&dir).unwrap().unwrap();
385        assert_eq!(loaded.num_tables(), 1);
386        assert!(loaded.find_by_name("Person").is_some());
387        assert_eq!(loaded.next_table_id, catalog.next_table_id);
388
389        let _ = std::fs::remove_dir_all(&dir);
390    }
391
392    #[test]
393    fn load_catalog_missing_returns_none() {
394        let dir = std::env::temp_dir().join("kyu_test_persist_missing");
395        let _ = std::fs::create_dir_all(&dir);
396        let _ = std::fs::remove_file(dir.join(CATALOG_FILE));
397
398        let result = load_catalog(&dir).unwrap();
399        assert!(result.is_none());
400
401        let _ = std::fs::remove_dir_all(&dir);
402    }
403
404    #[test]
405    fn storage_save_load_roundtrip() {
406        let dir = std::env::temp_dir().join("kyu_test_persist_storage");
407        let _ = std::fs::create_dir_all(&dir);
408
409        let catalog = make_test_catalog();
410        let tid = TableId(0);
411        let schema = vec![LogicalType::Int64, LogicalType::String, LogicalType::Double];
412
413        let mut storage = NodeGroupStorage::new();
414        storage.create_table(tid, schema);
415        storage
416            .insert_row(
417                tid,
418                &[
419                    TypedValue::Int64(1),
420                    TypedValue::String(SmolStr::new("Alice")),
421                    TypedValue::Double(95.5),
422                ],
423            )
424            .unwrap();
425        storage
426            .insert_row(
427                tid,
428                &[
429                    TypedValue::Int64(2),
430                    TypedValue::String(SmolStr::new("Bob")),
431                    TypedValue::Double(87.3),
432                ],
433            )
434            .unwrap();
435
436        save_storage(&dir, &storage, &catalog).unwrap();
437
438        let loaded = load_storage(&dir, &catalog).unwrap();
439        assert!(loaded.has_table(tid));
440        assert_eq!(loaded.num_rows(tid), 2);
441
442        let rows = loaded.scan_rows(tid).unwrap();
443        assert_eq!(rows.len(), 2);
444        assert_eq!(rows[0].1[0], TypedValue::Int64(1));
445        assert_eq!(rows[0].1[1], TypedValue::String(SmolStr::new("Alice")));
446        assert_eq!(rows[0].1[2], TypedValue::Double(95.5));
447        assert_eq!(rows[1].1[0], TypedValue::Int64(2));
448
449        let _ = std::fs::remove_dir_all(&dir);
450    }
451
452    #[test]
453    fn storage_empty_table_roundtrip() {
454        let dir = std::env::temp_dir().join("kyu_test_persist_empty");
455        let _ = std::fs::create_dir_all(&dir);
456
457        let catalog = make_test_catalog();
458        let tid = TableId(0);
459        let mut storage = NodeGroupStorage::new();
460        storage.create_table(
461            tid,
462            vec![LogicalType::Int64, LogicalType::String, LogicalType::Double],
463        );
464
465        save_storage(&dir, &storage, &catalog).unwrap();
466        let loaded = load_storage(&dir, &catalog).unwrap();
467        assert!(loaded.has_table(tid));
468        assert_eq!(loaded.num_rows(tid), 0);
469
470        let _ = std::fs::remove_dir_all(&dir);
471    }
472
473    #[test]
474    fn typed_value_binary_roundtrip() {
475        let values = vec![
476            TypedValue::Null,
477            TypedValue::Bool(true),
478            TypedValue::Bool(false),
479            TypedValue::Int8(-42),
480            TypedValue::Int16(1234),
481            TypedValue::Int32(-999999),
482            TypedValue::Int64(i64::MAX),
483            TypedValue::Float(3.14),
484            TypedValue::Double(2.718281828),
485            TypedValue::String(SmolStr::new("hello world")),
486            TypedValue::String(SmolStr::new("")),
487        ];
488
489        let mut buf = Vec::new();
490        for v in &values {
491            serialize_typed_value(&mut buf, v);
492        }
493
494        let mut offset = 0;
495        for expected in &values {
496            let actual = deserialize_typed_value(&buf, &mut offset).unwrap();
497            assert_eq!(&actual, expected);
498        }
499        assert_eq!(offset, buf.len());
500    }
501}