triviumdb 0.4.90

A high-performance memory-mmap hybrid search engine built for AI, combining dense vector, sparse text, graph relations, and JSON metadata.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
use crate::VectorType;
use crate::database::StorageMode;
use crate::error::{Result, TriviumError};
use crate::node::{Edge, NodeId};
use crate::storage::memtable::MemTable;
use crate::storage::vec_pool::VecPool;
use memmap2::Mmap;
use std::fs::File;
use std::io::{BufWriter, Write};
use std::path::Path;

/// Windows 下应对杀毒软件瞬态文件锁定的原子重命名
///
/// 杀毒软件(Windows Defender / 火绒等)会在文件关闭的瞬间以独占模式扫描,
/// 导致紧随其后的 rename 操作遇到 ERROR_SHARING_VIOLATION (32) 或
/// ERROR_ACCESS_DENIED (5)。此函数通过短暂指数退避重试来等待杀软释放锁。
///
/// 在非 Windows 平台上直接调用 std::fs::rename,零开销。
fn robust_rename(from: &Path, to: &Path) -> std::io::Result<()> {
    #[cfg(not(windows))]
    {
        return std::fs::rename(from, to);
    }

    #[cfg(windows)]
    {
        let max_retries = 10;
        let mut delay = std::time::Duration::from_millis(1);
        for attempt in 0..max_retries {
            match std::fs::rename(from, to) {
                Ok(()) => return Ok(()),
                Err(e) if attempt < max_retries - 1 => {
                    let os_err = e.raw_os_error();
                    // ERROR_ACCESS_DENIED (5) 或 ERROR_SHARING_VIOLATION (32)
                    if os_err == Some(5) || os_err == Some(32) {
                        tracing::debug!(
                            "robust_rename: attempt {} failed (os_error={:?}), retrying in {:?}",
                            attempt + 1, os_err, delay
                        );
                        std::thread::sleep(delay);
                        delay = (delay * 2).min(std::time::Duration::from_millis(50));
                        continue;
                    }
                    return Err(e);
                }
                Err(e) => return Err(e),
            }
        }
        unreachable!()
    }
}

// ══════ 文件头常量 ══════
const MAGIC: &[u8; 4] = b"TVDB";
const VERSION: u16 = 3; // v3:增加 ERPC 无锁加速索引元数据
const HEADER_SIZE: u64 = 58;

use crate::index::erpc::{ErpcIndex, ErpcParams, SeqEntry};

/// 向量文件路径(.tdb → .vec)
fn vec_path_from_db(db_path: &str) -> String {
    format!("{}.vec", db_path)
}

/// 刷新标记文件路径(.tdb → .flush_ok)
/// 该文件是 Mmap 双文件写入的"提交点",内含 .tdb 和 .vec 的文件大小
fn flush_ok_path_from_db(db_path: &str) -> String {
    format!("{}.flush_ok", db_path)
}

pub fn save<T: VectorType>(
    memtable: &mut MemTable<T>,
    path: &str,
    mode: StorageMode,
) -> Result<()> {
    match mode {
        StorageMode::Mmap => save_mmap(memtable, path),
        StorageMode::Rom => save_rom(memtable, path),
    }
}

/// Mmap 模式保存:分离向量到 .vec 文件,.tdb 纯元数据
fn save_mmap<T: VectorType>(memtable: &mut MemTable<T>, path: &str) -> Result<()> {
    let vec_file_path = vec_path_from_db(path);
    let vec_count = memtable.vec_pool_mut().flush(Path::new(&vec_file_path))?;
    save_tdb(memtable, path, vec_count, true)?;

    // ═══ 跨文件一致性标记(提交点) ═══
    // .vec 和 .tdb 都已原子替换成功后,才写入 .flush_ok 标记。
    // 加载时校验此标记来检测撕裂写入。
    let tdb_size = std::fs::metadata(path).map(|m| m.len()).unwrap_or(0);
    let vec_size = std::fs::metadata(&vec_file_path)
        .map(|m| m.len())
        .unwrap_or(0);
    let marker_path = flush_ok_path_from_db(path);
    let marker_tmp = format!("{}.tmp", &marker_path);
    {
        let mut f = File::create(&marker_tmp)?;
        f.write_all(&tdb_size.to_le_bytes())?;
        f.write_all(&vec_size.to_le_bytes())?;
        f.sync_all()?;
    }
    robust_rename(Path::new(&marker_tmp), Path::new(&marker_path))?;

    Ok(())
}

