use std::{
any::{Any, TypeId as StdTypeId},
collections::HashMap,
path::Path,
sync::Arc,
};
use parking_lot::ReentrantMutex;
use serde::Serialize;
use tracing::instrument;
use zerocopy::{FromBytes, IntoBytes};
use ahash::{AHashMap, AHashSet};
use crate::matrices::MatrixSet;
use crate::{
csr::{CsrCache, CsrSnapshot},
error::Error,
schema::{
AdjEntry, DirectedNeighborEntry, EdgeId, EdgeRecord, LabelId, Language, NeighborEntry,
NodeId, NodeRecord, PropKeyId, PropValue, TypeId, WeightedPath,
},
storage::{
fts,
ids::{
adjust_label_count, adjust_type_count, alloc_edge_id, alloc_node_id, get_label,
get_or_create_label, get_or_create_prop_key, get_or_create_type, get_prop_key,
get_prop_key_name, get_type,
},
lmdb::Storage,
props,
},
};
pub mod algo;
pub mod edge;
pub mod fts_mod;
pub mod graphblas;
pub mod index;
pub mod node;
pub mod txn;
pub mod vector;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub enum DegreeDirection {
In,
Out,
Both,
}
#[derive(Debug, Clone, Default)]
pub struct TriangleCountSpec<'a> {
pub rel_types: [Option<&'a str>; 3],
pub labels: [Option<&'a str>; 3],
}
pub(super) fn composite_key(prefix: u32, id: u64) -> [u8; 12] {
let mut key = [0u8; 12];
key[..4].copy_from_slice(&prefix.to_be_bytes());
key[4..].copy_from_slice(&id.to_be_bytes());
key
}
pub(super) const ENCODED_NULL: u8 = 0x00;
const SORT_SIGN_BIT: u64 = 0x8000_0000_0000_0000;
pub(super) fn encode_property_value(val: &serde_json::Value) -> Option<Vec<u8>> {
match val {
serde_json::Value::Null => Some(vec![ENCODED_NULL]),
serde_json::Value::Bool(false) => Some(vec![0x01]),
serde_json::Value::Bool(true) => Some(vec![0x02]),
serde_json::Value::Number(num) => {
let float_val = num.as_f64()?;
let bits = float_val.to_bits();
let masked = if (bits & SORT_SIGN_BIT) != 0 {
!bits
} else {
bits ^ SORT_SIGN_BIT
};
let int_disambig: u64 = if let Some(i) = num.as_i64() {
(i as u64) ^ SORT_SIGN_BIT
} else if float_val.fract() == 0.0
&& float_val >= i64::MIN as f64
&& float_val <= i64::MAX as f64
{
((float_val as i64) as u64) ^ SORT_SIGN_BIT
} else {
0
};
let mut buf = Vec::with_capacity(17);
buf.push(0x03);
buf.extend_from_slice(&masked.to_be_bytes());
buf.extend_from_slice(&int_disambig.to_be_bytes());
Some(buf)
}
serde_json::Value::String(s) => {
let mut buf = Vec::with_capacity(1 + s.len() + 1);
buf.push(0x04);
buf.extend_from_slice(s.as_bytes());
buf.push(0x00);
Some(buf)
}
_ => None, }
}
#[allow(dead_code)]
pub(super) fn decode_property_value(bytes: &[u8]) -> Option<serde_json::Value> {
if bytes.is_empty() {
return None;
}
match bytes[0] {
0x00 => Some(serde_json::Value::Null),
0x01 => Some(serde_json::Value::Bool(false)),
0x02 => Some(serde_json::Value::Bool(true)),
0x03 => {
if bytes.len() < 17 {
return None;
}
let mut int_arr = [0u8; 8];
int_arr.copy_from_slice(&bytes[9..17]);
let int_val = (u64::from_be_bytes(int_arr) ^ SORT_SIGN_BIT) as i64;
let mut arr = [0u8; 8];
arr.copy_from_slice(&bytes[1..9]);
let masked = u64::from_be_bytes(arr);
let bits = if (masked & SORT_SIGN_BIT) == 0 {
!masked
} else {
masked ^ SORT_SIGN_BIT
};
let float_val = f64::from_bits(bits);
if (int_val as f64) == float_val {
Some(serde_json::Value::Number(int_val.into()))
} else {
serde_json::Number::from_f64(float_val).map(serde_json::Value::Number)
}
}
0x04 => {
let str_bytes = if bytes.ends_with(&[0x00]) {
&bytes[1..bytes.len() - 1]
} else {
&bytes[1..]
};
String::from_utf8(str_bytes.to_vec())
.ok()
.map(serde_json::Value::String)
}
_ => None,
}
}
pub(super) fn node_prop_index_key(
label_id: LabelId,
prop_key_id: PropKeyId,
encoded_val: &[u8],
node_id: NodeId,
) -> Vec<u8> {
let mut key = Vec::with_capacity(4 + 4 + encoded_val.len() + 8);
key.extend_from_slice(&label_id.to_be_bytes());
key.extend_from_slice(&prop_key_id.to_be_bytes());
key.extend_from_slice(encoded_val);
key.extend_from_slice(&node_id.to_be_bytes());
key
}
pub(super) fn edge_prop_index_key(
type_id: TypeId,
prop_key_id: PropKeyId,
encoded_val: &[u8],
edge_id: EdgeId,
) -> Vec<u8> {
let mut key = Vec::with_capacity(4 + 4 + encoded_val.len() + 8);
key.extend_from_slice(&type_id.to_be_bytes());
key.extend_from_slice(&prop_key_id.to_be_bytes());
key.extend_from_slice(encoded_val);
key.extend_from_slice(&edge_id.to_be_bytes());
key
}
pub(super) fn fts_postings_key(label_id: LabelId, prop_key_id: PropKeyId, term: &str) -> Vec<u8> {
let mut key = Vec::with_capacity(8 + term.len());
key.extend_from_slice(&label_id.to_be_bytes());
key.extend_from_slice(&prop_key_id.to_be_bytes());
key.extend_from_slice(term.as_bytes());
key
}
pub(super) fn fts_posting_val(node_id: NodeId, frequency: u32) -> [u8; 12] {
let mut val = [0u8; 12];
val[0..8].copy_from_slice(&node_id.to_be_bytes());
val[8..12].copy_from_slice(&frequency.to_be_bytes());
val
}
pub(super) fn parse_fts_posting_val(bytes: &[u8]) -> Result<(NodeId, u32), Error> {
if bytes.len() != 12 {
return Err(Error::Corrupt("fts posting value must be 12 bytes"));
}
let node_id = NodeId::from_be_bytes(
bytes[0..8]
.try_into()
.map_err(|_| Error::Corrupt("fts posting: node_id slice wrong size"))?,
);
let frequency = u32::from_be_bytes(
bytes[8..12]
.try_into()
.map_err(|_| Error::Corrupt("fts posting: frequency slice wrong size"))?,
);
Ok((node_id, frequency))
}
pub(super) fn fts_doc_key(label_id: LabelId, prop_key_id: PropKeyId, node_id: NodeId) -> [u8; 16] {
let mut key = [0u8; 16];
key[0..4].copy_from_slice(&label_id.to_be_bytes());
key[4..8].copy_from_slice(&prop_key_id.to_be_bytes());
key[8..16].copy_from_slice(&node_id.to_be_bytes());
key
}
pub(super) fn parse_fts_doc_val(bytes: &[u8]) -> Result<u32, Error> {
if bytes.len() != 4 {
return Err(Error::Corrupt("fts doc val must be 4 bytes"));
}
Ok(u32::from_be_bytes(bytes.try_into().map_err(|_| {
Error::Corrupt("fts doc val: slice wrong size")
})?))
}
pub(super) fn fts_stats_n_key(label_id: LabelId, prop_key_id: PropKeyId) -> String {
format!("fts_stats:node:l:{label_id}:p:{prop_key_id}:N")
}
pub(super) fn fts_stats_sum_dl_key(label_id: LabelId, prop_key_id: PropKeyId) -> String {
format!("fts_stats:node:l:{label_id}:p:{prop_key_id}:sum_dl")
}
#[derive(Clone)]
pub struct Graph {
pub(super) storage: Arc<Storage>,
pub(super) _write_lock: Arc<ReentrantMutex<()>>,
pub(super) csr_cache: Arc<CsrCache>,
pub(super) matrices: Arc<parking_lot::RwLock<Option<MatrixSet>>>,
pub(super) prop_columns: Arc<crate::columns::ColumnsCache>,
pub(super) n_threads: Arc<std::sync::atomic::AtomicI32>,
pub(crate) extensions: Arc<parking_lot::Mutex<AHashMap<StdTypeId, Box<dyn Any + Send + Sync>>>>,
}
pub struct ReadTxn<'a> {
pub(super) graph: &'a Graph,
pub(super) rtxn: heed::RoTxn<'a, heed::WithTls>,
}
pub struct WriteTxn<'a> {
pub(super) graph: &'a Graph,
pub(super) wtxn: heed::RwTxn<'a>,
pub(super) mutations_count: usize,
pub(super) delta: crate::csr::GraphDelta,
}
impl Graph {
pub fn open(path: &Path, map_size_gb: usize) -> Result<Self, Error> {
let storage = Storage::open(path, map_size_gb)?;
let _ = std::fs::remove_file(path.join("csr_snapshot.bin"));
let initial = CsrSnapshot::build(&storage)?;
let storage = Arc::new(storage);
let csr_cache = Arc::new(CsrCache::new(initial));
let matrices = {
let initial_snap = csr_cache.snapshot.load();
let m = MatrixSet::materialize(&initial_snap, 0)?;
Arc::new(parking_lot::RwLock::new(Some(m)))
};
Ok(Self {
storage,
_write_lock: Arc::new(ReentrantMutex::new(())),
csr_cache,
matrices,
prop_columns: Arc::new(crate::columns::ColumnsCache::default()),
n_threads: Arc::new(std::sync::atomic::AtomicI32::new(0)),
extensions: Arc::new(parking_lot::Mutex::new(AHashMap::new())),
})
}
pub fn set_thread_count(&self, n: i32) -> Result<(), Error> {
self.n_threads
.store(n, std::sync::atomic::Ordering::Release);
issundb_graphblas::set_global_threads(n).map_err(|e| Error::GraphBLAS(e.to_string()))?;
Ok(())
}
pub fn node_prop_json(
&self,
id: NodeId,
prop: &str,
) -> Result<Option<serde_json::Value>, Error> {
self.prop_columns.with_fresh(&self.storage, |cols| {
cols.id_to_dense.get(&id).map(|&d| {
cols.cols
.get(prop)
.and_then(|c| c.get_json_opt(d as usize))
.unwrap_or(serde_json::Value::Null)
})
})
}
pub fn node_props_json_table(
&self,
ids: &[NodeId],
props: &[&str],
) -> Result<Vec<Vec<serde_json::Value>>, Error> {
self.prop_columns
.with_fresh(&self.storage, |cols| cols.props_table(ids, props))?
}
pub fn node_prop_json_column(
&self,
ids: &[NodeId],
prop: &str,
) -> Result<Vec<serde_json::Value>, Error> {
self.prop_columns
.with_fresh(&self.storage, |cols| cols.prop_column(ids, prop))?
}
pub fn node_prop_group_codes(
&self,
ids: &[NodeId],
prop: &str,
) -> Result<(Vec<u32>, Vec<serde_json::Value>), Error> {
self.prop_columns
.with_fresh(&self.storage, |cols| cols.group_codes(ids, prop))?
}
pub fn node_prop_min_max(
&self,
prop: &str,
) -> Result<Option<(serde_json::Value, serde_json::Value)>, Error> {
self.prop_columns.with_fresh_mut(&self.storage, |cols| {
cols.prop_stats(prop)
.map(|s| (s.min.clone(), s.max.clone()))
})
}
pub fn estimate_range_selectivity(
&self,
prop: &str,
lower: Option<&serde_json::Value>,
upper: Option<&serde_json::Value>,
) -> Result<Option<f64>, Error> {
self.prop_columns.with_fresh_mut(&self.storage, |cols| {
cols.prop_stats(prop)
.map(|s| s.histogram.estimate_range_selectivity(lower, upper))
})
}
pub fn estimate_equality_selectivity(
&self,
prop: &str,
val: &serde_json::Value,
) -> Result<Option<f64>, Error> {
self.prop_columns.with_fresh_mut(&self.storage, |cols| {
cols.prop_stats(prop).map(|s| s.equality_selectivity(val))
})
}
pub fn set_extension<T: Any + Send + Sync>(&self, val: Arc<T>) {
self.extensions
.lock()
.insert(StdTypeId::of::<T>(), Box::new(val));
}
pub fn get_extension<T: Any + Send + Sync>(&self) -> Option<Arc<T>> {
self.extensions
.lock()
.get(&StdTypeId::of::<T>())
.and_then(|b| b.downcast_ref::<Arc<T>>())
.cloned()
}
pub fn get_or_init_extension_with<T, E, F>(&self, init: F) -> Result<Arc<T>, E>
where
T: Any + Send + Sync,
F: FnOnce() -> Result<Arc<T>, E>,
{
if let Some(existing) = self.get_extension::<T>() {
return Ok(existing);
}
let value = init()?;
let mut ext = self.extensions.lock();
if let Some(existing) = ext
.get(&StdTypeId::of::<T>())
.and_then(|b| b.downcast_ref::<Arc<T>>())
{
return Ok(existing.clone());
}
ext.insert(StdTypeId::of::<T>(), Box::new(value.clone()));
Ok(value)
}
pub fn view<F, T>(&self, f: F) -> Result<T, Error>
where
F: FnOnce(&ReadTxn) -> Result<T, Error>,
{
let rtxn = self.storage.env.read_txn()?;
let txn = ReadTxn { graph: self, rtxn };
f(&txn)
}
pub fn update<F, T>(&self, f: F) -> Result<T, Error>
where
F: FnOnce(&mut WriteTxn) -> Result<T, Error>,
{
let _guard = self._write_lock.lock();
let wtxn = self.storage.env.write_txn()?;
let mut txn = WriteTxn {
graph: self,
wtxn,
mutations_count: 0,
delta: crate::csr::GraphDelta::default(),
};
match f(&mut txn) {
Ok(val) => {
let mutations_count = txn.mutations_count;
let delta = std::mem::take(&mut txn.delta);
txn.wtxn.commit()?;
if delta.force_full {
self.prop_columns.record_force_full();
} else {
self.prop_columns.record_touched_many(&delta.added_nodes);
self.prop_columns.record_touched_many(&delta.updated_nodes);
}
self.csr_cache.record_batch(delta);
if mutations_count > 0 {
self.maybe_spawn_rebuild_n(mutations_count);
}
Ok(val)
}
Err(err) => {
txn.wtxn.abort();
Err(err)
}
}
}
pub fn with_write_lock<F, R>(&self, f: F) -> R
where
F: FnOnce() -> R,
{
let _guard = self._write_lock.lock();
f()
}
#[instrument(skip(self))]
pub fn rebuild_csr(&self) -> Result<(), Error> {
let built_gen = self.csr_cache.current_gen();
self.csr_cache.clear_delta();
let snap = CsrSnapshot::build(&self.storage)?;
let m = MatrixSet::materialize(
&snap,
self.n_threads.load(std::sync::atomic::Ordering::Acquire),
)?;
*self.matrices.write() = Some(m);
self.csr_cache.install_full(snap, built_gen);
Ok(())
}
pub fn backup(&self, destination: &Path) -> Result<(), Error> {
self.storage
.env
.copy_to_path(destination, heed::CompactionOption::Disabled)
.map(|_| ())
.map_err(Error::Storage)
}
pub fn backup_compact(&self, destination: &Path) -> Result<(), Error> {
self.storage
.env
.copy_to_path(destination, heed::CompactionOption::Enabled)
.map(|_| ())
.map_err(Error::Storage)
}
pub fn restore(snapshot_file: &Path, dst_dir: &Path) -> Result<(), Error> {
std::fs::create_dir_all(dst_dir)?;
let dst_file = dst_dir.join("data.mdb");
std::fs::copy(snapshot_file, &dst_file)?;
Ok(())
}
}
#[cfg(test)]
mod extension_tests {
use std::sync::Arc;
use tempfile::TempDir;
use super::Graph;
fn open_tmp() -> (TempDir, Graph) {
let dir = TempDir::new().unwrap();
let g = Graph::open(dir.path(), 1).unwrap();
(dir, g)
}
#[test]
fn extension_roundtrip_by_type() {
let (_dir, g) = open_tmp();
assert!(g.get_extension::<String>().is_none());
g.set_extension(Arc::new(String::from("cache")));
let got = g.get_extension::<String>().expect("extension must exist");
assert_eq!(*got, "cache");
assert!(g.get_extension::<u64>().is_none(), "distinct type slot");
g.set_extension(Arc::new(String::from("replaced")));
assert_eq!(*g.get_extension::<String>().unwrap(), "replaced");
}
#[test]
fn get_or_init_extension_initializes_once() {
let (_dir, g) = open_tmp();
let v1 = g
.get_or_init_extension_with::<u64, std::convert::Infallible, _>(|| Ok(Arc::new(7)))
.unwrap();
assert_eq!(*v1, 7);
let v2 = g
.get_or_init_extension_with::<u64, std::convert::Infallible, _>(|| Ok(Arc::new(9)))
.unwrap();
assert_eq!(*v2, 7, "second init must not replace the stored value");
}
#[test]
fn get_or_init_extension_propagates_init_error() {
let (_dir, g) = open_tmp();
let err = g
.get_or_init_extension_with::<u64, &str, _>(|| Err("init failed"))
.unwrap_err();
assert_eq!(err, "init failed");
assert!(g.get_extension::<u64>().is_none());
let v = g
.get_or_init_extension_with::<u64, &str, _>(|| Ok(Arc::new(7)))
.unwrap();
assert_eq!(*v, 7);
}
}
#[cfg(test)]
mod encode_tests {
use serde_json::json;
use super::{decode_property_value, encode_property_value};
#[test]
fn large_integers_do_not_collide() {
let a = encode_property_value(&json!(9_007_199_254_740_992_i64)).unwrap(); let b = encode_property_value(&json!(9_007_199_254_740_993_i64)).unwrap(); assert_ne!(a, b, "distinct large integers must encode distinctly");
}
#[test]
fn integer_and_equal_float_unify() {
assert_eq!(
encode_property_value(&json!(30)).unwrap(),
encode_property_value(&json!(30.0)).unwrap(),
);
assert_eq!(
encode_property_value(&json!(0)).unwrap(),
encode_property_value(&json!(0.0)).unwrap(),
);
}
#[test]
fn numeric_encoding_is_fixed_length() {
for v in [
json!(1),
json!(-1),
json!(0),
json!(i64::MAX),
json!(i64::MIN),
json!(3.5),
json!(-2.5e10),
] {
assert_eq!(encode_property_value(&v).unwrap().len(), 17, "value {v}");
}
}
#[test]
fn numeric_ordering_preserved() {
let ascending: Vec<i64> = vec![
i64::MIN,
-1_000,
-1,
0,
1,
1_000,
1 << 53,
(1 << 53) + 1,
i64::MAX,
];
let encoded: Vec<Vec<u8>> = ascending
.iter()
.map(|v| encode_property_value(&json!(v)).unwrap())
.collect();
let mut sorted = encoded.clone();
sorted.sort();
assert_eq!(encoded, sorted, "encodings must sort in numeric order");
}
#[test]
fn decode_round_trips_large_integer() {
for v in [
json!(0),
json!(-1),
json!(9_007_199_254_740_993_i64),
json!(i64::MAX),
] {
let enc = encode_property_value(&v).unwrap();
assert_eq!(decode_property_value(&enc), Some(v.clone()), "value {v}");
}
}
}