use crate::VectorType;
use crate::error::{Result, TriviumError};
use crate::index::bq::BqSignature;
use crate::index::int8::Int8Pool;
use crate::index::text::TextIndex;
use crate::node::{Edge, NodeId};
use crate::storage::vec_pool::VecPool;
use std::collections::HashMap;
fn calculate_json_signature(value: &serde_json::Value) -> u64 {
let mut sig = 0u64;
flatten_and_hash_json("", value, &mut sig);
sig
}
fn flatten_and_hash_json(prefix: &str, value: &serde_json::Value, sig: &mut u64) {
use std::hash::{Hash, Hasher};
match value {
serde_json::Value::Object(map) => {
for (k, v) in map {
let new_prefix = if prefix.is_empty() { k.clone() } else { format!("{}.{}", prefix, k) };
flatten_and_hash_json(&new_prefix, v, sig);
}
}
serde_json::Value::Array(arr) => {
for v in arr {
flatten_and_hash_json(prefix, v, sig);
}
}
serde_json::Value::String(s) => {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
format!("{}:{}", prefix, s).hash(&mut hasher);
*sig |= 1u64 << (hasher.finish() % 64);
}
serde_json::Value::Bool(b) => {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
format!("{}:{}", prefix, b).hash(&mut hasher);
*sig |= 1u64 << (hasher.finish() % 64);
}
serde_json::Value::Number(n) => {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
format!("{}:{}", prefix, n).hash(&mut hasher);
*sig |= 1u64 << (hasher.finish() % 64);
}
serde_json::Value::Null => {}
}
}
pub struct MemTable<T: VectorType> {
dim: usize,
next_id: NodeId,
vec_pool: VecPool<T>,
bq_signatures: Vec<BqSignature>,
bq_dirty: bool,
text_index: TextIndex,
payloads: HashMap<NodeId, serde_json::Value>,
edges: HashMap<NodeId, Vec<Edge>>,
in_degrees: HashMap<NodeId, usize>,
incoming_edges: HashMap<NodeId, Vec<NodeId>>,
fatigue_map: std::sync::RwLock<HashMap<NodeId, u8>>,
indices_to_ids: Vec<NodeId>,
ids_to_indices: HashMap<NodeId, usize>,
fast_tags: Vec<u64>,
free_slots: Vec<usize>,
int8_pool: Option<Int8Pool>,
}
impl<T: VectorType> MemTable<T> {
#[inline]
fn validate_vector(vector: &[T]) -> Result<()> {
for elem in vector {
let f = elem.to_f32();
if f.is_nan() || f.is_infinite() {
return Err(TriviumError::Generic(
"Vector contains NaN or Infinity; insert rejected to prevent silent search corruption"
.into(),
));
}
}
Ok(())
}
pub fn new(dim: usize) -> Self {
Self {
dim,
next_id: 1, vec_pool: VecPool::new(dim),
bq_signatures: Vec::new(),
bq_dirty: false,
text_index: TextIndex::new(),
payloads: HashMap::new(),
edges: HashMap::new(),
in_degrees: HashMap::new(),
incoming_edges: HashMap::new(),
fatigue_map: std::sync::RwLock::new(HashMap::new()),
indices_to_ids: Vec::new(),
ids_to_indices: HashMap::new(),
fast_tags: Vec::new(),
free_slots: Vec::new(),
int8_pool: None,
}
}
pub fn new_with_next_id(dim: usize, next_id: NodeId) -> Self {
let mut mt = Self::new(dim);
mt.next_id = next_id;
mt
}
pub fn new_with_vec_pool(dim: usize, next_id: NodeId, vec_pool: VecPool<T>) -> Self {
Self {
dim,
next_id,
vec_pool,
bq_signatures: Vec::new(),
bq_dirty: false,
text_index: TextIndex::new(),
payloads: HashMap::new(),
edges: HashMap::new(),
in_degrees: HashMap::new(),
incoming_edges: HashMap::new(),
fatigue_map: std::sync::RwLock::new(HashMap::new()),
indices_to_ids: Vec::new(),
ids_to_indices: HashMap::new(),
fast_tags: Vec::new(),
free_slots: Vec::new(),
int8_pool: None,
}
}
pub fn next_id_value(&self) -> NodeId {
self.next_id
}
#[inline]
pub fn advance_next_id(&mut self, candidate: NodeId) {
if candidate > self.next_id {
self.next_id = candidate;
}
}
pub fn vec_pool_mut(&mut self) -> &mut VecPool<T> {
&mut self.vec_pool
}
pub fn vec_pool(&self) -> &VecPool<T> {
&self.vec_pool
}
pub fn raw_insert(
&mut self,
id: NodeId,
vector: &[T],
payload: serde_json::Value,
) -> Result<()> {
if vector.len() != self.dim {
return Err(TriviumError::DimensionMismatch {
expected: self.dim,
got: vector.len(),
});
}
let sig = calculate_json_signature(&payload);
let idx = if let Some(free_idx) = self.free_slots.pop() {
self.vec_pool.update(free_idx, vector);
self.indices_to_ids[free_idx] = id;
self.fast_tags[free_idx] = sig;
free_idx
} else {
let i = self.indices_to_ids.len();
self.vec_pool.push(vector);
self.indices_to_ids.push(id);
self.fast_tags.push(sig);
i
};
self.payloads.insert(id, payload);
self.ids_to_indices.insert(id, idx);
Ok(())
}
pub fn register_node(&mut self, id: NodeId, payload: serde_json::Value) -> Result<()> {
let sig = calculate_json_signature(&payload);
let idx = self.indices_to_ids.len();
self.payloads.insert(id, payload);
self.indices_to_ids.push(id);
self.fast_tags.push(sig);
self.ids_to_indices.insert(id, idx);
Ok(())
}
pub fn register_tombstone(&mut self) -> Result<()> {
let idx = self.indices_to_ids.len();
self.indices_to_ids.push(0);
self.fast_tags.push(0);
self.free_slots.push(idx); Ok(())
}
pub fn insert(&mut self, vector: &[T], payload: serde_json::Value) -> Result<NodeId> {
if vector.len() != self.dim {
return Err(TriviumError::DimensionMismatch {
expected: self.dim,
got: vector.len(),
});
}
Self::validate_vector(vector)?;
let id = self.next_id;
self.next_id += 1;
let sig = calculate_json_signature(&payload);
let idx = if let Some(free_idx) = self.free_slots.pop() {
self.vec_pool.update(free_idx, vector); self.indices_to_ids[free_idx] = id;
self.fast_tags[free_idx] = sig;
free_idx
} else {
let i = self.indices_to_ids.len();
self.vec_pool.push(vector); self.indices_to_ids.push(id);
self.fast_tags.push(sig);
i
};
self.payloads.insert(id, payload);
self.ids_to_indices.insert(id, idx);
Ok(id)
}
pub fn insert_with_id(
&mut self,
id: NodeId,
vector: &[T],
payload: serde_json::Value,
) -> Result<()> {
if self.payloads.contains_key(&id) {
return Err(TriviumError::Generic(format!("Node {} already exists", id)));
}
if vector.len() != self.dim {
return Err(TriviumError::DimensionMismatch {
expected: self.dim,
got: vector.len(),
});
}
Self::validate_vector(vector)?;
let sig = calculate_json_signature(&payload);
let idx = if let Some(free_idx) = self.free_slots.pop() {
self.vec_pool.update(free_idx, vector);
self.indices_to_ids[free_idx] = id;
self.fast_tags[free_idx] = sig;
free_idx
} else {
let i = self.indices_to_ids.len();
self.vec_pool.push(vector);
self.indices_to_ids.push(id);
self.fast_tags.push(sig);
i
};
self.payloads.insert(id, payload);
self.ids_to_indices.insert(id, idx);
if id >= self.next_id {
self.next_id = id + 1;
}
Ok(())
}
pub fn link(&mut self, src: NodeId, dst: NodeId, label: String, weight: f32) -> Result<()> {
if !self.payloads.contains_key(&src) {
return Err(TriviumError::NodeNotFound(src));
}
if !self.payloads.contains_key(&dst) {
return Err(TriviumError::NodeNotFound(dst));
}
let edge = Edge {
target_id: dst,
label,
weight,
};
self.edges.entry(src).or_default().push(edge);
*self.in_degrees.entry(dst).or_insert(0) += 1;
self.incoming_edges.entry(dst).or_default().push(src);
Ok(())
}
pub fn mark_fatigued(&self, ids: &[NodeId]) {
if let Ok(mut map) = self.fatigue_map.write() {
for &id in ids {
map.insert(id, 1);
}
}
}
pub fn get_fatigue(&self, id: NodeId) -> u8 {
if let Ok(map) = self.fatigue_map.read() {
*map.get(&id).unwrap_or(&0)
} else {
0
}
}
pub fn consume_fatigue(&self, id: NodeId) {
if let Ok(mut map) = self.fatigue_map.write() {
if let Some(f) = map.get_mut(&id) {
*f = 0;
}
}
}
pub fn consume_fatigue_batch(&self, ids: &[NodeId]) {
if let Ok(mut map) = self.fatigue_map.write() {
for &id in ids {
if let Some(f) = map.get_mut(&id) {
*f = 0;
}
}
}
}
#[inline]
pub fn ensure_vectors_cache(&mut self) {
self.vec_pool.ensure_cache();
let total = self.vec_pool.total_count();
if self.bq_signatures.len() != total || self.bq_dirty {
self.rebuild_bq_signatures(total);
self.rebuild_int8_pool();
self.bq_dirty = false;
}
}
fn rebuild_bq_signatures(&mut self, total: usize) {
let dim = self.dim();
let flat = self.vec_pool.flat_vectors();
let mut new_bq = Vec::with_capacity(total);
for chunk in flat.chunks(dim) {
new_bq.push(BqSignature::from_vector(chunk));
}
while new_bq.len() < total {
new_bq.push(BqSignature::empty());
}
self.bq_signatures = new_bq;
}
fn rebuild_int8_pool(&mut self) {
let dim = self.dim();
let flat = self.vec_pool.flat_vectors();
if flat.is_empty() {
self.int8_pool = None;
return;
}
self.int8_pool = Some(Int8Pool::from_generic_vectors(flat, dim));
}
pub fn get_bq_signature(&self, index: usize) -> Option<BqSignature> {
self.bq_signatures.get(index).copied()
}
#[inline]
pub fn bq_signatures_slice(&self) -> &[BqSignature] {
&self.bq_signatures
}
#[inline]
pub fn fast_tags_slice(&self) -> &[u64] {
&self.fast_tags
}
pub fn set_bq_signatures(&mut self, sigs: Vec<BqSignature>) {
self.bq_signatures = sigs;
self.bq_dirty = false; }
#[inline]
pub fn int8_pool(&self) -> Option<&Int8Pool> {
self.int8_pool.as_ref()
}
#[inline]
pub fn flat_vectors(&self) -> &[T] {
self.vec_pool.flat_vectors()
}
#[inline]
pub fn dim(&self) -> usize {
self.dim
}
#[inline]
pub fn get_id_by_index(&self, idx: usize) -> NodeId {
self.indices_to_ids[idx]
}
pub fn get_payload(&self, id: NodeId) -> Option<&serde_json::Value> {
self.payloads.get(&id)
}
pub fn get_edges(&self, id: NodeId) -> Option<&[Edge]> {
self.edges.get(&id).map(|e| e.as_slice())
}
pub fn delete(&mut self, id: NodeId) -> Result<()> {
if !self.payloads.contains_key(&id) {
return Err(TriviumError::NodeNotFound(id));
}
if let Some(idx) = self.ids_to_indices.remove(&id) {
self.vec_pool.zero_out(idx);
self.indices_to_ids[idx] = 0; self.free_slots.push(idx); }
self.payloads.remove(&id);
if let Some(outgoing_edges) = self.edges.remove(&id) {
for edge in outgoing_edges {
let target = edge.target_id;
if let Some(in_deg) = self.in_degrees.get_mut(&target) {
*in_deg = in_deg.saturating_sub(1);
}
if let Some(incoming) = self.incoming_edges.get_mut(&target) {
incoming.retain(|&src| src != id);
}
}
}
if let Some(incoming) = self.incoming_edges.remove(&id) {
for src_id in incoming {
if let Some(edge_list) = self.edges.get_mut(&src_id) {
edge_list.retain(|e| e.target_id != id);
}
}
}
self.in_degrees.remove(&id);
self.bq_dirty = true;
Ok(())
}
pub fn unlink(&mut self, src: NodeId, dst: NodeId) -> Result<()> {
if let Some(edge_list) = self.edges.get_mut(&src) {
let initial_len = edge_list.len();
edge_list.retain(|e| e.target_id != dst);
if edge_list.len() < initial_len {
let removed_count = initial_len - edge_list.len();
if let Some(in_deg) = self.in_degrees.get_mut(&dst) {
*in_deg = in_deg.saturating_sub(removed_count);
}
if let Some(incoming) = self.incoming_edges.get_mut(&dst) {
incoming.retain(|&id| id != src);
}
}
Ok(())
} else {
Err(TriviumError::NodeNotFound(src))
}
}
pub fn get_all_ids(&self) -> Vec<NodeId> {
self.payloads.keys().copied().collect()
}
pub fn update_payload(&mut self, id: NodeId, payload: serde_json::Value) -> Result<()> {
match self.payloads.get_mut(&id) {
Some(existing) => {
let sig = calculate_json_signature(&payload);
if let Some(&idx) = self.ids_to_indices.get(&id) {
self.fast_tags[idx] = sig;
}
*existing = payload;
Ok(())
}
None => Err(TriviumError::NodeNotFound(id)),
}
}
pub fn update_vector(&mut self, id: NodeId, vector: &[T]) -> Result<()> {
if vector.len() != self.dim {
return Err(TriviumError::DimensionMismatch {
expected: self.dim,
got: vector.len(),
});
}
Self::validate_vector(vector)?;
if !self.payloads.contains_key(&id) {
return Err(TriviumError::NodeNotFound(id));
}
match self.ids_to_indices.get(&id) {
Some(&idx) => {
self.vec_pool.update(idx, vector);
self.bq_dirty = true; Ok(())
}
None => Err(TriviumError::NodeNotFound(id)),
}
}
pub fn get_vector(&self, id: NodeId) -> Option<&[T]> {
self.ids_to_indices
.get(&id)
.and_then(|&idx| self.vec_pool.get(idx))
}
pub fn node_count(&self) -> usize {
self.payloads.len()
}
#[inline]
pub fn internal_slot_count(&self) -> usize {
self.indices_to_ids.len()
}
pub fn get_in_degree(&self, id: NodeId) -> usize {
self.in_degrees.get(&id).copied().unwrap_or(0)
}
pub fn contains(&self, id: NodeId) -> bool {
self.payloads.contains_key(&id)
}
pub fn all_node_ids(&self) -> Vec<NodeId> {
self.payloads.keys().cloned().collect()
}
pub fn internal_indices(&self) -> &[NodeId] {
&self.indices_to_ids
}
pub fn active_entries(&self) -> impl Iterator<Item = (usize, NodeId)> + '_ {
self.indices_to_ids
.iter()
.enumerate()
.filter(|(_, nid)| self.payloads.contains_key(nid))
.map(|(idx, nid)| (idx, *nid))
}
pub fn estimated_memory_bytes(&self) -> usize {
let vec_bytes = self.vec_pool.heap_memory_bytes();
let payload_bytes: usize = self.payloads.values().map(|v| v.to_string().len()).sum();
let edge_bytes: usize = self
.edges
.values()
.map(|es| es.len() * std::mem::size_of::<Edge>())
.sum();
let index_bytes = self.indices_to_ids.len() * std::mem::size_of::<NodeId>()
+ self.ids_to_indices.len()
* (std::mem::size_of::<NodeId>() + std::mem::size_of::<usize>());
vec_bytes + payload_bytes + edge_bytes + index_bytes
}
pub fn index_keyword(&mut self, id: NodeId, keyword: &str) {
if self.contains(id) {
self.text_index.add_keyword(id, keyword);
}
}
pub fn index_text(&mut self, id: NodeId, text: &str) {
if self.contains(id) {
self.text_index.add_text(id, text);
}
}
pub fn build_text_index(&mut self) {
self.text_index.build();
}
pub fn text_engine(&self) -> &TextIndex {
&self.text_index
}
pub fn rebuild_text_index_from_payloads(&mut self) {
self.text_index.clear();
for (&id, payload) in &self.payloads {
if let serde_json::Value::Object(map) = payload {
for (_key, value) in map {
if let serde_json::Value::String(text) = value
&& !text.is_empty() {
self.text_index.add_text(id, text);
}
}
}
}
self.text_index.build();
if !self.payloads.is_empty() {
tracing::info!(
"TextIndex 从 {} 个节点的 payload 自动重建完成",
self.payloads.len()
);
}
}
}