/// Rom 模式保存:把向量合并,写单文件,抛弃 .vec
fn save_rom<T: VectorType>(memtable: &mut MemTable<T>, path: &str) -> Result<()> {
    // 1. 确保在纯内存中获取到完整的合并数组
    memtable.ensure_vectors_cache();
    let total_vectors = memtable.internal_indices().len();

    // 2. 将数据合并写入单文件
    save_tdb(memtable, path, total_vectors, false)?;

    // 3. 将现有的 mmap (如果有) 剥离到内存 delta 中,避免锁住已或将被删除的 .vec
    memtable.vec_pool_mut().detach_mmap();

    // 4. 清理残留的 .vec 和 .flush_ok
    let vec_file_path = vec_path_from_db(path);
    if Path::new(&vec_file_path).exists() {
        std::fs::remove_file(vec_file_path).ok();
    }
    let marker_path = flush_ok_path_from_db(path);
    if Path::new(&marker_path).exists() {
        std::fs::remove_file(marker_path).ok();
    }

    Ok(())
}

/// 核心通用写入逻辑:将 MemTable (Payload & Edge) 写入 .tdb
fn save_tdb<T: VectorType>(
    memtable: &mut MemTable<T>,
    path: &str,
    vec_count: usize,
    is_mmap_mode: bool,
) -> Result<()> {
    if !is_mmap_mode {
        memtable.ensure_vectors_cache();
    }

    let tmp_path = format!("{}.tmp", path);
    let file = File::create(&tmp_path)?;
    let mut w = BufWriter::new(file);

    let dim = memtable.dim();

    // 我们必须按照内部索引数组以防止在重载时 NodeID/Vector 错位
    let internal_indices = memtable.internal_indices();
    // 实际写入的记录数量等于从向量池生成的记录数(包括空洞 Tombstones)
    let node_count = internal_indices.len() as u64;

    let mut all_edges: Vec<(NodeId, &Edge)> = Vec::new();
    let mut payload_size: u64 = 0;

    // 计算 Payload 块大小并收集边
    for &nid in internal_indices {
        if nid != 0 {
            // 有效节点
            if let Some(p) = memtable.get_payload(nid) {
                let json_bytes = serde_json::to_vec(p).unwrap_or_default();
                payload_size += 8 + 4 + json_bytes.len() as u64;
            } else {
                // tombstone 占位符结构:NodeId (0) + len (0) = 12 bytes
                payload_size += 12;
            }
            if let Some(edges) = memtable.get_edges(nid) {
                for edge in edges {
                    all_edges.push((nid, edge));
                }
            }
        } else {
            // 空洞(由于节点被彻底移除,保留内部索引占位)
            payload_size += 12;
        }
    }

    let payload_offset = HEADER_SIZE;
    let vector_offset = if is_mmap_mode {
        0
    } else {
        payload_offset + payload_size
    };
    let vector_size = if is_mmap_mode {
        0
    } else {
        node_count * (dim as u64) * (std::mem::size_of::<T>() as u64)
    };
    let edge_offset = payload_offset + payload_size + vector_size;

    let edge_size = all_edges.iter().map(|(_, e)| (8 + 8 + 2 + e.label.len() + 4) as u64).sum::<u64>();
    let erpc_offset = edge_offset + edge_size;

    // 1. Header
    w.write_all(MAGIC)?;
    w.write_all(&VERSION.to_le_bytes())?;
    w.write_all(&(dim as u32).to_le_bytes())?;
    w.write_all(&memtable.next_id_value().to_le_bytes())?;
    w.write_all(&node_count.to_le_bytes())?;
    w.write_all(&payload_offset.to_le_bytes())?;
    w.write_all(&vector_offset.to_le_bytes())?;
    w.write_all(&edge_offset.to_le_bytes())?;
    w.write_all(&erpc_offset.to_le_bytes())?;

    // 2. Payload Block 包含 Tombstones
    for &nid in internal_indices {
        if nid != 0
            && let Some(p) = memtable.get_payload(nid) {
                let json_bytes = serde_json::to_vec(p).unwrap_or_default();
                w.write_all(&nid.to_le_bytes())?;
                w.write_all(&(json_bytes.len() as u32).to_le_bytes())?;
                w.write_all(&json_bytes)?;
                continue;
            }
        // Tombstone
        w.write_all(&0u64.to_le_bytes())?;
        w.write_all(&0u32.to_le_bytes())?;
    }

    // 3. Vector Block (Rom 用)
    if !is_mmap_mode {
        let flat = memtable.flat_vectors();
        w.write_all(bytemuck::cast_slice(flat))?;
    }

    // 4. Edge Block
    for (src_id, edge) in &all_edges {
        w.write_all(&src_id.to_le_bytes())?;
        w.write_all(&edge.target_id.to_le_bytes())?;
        let label_bytes = edge.label.as_bytes();
        w.write_all(&(label_bytes.len() as u16).to_le_bytes())?;
        w.write_all(label_bytes)?;
        w.write_all(&edge.weight.to_le_bytes())?;
    }

    // 5. ERPC Metadata Block (If exists)
    if let Some(erpc) = &memtable.erpc_index {
        w.write_all(bytemuck::bytes_of(&erpc.params))?;
        for chunk_centers in &erpc.pq_centers {
            for c in chunk_centers {
                for &fv in c { w.write_all(&fv.to_le_bytes())?; }
            }
        }
        for c in &erpc.centers {
            // 中心点扁平化
            for &fv in c { w.write_all(&fv.to_le_bytes())?; }
        }
        w.write_all(bytemuck::cast_slice(&erpc.sequence))?;
    }

    w.flush()?;
    let file = w
        .into_inner()
        .map_err(|e| TriviumError::Io(e.into_error()))?;
    file.sync_all()?;
    drop(file);

    robust_rename(Path::new(&tmp_path), Path::new(path))?;

    tracing::info!(
        "持久化完成: {} 个槽位(含删除), {} 个向量, Mode: {}",
        node_count,
        vec_count,
        if is_mmap_mode { "Mmap" } else { "Rom" }
    );

    Ok(())
}

pub fn load<T: VectorType>(path: &str, _mode: StorageMode) -> Result<MemTable<T>> {
    let file = File::open(path).map_err(TriviumError::Io)?;

    let mmap = unsafe { Mmap::map(&file) }.map_err(TriviumError::Io)?;

    if mmap.len() < HEADER_SIZE as usize {
        return Err(TriviumError::Generic("File too small for header".into()));
    }

    let bytes = &mmap[..];
    if &bytes[0..4] != MAGIC {
        return Err(TriviumError::Generic(format!(
            "Invalid file magic: expected TVDB, got {:?}",
            &bytes[0..4]
        )));
    }

    let dim = u32::from_le_bytes(bytes[6..10].try_into().unwrap()) as usize;
    let next_id = u64::from_le_bytes(bytes[10..18].try_into().unwrap());
    let node_count = u64::from_le_bytes(bytes[18..26].try_into().unwrap()) as usize;
    let payload_offset = u64::from_le_bytes(bytes[26..34].try_into().unwrap()) as usize;
    let vector_offset = u64::from_le_bytes(bytes[34..42].try_into().unwrap()) as usize;
    let edge_offset = u64::from_le_bytes(bytes[42..50].try_into().unwrap()) as usize;
    
    // 兼容旧版 V2 的情况下补齐 erpc_offset 游标
    let erpc_offset = if mmap.len() >= 58 {
        u64::from_le_bytes(bytes[50..58].try_into().unwrap()) as usize
    } else {
        mmap.len()
    };

    let vec_file_path = vec_path_from_db(path);

    // 如果 vector_offset 是 0 说明是分离架构,且存在 .vec 则按 Mmap 加载
    // 无论目前 config 设置的模式是什么,如果在初始化加载时已经存在可用的 .vec 结构,应当正确恢复它
    // 由下一次 flush 再按照最新的 StorageMode 决定写出格式
    if vector_offset == 0 && Path::new(&vec_file_path).exists() {
        // ═══ 跨文件一致性校验 ═══
        // 检查 .flush_ok 标记是否存在且文件大小吻合,防止撕裂写入
        let marker_path = flush_ok_path_from_db(path);
        let flush_ok_valid = (|| -> Option<bool> {
            let marker_bytes = std::fs::read(&marker_path).ok()?;
            if marker_bytes.len() < 16 {
                return Some(false);
            }
            let stored_tdb = u64::from_le_bytes(marker_bytes[0..8].try_into().ok()?);
            let stored_vec = u64::from_le_bytes(marker_bytes[8..16].try_into().ok()?);
            let actual_tdb = std::fs::metadata(path).ok()?.len();
            let actual_vec = std::fs::metadata(&vec_file_path).ok()?.len();
            Some(stored_tdb == actual_tdb && stored_vec == actual_vec)
        })()
        .unwrap_or(false);

        if flush_ok_valid {
            load_v2(
                bytes,
                dim,
                next_id,
                node_count,
                payload_offset,
                edge_offset,
                erpc_offset,
                &vec_file_path,
                &mmap,
            )
        } else {
            tracing::warn!(
                "检测到 .tdb/.vec 跨文件撕裂(.flush_ok 标记缺失或不匹配),\
                 降级为忽略 .vec 的安全模式加载,增量数据将由 WAL 回放恢复"
            );
            // 降级:忽略 .vec,仅从 .tdb 的 payload/edge 恢复骨架
            // node_count 仍有效,但向量数据丢失,需要 WAL 或下次 flush 重建
            load_v1_rom(
                bytes,
                dim,
                next_id,
                node_count,
                payload_offset,
                vector_offset,
                edge_offset,
                erpc_offset,
                &mmap,
            )
        }
    } else {
        load_v1_rom(
            bytes,
            dim,
            next_id,
            node_count,
            payload_offset,
            vector_offset,
            edge_offset,
            erpc_offset,
            &mmap,
        )
    }
}

/// 分离向量 .vec 文件的加载
fn load_v2<T: VectorType>(
    bytes: &[u8],
    dim: usize,
    next_id: u64,
    node_count: usize,
    payload_offset: usize,
    edge_offset: usize,
    erpc_offset: usize,
    vec_file_path: &str,
    _tdb_mmap: &Mmap,
) -> Result<MemTable<T>> {
    let vec_pool = VecPool::<T>::open(Path::new(vec_file_path), dim, node_count)?;
    let mut memtable = MemTable::new_with_vec_pool(dim, next_id, vec_pool);
    load_payloads(
        &mut memtable,
        bytes,
        node_count,
        payload_offset,
        edge_offset,
    )?;
    load_edges(&mut memtable, bytes, edge_offset, erpc_offset)?;
    memtable.erpc_index = load_erpc(bytes, erpc_offset, dim)?;
    Ok(memtable)
}

/// 单文件内存向量的加载
fn load_v1_rom<T: VectorType>(
    bytes: &[u8],
    dim: usize,
    next_id: u64,
    node_count: usize,
    payload_offset: usize,
    vector_offset: usize,
    edge_offset: usize,
    erpc_offset: usize,
    tdb_mmap: &Mmap,
) -> Result<MemTable<T>> {
    let mut memtable = MemTable::new_with_next_id(dim, next_id);
    let vector_bytes_per_elem = std::mem::size_of::<T>();
    let expected_vec_size = node_count * dim * vector_bytes_per_elem;

    if vector_offset + expected_vec_size > tdb_mmap.len() {
        return Err(TriviumError::Generic(
            "Vector block exceeds file size".into(),
        ));
    }

    // 先恢复映射位置和 Payload
    load_payloads(
        &mut memtable,
        bytes,
        node_count,
        payload_offset,
        vector_offset,
    )?;

    let vec_block = &bytes[vector_offset..vector_offset + expected_vec_size];
    let is_aligned = (vec_block.as_ptr() as usize).is_multiple_of(std::mem::align_of::<T>());

    // 因为 load_payloads 已经按内部索引位置推了占位符(包含 Tombstone),
    // 接下来我们只需要把所有的 vector_block 推入 VecPool!
    if is_aligned {
        let t_slice =
            unsafe { std::slice::from_raw_parts(vec_block.as_ptr() as *const T, node_count * dim) };
        memtable.vec_pool_mut().push(t_slice);
    } else {
        // 不对齐
        let mut v = Vec::with_capacity(node_count * dim);
        for i in 0..(node_count * dim) {
            let off = i * vector_bytes_per_elem;
            let chunk = &vec_block[off..off + vector_bytes_per_elem];
            let elem: T = bytemuck::pod_read_unaligned(chunk);
            v.push(elem);
        }
        memtable.vec_pool_mut().push(&v);
    }

    load_edges(&mut memtable, bytes, edge_offset, erpc_offset)?;
    memtable.erpc_index = load_erpc(bytes, erpc_offset, dim)?;
    Ok(memtable)
}

/// 解析 Payload Block,处理 Tombstone
fn load_payloads<T: VectorType>(
    memtable: &mut MemTable<T>,
    bytes: &[u8],
    node_count: usize,
    offset: usize,
    end_offset: usize,
) -> Result<()> {
    let mut cursor = offset;
    for _ in 0..node_count {
        if cursor + 12 > end_offset {
            return Err(TriviumError::Generic("Payload block overflow".into()));
        }
        let nid = u64::from_le_bytes(bytes[cursor..cursor + 8].try_into().unwrap());
        cursor += 8;
        let json_len = u32::from_le_bytes(bytes[cursor..cursor + 4].try_into().unwrap()) as usize;
        cursor += 4;

        if nid == 0 && json_len == 0 {
            memtable.register_tombstone()?;
            continue;
        }

        if cursor + json_len > end_offset {
            return Err(TriviumError::Generic("JSON data overflow".into()));
        }
        let payload: serde_json::Value = serde_json::from_slice(&bytes[cursor..cursor + json_len])
            .map_err(|e| TriviumError::Generic(format!("JSON parse error: {}", e)))?;
        cursor += json_len;

        memtable.register_node(nid, payload)?;
    }
    Ok(())
}

fn load_edges<T: VectorType>(
    memtable: &mut MemTable<T>,
    bytes: &[u8],
    edge_offset: usize,
    file_len: usize,
) -> Result<()> {
    let mut cursor = edge_offset;
    while cursor + 18 <= file_len {
        let src_id = u64::from_le_bytes(bytes[cursor..cursor + 8].try_into().unwrap());
        cursor += 8;
        let dst_id = u64::from_le_bytes(bytes[cursor..cursor + 8].try_into().unwrap());
        cursor += 8;
        let label_len = u16::from_le_bytes(bytes[cursor..cursor + 2].try_into().unwrap()) as usize;
        cursor += 2;
        if cursor + label_len + 4 > file_len {
            break;
        }
        let label = String::from_utf8(bytes[cursor..cursor + label_len].to_vec())
            .map_err(|e| TriviumError::Generic(format!("Label decode error: {}", e)))?;
        cursor += label_len;
        let weight = f32::from_le_bytes(bytes[cursor..cursor + 4].try_into().unwrap());
        cursor += 4;
        memtable.link(src_id, dst_id, label, weight)?;
    }
    Ok(())
}

fn load_erpc(bytes: &[u8], erpc_offset: usize, dim: usize) -> Result<Option<ErpcIndex>> {
    if erpc_offset >= bytes.len() || bytes.len() - erpc_offset < 32 {
        return Ok(None);
    }
    let mut cursor = erpc_offset;
    
    // 强制按 byte 对齐拷贝
    let params: ErpcParams = bytemuck::pod_read_unaligned(&bytes[cursor..cursor + 32]);
    cursor += 32;

    let chunks = crate::index::erpc::CHUNKS;
    let pq_k = crate::index::erpc::PQ_K;
    
    let mut chunk_bounds = Vec::with_capacity(chunks);
    let base = dim / chunks;
    let mut remainder = dim % chunks;
    let mut start = 0;
    for _ in 0..chunks {
        let len = base + if remainder > 0 { 1 } else { 0 };
        if remainder > 0 { remainder -= 1; }
        chunk_bounds.push((start, start + len));
        start += len;
    }

    let mut pq_centers = Vec::with_capacity(chunks);
    for c in 0..chunks {
        let (s, e) = chunk_bounds[c];
        let sub_dim = e - s;
        let mut pc = Vec::with_capacity(pq_k);
        for _ in 0..pq_k {
            let mut pt = Vec::with_capacity(sub_dim);
            for _ in 0..sub_dim {
                let offset = cursor;
                pt.push(f32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap()));
                cursor += 4;
            }
            pc.push(pt);
        }
        pq_centers.push(pc);
    }

    let k_clusters = params.k_clusters as usize;
    let mut centers = Vec::with_capacity(k_clusters);
    for _ in 0..k_clusters {
        let size = dim * 4;
        let mut c = Vec::with_capacity(dim);
        for i in 0..dim {
            let offset = cursor + i * 4;
            let val = f32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap());
            c.push(val);
        }
        centers.push(c);
        cursor += size;
    }

    let sequence_bytes = &bytes[cursor..];
    let expected_entry_size = std::mem::size_of::<SeqEntry>();
    let count = sequence_bytes.len() / expected_entry_size;
    let mut sequence: Vec<SeqEntry> = vec![bytemuck::Zeroable::zeroed(); count];
    
    unsafe {
        std::ptr::copy_nonoverlapping(
            sequence_bytes.as_ptr(),
            sequence.as_mut_ptr() as *mut u8,
            count * expected_entry_size,
        );
    }

    Ok(Some(ErpcIndex {
        lsh_basis: Vec::new(),
        centers,
        sequence,
        dim,
        params,
        pq_centers,
        chunk_bounds,
    }))
}