#[cfg(feature = "python")]
pub mod python {
use crate::database::Database as GenericDatabase;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};
enum DbBackend {
F32(GenericDatabase<f32>),
F16(GenericDatabase<half::f16>),
U64(GenericDatabase<u64>),
}
#[pyclass(name = "TriviumDB")]
pub struct PyTriviumDB {
inner: DbBackend,
#[pyo3(get)]
dtype: String,
}
macro_rules! dispatch {
($self:expr, $db:ident => $expr:expr) => {
match &$self.inner {
DbBackend::F32($db) => $expr,
DbBackend::F16($db) => $expr,
DbBackend::U64($db) => $expr,
}
};
($self:expr, mut $db:ident => $expr:expr) => {
match &mut $self.inner {
DbBackend::F32($db) => $expr,
DbBackend::F16($db) => $expr,
DbBackend::U64($db) => $expr,
}
};
}
#[pyclass(name = "SearchHit")]
pub struct PySearchHit {
#[pyo3(get)]
pub id: u64,
#[pyo3(get)]
pub score: f32,
#[pyo3(get)]
pub payload: PyObject,
}
#[pyclass(name = "Edge")]
#[derive(Clone)]
pub struct PyEdge {
#[pyo3(get)]
pub target_id: u64,
#[pyo3(get)]
pub label: String,
#[pyo3(get)]
pub weight: f32,
}
#[pyclass(name = "NodeView")]
pub struct PyNodeView {
#[pyo3(get)]
pub id: u64,
#[pyo3(get)]
pub vector: PyObject, #[pyo3(get)]
pub payload: PyObject,
#[pyo3(get)]
pub edges: Vec<PyEdge>,
#[pyo3(get)]
pub num_edges: usize,
}
#[pyclass(name = "QueryRow")]
pub struct PyQueryRow {
#[pyo3(get)]
pub row: PyObject,
}
#[pyclass(name = "HookContext")]
pub struct PyHookContext {
#[pyo3(get)]
pub timings: PyObject,
#[pyo3(get)]
pub custom_data: PyObject,
#[pyo3(get)]
pub aborted: bool,
}
#[pymethods]
impl PyHookContext {
fn __repr__(&self, py: Python<'_>) -> String {
format!(
"HookContext(aborted={}, timings={:?})",
self.aborted,
self.timings
.bind(py)
.repr()
.map(|r| r.to_string())
.unwrap_or_default()
)
}
}
#[pymethods]
impl PyQueryRow {
fn __repr__(&self, py: Python<'_>) -> String {
format!(
"QueryRow({:?})",
self.row
.bind(py)
.repr()
.map(|r| r.to_string())
.unwrap_or_default()
)
}
}
fn json_to_pyobject(py: Python<'_>, val: &serde_json::Value) -> PyObject {
match val {
serde_json::Value::Null => py.None(),
serde_json::Value::Bool(b) => (*b)
.into_pyobject(py)
.unwrap()
.to_owned()
.into_any()
.unbind(),
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
i.into_pyobject(py).unwrap().into_any().unbind()
} else {
n.as_f64()
.unwrap_or(0.0)
.into_pyobject(py)
.unwrap()
.into_any()
.unbind()
}
}
serde_json::Value::String(s) => s.into_pyobject(py).unwrap().into_any().unbind(),
serde_json::Value::Array(arr) => {
let list = PyList::new(py, arr.iter().map(|v| json_to_pyobject(py, v))).unwrap();
list.into_any().unbind()
}
serde_json::Value::Object(map) => {
let dict = PyDict::new(py);
for (k, v) in map {
let _ = dict.set_item(k, json_to_pyobject(py, v));
}
dict.into_any().unbind()
}
}
}
fn pyobject_to_json(py: Python<'_>, obj: &Bound<'_, PyAny>) -> serde_json::Value {
if obj.is_none() {
serde_json::Value::Null
} else if let Ok(b) = obj.extract::<bool>() {
serde_json::Value::Bool(b)
} else if let Ok(i) = obj.extract::<i64>() {
serde_json::json!(i)
} else if let Ok(f) = obj.extract::<f64>() {
serde_json::json!(f)
} else if let Ok(s) = obj.extract::<String>() {
serde_json::Value::String(s)
} else if let Ok(dict) = obj.downcast::<PyDict>() {
let mut map = serde_json::Map::new();
for (k, v) in dict.iter() {
if let Ok(key) = k.extract::<String>() {
map.insert(key, pyobject_to_json(py, &v));
}
}
serde_json::Value::Object(map)
} else if let Ok(list) = obj.downcast::<PyList>() {
let arr: Vec<serde_json::Value> = list
.iter()
.map(|item| pyobject_to_json(py, &item))
.collect();
serde_json::Value::Array(arr)
} else {
serde_json::Value::Null
}
}
use crate::filter::Filter;
fn dict_to_filter(py: Python<'_>, dict: &Bound<'_, PyDict>) -> PyResult<Filter> {
let json_val = pyobject_to_json(py, &dict.clone().into_any());
Filter::from_json(&json_val).map_err(|e| pyo3::exceptions::PyValueError::new_err(e))
}
fn parse_sync_mode(s: &str) -> PyResult<crate::storage::wal::SyncMode> {
crate::storage::wal::SyncMode::parse(s)
.map_err(|e| pyo3::exceptions::PyValueError::new_err(e))
}
#[pymethods]
impl PyTriviumDB {
#[new]
#[pyo3(signature = (path, dim=1536, dtype="f32", sync_mode="normal"))]
fn new(path: &str, dim: usize, dtype: &str, sync_mode: &str) -> PyResult<Self> {
let sm = parse_sync_mode(sync_mode)?;
let inner = match dtype {
"f32" => DbBackend::F32(
GenericDatabase::<f32>::open_with_sync(path, dim, sm).map_err(
|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
},
)?,
),
"f16" => DbBackend::F16(
GenericDatabase::<half::f16>::open_with_sync(path, dim, sm).map_err(
|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
},
)?,
),
"u64" => DbBackend::U64(
GenericDatabase::<u64>::open_with_sync(path, dim, sm).map_err(
|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
},
)?,
),
_ => {
return Err(pyo3::exceptions::PyValueError::new_err(
"Unsupported dtype. Use 'f32', 'f16', or 'u64'",
));
}
};
Ok(Self {
inner,
dtype: dtype.to_string(),
})
}
fn set_sync_mode(&mut self, mode: &str) -> PyResult<()> {
let sm = parse_sync_mode(mode)?;
dispatch!(self, mut db => db.set_sync_mode(sm));
Ok(())
}
fn load_ffi_hook(&mut self, lib_path: &str) -> PyResult<()> {
let ffi_hook = crate::hook::FfiHook::load(lib_path).map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!("加载 FFI Hook 失败: {}", e))
})?;
dispatch!(self, mut db => db.set_hook(ffi_hook));
Ok(())
}
fn clear_hook(&mut self) {
dispatch!(self, mut db => db.clear_hook());
}
fn set_hook(&mut self, hook: PyObject) {
let wrapper = PySearchHookWrapper { py_hook: hook };
dispatch!(self, mut db => db.set_hook(wrapper));
}
#[pyo3(signature = (query_vector, top_k=5, expand_depth=2, min_score=0.1, payload_filter=None))]
fn search_with_context(
&self,
py: Python<'_>,
query_vector: Bound<'_, PyAny>,
top_k: usize,
expand_depth: usize,
min_score: f32,
payload_filter: Option<&Bound<'_, PyDict>>,
) -> PyResult<(Vec<PySearchHit>, PyHookContext)> {
let rust_filter = match payload_filter {
Some(dict) => Some(dict_to_filter(py, dict)?),
None => None,
};
let config = crate::database::SearchConfig {
top_k,
expand_depth,
min_score,
payload_filter: rust_filter,
..Default::default()
};
let (results, hook_ctx) = match &self.inner {
DbBackend::F32(db) => {
let vec: Vec<f32> = query_vector.extract()?;
db.search_hybrid_with_context(None, Some(&vec), &config)
}
DbBackend::F16(db) => {
let vec: Vec<f32> = query_vector.extract()?;
let vec16: Vec<half::f16> = vec.into_iter().map(half::f16::from_f32).collect();
db.search_hybrid_with_context(None, Some(&vec16), &config)
}
DbBackend::U64(db) => {
let vec: Vec<u64> = query_vector.extract()?;
db.search_hybrid_with_context(None, Some(&vec), &config)
}
}
.map_err(|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
})?;
let hits: Vec<PySearchHit> = results
.into_iter()
.map(|h| PySearchHit {
id: h.id,
score: h.score,
payload: json_to_pyobject(py, &h.payload),
})
.collect();
let timings_dict = PyDict::new(py);
for (stage, dur) in &hook_ctx.stage_timings {
let _ = timings_dict.set_item(stage, dur.as_secs_f64() * 1000.0); }
let ctx = PyHookContext {
timings: timings_dict.into_any().unbind(),
custom_data: json_to_pyobject(py, &hook_ctx.custom_data),
aborted: hook_ctx.abort,
};
Ok((hits, ctx))
}
fn insert(
&mut self,
py: Python<'_>,
vector: Bound<'_, PyAny>,
payload: &Bound<'_, PyAny>,
) -> PyResult<u64> {
let json = pyobject_to_json(py, payload);
match &mut self.inner {
DbBackend::F32(db) => {
let vec: Vec<f32> = vector.extract()?;
db.insert(&vec, json)
.map_err(|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
})
}
DbBackend::F16(db) => {
let vec: Vec<f32> = vector.extract()?;
let vec16: Vec<half::f16> = vec.into_iter().map(half::f16::from_f32).collect();
db.insert(&vec16, json)
.map_err(|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
})
}
DbBackend::U64(db) => {
let vec: Vec<u64> = vector.extract()?;
db.insert(&vec, json)
.map_err(|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
})
}
}
}
fn insert_with_id(
&mut self,
py: Python<'_>,
id: u64,
vector: Bound<'_, PyAny>,
payload: &Bound<'_, PyAny>,
) -> PyResult<()> {
let json = pyobject_to_json(py, payload);
match &mut self.inner {
DbBackend::F32(db) => {
let vec: Vec<f32> = vector.extract()?;
db.insert_with_id(id, &vec, json)
.map_err(|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
})
}
DbBackend::F16(db) => {
let vec: Vec<f32> = vector.extract()?;
let vec16: Vec<half::f16> = vec.into_iter().map(half::f16::from_f32).collect();
db.insert_with_id(id, &vec16, json)
.map_err(|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
})
}
DbBackend::U64(db) => {
let vec: Vec<u64> = vector.extract()?;
db.insert_with_id(id, &vec, json)
.map_err(|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
})
}
}
}
#[pyo3(signature = (src, dst, label="related", weight=1.0))]
fn link(&mut self, src: u64, dst: u64, label: &str, weight: f32) -> PyResult<()> {
dispatch!(self, mut db => db.link(src, dst, label, weight)).map_err(
|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
},
)
}
#[pyo3(signature = (query_vector, top_k=5, expand_depth=0, min_score=0.5, payload_filter=None))]
fn search(
&self,
py: Python<'_>,
query_vector: Bound<'_, PyAny>,
top_k: usize,
expand_depth: usize,
min_score: f32,
payload_filter: Option<&Bound<'_, PyDict>>,
) -> PyResult<Vec<PySearchHit>> {
let rust_filter = match payload_filter {
Some(dict) => Some(dict_to_filter(py, dict)?),
None => None,
};
let config = crate::database::SearchConfig {
top_k,
expand_depth,
min_score,
enable_advanced_pipeline: false,
payload_filter: rust_filter,
..Default::default()
};
let results = match &self.inner {
DbBackend::F32(db) => {
let vec: Vec<f32> = query_vector.extract()?;
db.search_hybrid(None, Some(&vec), &config)
}
DbBackend::F16(db) => {
let vec: Vec<f32> = query_vector.extract()?;
let vec16: Vec<half::f16> = vec.into_iter().map(half::f16::from_f32).collect();
db.search_hybrid(None, Some(&vec16), &config)
}
DbBackend::U64(db) => {
let vec: Vec<u64> = query_vector.extract()?;
db.search_hybrid(None, Some(&vec), &config)
}
}
.map_err(|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
})?;
Ok(results
.into_iter()
.map(|h| PySearchHit {
id: h.id,
score: h.score,
payload: json_to_pyobject(py, &h.payload),
})
.collect())
}
#[pyo3(signature = (
query_vector,
top_k=5,
expand_depth=2,
min_score=0.1,
teleport_alpha=0.0,
enable_advanced_pipeline=true,
enable_sparse_residual=false,
fista_lambda=0.1,
fista_threshold=0.3,
enable_dpp=false,
dpp_quality_weight=1.0,
enable_refractory_fatigue=false,
enable_text_hybrid_search=false,
text_boost=1.5,
custom_query_text=None,
payload_filter=None,
force_brute_force=false
))]
fn search_advanced(
&self,
py: Python<'_>,
query_vector: Bound<'_, PyAny>,
top_k: usize,
expand_depth: usize,
min_score: f32,
teleport_alpha: f32,
enable_advanced_pipeline: bool,
enable_sparse_residual: bool,
fista_lambda: f32,
fista_threshold: f32,
enable_dpp: bool,
dpp_quality_weight: f32,
enable_refractory_fatigue: bool,
enable_text_hybrid_search: bool,
text_boost: f32,
custom_query_text: Option<String>,
payload_filter: Option<&Bound<'_, PyDict>>,
force_brute_force: bool,
) -> PyResult<Vec<PySearchHit>> {
let rust_filter = match payload_filter {
Some(dict) => Some(dict_to_filter(py, dict)?),
None => None,
};
let config = crate::database::SearchConfig {
top_k,
expand_depth,
min_score,
teleport_alpha,
enable_advanced_pipeline,
enable_sparse_residual,
fista_lambda,
fista_threshold,
enable_dpp,
dpp_quality_weight,
enable_refractory_fatigue,
enable_text_hybrid_search,
text_boost,
force_brute_force,
payload_filter: rust_filter,
..Default::default()
};
let q_text = custom_query_text.as_deref();
let results = match &self.inner {
DbBackend::F32(db) => {
let vec: Vec<f32> = query_vector.extract()?;
db.search_hybrid(q_text, Some(&vec), &config)
}
DbBackend::F16(db) => {
let vec: Vec<f32> = query_vector.extract()?;
let vec16: Vec<half::f16> = vec.into_iter().map(half::f16::from_f32).collect();
db.search_hybrid(q_text, Some(&vec16), &config)
}
DbBackend::U64(db) => {
let vec: Vec<u64> = query_vector.extract()?;
db.search_hybrid(q_text, Some(&vec), &config)
}
}
.map_err(|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
})?;
Ok(results
.into_iter()
.map(|h| PySearchHit {
id: h.id,
score: h.score,
payload: json_to_pyobject(py, &h.payload),
})
.collect())
}
#[pyo3(signature = (query_vector, query_text, top_k=5, expand_depth=2, min_score=0.1, hybrid_alpha=0.7, payload_filter=None))]
fn search_hybrid(
&self,
py: Python<'_>,
query_vector: Bound<'_, PyAny>,
query_text: &str,
top_k: usize,
expand_depth: usize,
min_score: f32,
hybrid_alpha: f32,
payload_filter: Option<&Bound<'_, PyDict>>,
) -> PyResult<Vec<PySearchHit>> {
let rust_filter = match payload_filter {
Some(dict) => Some(dict_to_filter(py, dict)?),
None => None,
};
let boost = (1.0 - hybrid_alpha).max(0.1) * 3.0;
let config = crate::database::SearchConfig {
top_k,
expand_depth,
min_score,
enable_text_hybrid_search: true,
text_boost: boost,
payload_filter: rust_filter,
..Default::default()
};
let results = match &self.inner {
DbBackend::F32(db) => {
let vec: Vec<f32> = query_vector.extract()?;
db.search_hybrid(Some(query_text), Some(&vec), &config)
}
DbBackend::F16(db) => {
let vec: Vec<f32> = query_vector.extract()?;
let vec16: Vec<half::f16> = vec.into_iter().map(half::f16::from_f32).collect();
db.search_hybrid(Some(query_text), Some(&vec16), &config)
}
DbBackend::U64(db) => {
let vec: Vec<u64> = query_vector.extract()?;
db.search_hybrid(Some(query_text), Some(&vec), &config)
}
}
.map_err(|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
})?;
Ok(results
.into_iter()
.map(|h| PySearchHit {
id: h.id,
score: h.score,
payload: json_to_pyobject(py, &h.payload),
})
.collect())
}
fn delete(&mut self, id: u64) -> PyResult<()> {
dispatch!(self, mut db => db.delete(id)).map_err(|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
})
}
fn unlink(&mut self, src: u64, dst: u64) -> PyResult<()> {
dispatch!(self, mut db => db.unlink(src, dst)).map_err(
|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
},
)
}
fn update_payload(
&mut self,
py: Python<'_>,
id: u64,
payload: &Bound<'_, PyAny>,
) -> PyResult<()> {
let json = pyobject_to_json(py, payload);
dispatch!(self, mut db => db.update_payload(id, json)).map_err(
|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
},
)
}
fn update_vector(&mut self, vector: Bound<'_, PyAny>, id: u64) -> PyResult<()> {
match &mut self.inner {
DbBackend::F32(db) => {
let vec: Vec<f32> = vector.extract()?;
db.update_vector(id, &vec)
.map_err(|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
})
}
DbBackend::F16(db) => {
let vec: Vec<f32> = vector.extract()?;
let vec16: Vec<half::f16> = vec.into_iter().map(half::f16::from_f32).collect();
db.update_vector(id, &vec16)
.map_err(|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
})
}
DbBackend::U64(db) => {
let vec: Vec<u64> = vector.extract()?;
db.update_vector(id, &vec)
.map_err(|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
})
}
}
}
fn index_text(&mut self, id: u64, text: &str) -> PyResult<()> {
dispatch!(self, mut db => db.index_text(id, text)).map_err(
|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
},
)
}
fn index_keyword(&mut self, id: u64, keyword: &str) -> PyResult<()> {
dispatch!(self, mut db => db.index_keyword(id, keyword)).map_err(
|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
},
)
}
fn build_text_index(&mut self) {
let _ = dispatch!(self, mut db => db.build_text_index());
}
fn create_index(&mut self, field: &str) {
dispatch!(self, mut db => db.create_index(field));
}
fn drop_index(&mut self, field: &str) {
dispatch!(self, mut db => db.drop_index(field));
}
fn get_payload(&self, py: Python<'_>, id: u64) -> Option<PyObject> {
dispatch!(self, db => db.get_payload(id)).map(|p| json_to_pyobject(py, &p))
}
fn get_edges(&self, id: u64) -> Vec<PyEdge> {
dispatch!(self, db => db.get_edges(id))
.into_iter()
.map(|e| PyEdge {
target_id: e.target_id,
label: e.label,
weight: e.weight,
})
.collect()
}
fn get(&self, py: Python<'_>, id: u64) -> PyResult<Option<PyNodeView>> {
match &self.inner {
DbBackend::F32(db) => {
if let Some(n) = db.get(id) {
return Ok(Some(PyNodeView {
id: n.id,
vector: n.vector.into_pyobject(py).unwrap().into_any().unbind(),
payload: json_to_pyobject(py, &n.payload),
edges: n
.edges
.iter()
.map(|e| PyEdge {
target_id: e.target_id,
label: e.label.clone(),
weight: e.weight,
})
.collect(),
num_edges: n.edges.len(),
}));
}
}
DbBackend::F16(db) => {
if let Some(n) = db.get(id) {
let f32_vec: Vec<f32> = n.vector.into_iter().map(|f| f.to_f32()).collect();
return Ok(Some(PyNodeView {
id: n.id,
vector: f32_vec.into_pyobject(py).unwrap().into_any().unbind(),
payload: json_to_pyobject(py, &n.payload),
edges: n
.edges
.iter()
.map(|e| PyEdge {
target_id: e.target_id,
label: e.label.clone(),
weight: e.weight,
})
.collect(),
num_edges: n.edges.len(),
}));
}
}
DbBackend::U64(db) => {
if let Some(n) = db.get(id) {
return Ok(Some(PyNodeView {
id: n.id,
vector: n.vector.into_pyobject(py).unwrap().into_any().unbind(),
payload: json_to_pyobject(py, &n.payload),
edges: n
.edges
.iter()
.map(|e| PyEdge {
target_id: e.target_id,
label: e.label.clone(),
weight: e.weight,
})
.collect(),
num_edges: n.edges.len(),
}));
}
}
}
Ok(None)
}
#[pyo3(signature = (id, depth=1))]
fn neighbors(&self, id: u64, depth: usize) -> Vec<u64> {
dispatch!(self, db => db.neighbors(id, depth))
}
fn node_count(&self) -> usize {
dispatch!(self, db => db.node_count())
}
fn flush(&mut self) -> PyResult<()> {
dispatch!(self, mut db => db.flush()).map_err(|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
})
}
fn dim(&self) -> usize {
dispatch!(self, db => db.dim())
}
fn all_node_ids(&self) -> Vec<u64> {
dispatch!(self, db => db.all_node_ids())
}
fn migrate(&self, new_path: &str, new_dim: usize) -> PyResult<Vec<u64>> {
match &self.inner {
DbBackend::F32(db) => {
let (_new_db, ids) = db.migrate_to(new_path, new_dim).map_err(
|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
},
)?;
Ok(ids)
}
DbBackend::F16(db) => {
let (_new_db, ids) = db.migrate_to(new_path, new_dim).map_err(
|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
},
)?;
Ok(ids)
}
DbBackend::U64(db) => {
let (_new_db, ids) = db.migrate_to(new_path, new_dim).map_err(
|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
},
)?;
Ok(ids)
}
}
}
#[pyo3(signature = (interval_secs=7200))]
fn enable_auto_compaction(&mut self, interval_secs: u64) {
dispatch!(self, mut db => db.enable_auto_compaction(std::time::Duration::from_secs(interval_secs)));
}
fn disable_auto_compaction(&mut self) {
dispatch!(self, mut db => db.disable_auto_compaction());
}
fn compact(&mut self) -> PyResult<()> {
dispatch!(self, mut db => db.compact()).map_err(|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
})
}
#[pyo3(signature = (mb=0))]
fn set_memory_limit(&mut self, mb: usize) {
let bytes = mb * 1024 * 1024;
dispatch!(self, mut db => db.set_memory_limit(bytes));
}
fn estimated_memory(&self) -> usize {
dispatch!(self, db => db.estimated_memory())
}
fn __len__(&self) -> usize {
self.node_count()
}
fn __contains__(&self, id: u64) -> bool {
dispatch!(self, db => db.contains(id))
}
fn __repr__(&self) -> String {
format!(
"TriviumDB(dtype={}, nodes={}, dim={})",
self.dtype,
self.node_count(),
self.dim()
)
}
fn __enter__(slf: Py<Self>) -> Py<Self> {
slf
}
#[pyo3(signature = (_exc_type=None, _exc_val=None, _exc_tb=None))]
fn __exit__(
&mut self,
_exc_type: Option<&Bound<'_, PyAny>>,
_exc_val: Option<&Bound<'_, PyAny>>,
_exc_tb: Option<&Bound<'_, PyAny>>,
) -> PyResult<bool> {
self.flush()?;
Ok(false)
}
fn batch_insert(
&mut self,
py: Python<'_>,
vectors: Bound<'_, PyList>,
payloads: &Bound<'_, PyList>,
) -> PyResult<Vec<u64>> {
if vectors.len() != payloads.len() {
return Err(pyo3::exceptions::PyValueError::new_err(
"vectors and payloads must have the same length",
));
}
match &mut self.inner {
DbBackend::F32(db) => {
let mut ids = Vec::with_capacity(vectors.len());
for (i, payload_obj) in payloads.iter().enumerate() {
let vec_obj = vectors.get_item(i)?;
let vec: Vec<f32> = vec_obj.extract()?;
let json = pyobject_to_json(py, &payload_obj);
let id =
db.insert(&vec, json)
.map_err(|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
})?;
ids.push(id);
}
Ok(ids)
}
DbBackend::F16(db) => {
let mut ids = Vec::with_capacity(vectors.len());
for (i, payload_obj) in payloads.iter().enumerate() {
let vec_obj = vectors.get_item(i)?;
let vec: Vec<f32> = vec_obj.extract()?;
let vec16: Vec<half::f16> =
vec.into_iter().map(half::f16::from_f32).collect();
let json = pyobject_to_json(py, &payload_obj);
let id =
db.insert(&vec16, json)
.map_err(|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
})?;
ids.push(id);
}
Ok(ids)
}
DbBackend::U64(db) => {
let mut ids = Vec::with_capacity(vectors.len());
for (i, payload_obj) in payloads.iter().enumerate() {
let vec_obj = vectors.get_item(i)?;
let vec: Vec<u64> = vec_obj.extract()?;
let json = pyobject_to_json(py, &payload_obj);
let id =
db.insert(&vec, json)
.map_err(|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
})?;
ids.push(id);
}
Ok(ids)
}
}
}
fn batch_insert_with_ids(
&mut self,
py: Python<'_>,
ids: Vec<u64>,
vectors: Bound<'_, PyList>,
payloads: &Bound<'_, PyList>,
) -> PyResult<()> {
if vectors.len() != payloads.len() || ids.len() != vectors.len() {
return Err(pyo3::exceptions::PyValueError::new_err(
"ids, vectors and payloads must have the same length",
));
}
match &mut self.inner {
DbBackend::F32(db) => {
for (i, payload_obj) in payloads.iter().enumerate() {
let vec_obj = vectors.get_item(i)?;
let vec: Vec<f32> = vec_obj.extract()?;
let json = pyobject_to_json(py, &payload_obj);
db.insert_with_id(ids[i], &vec, json).map_err(
|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
},
)?;
}
Ok(())
}
DbBackend::F16(db) => {
for (i, payload_obj) in payloads.iter().enumerate() {
let vec_obj = vectors.get_item(i)?;
let vec: Vec<f32> = vec_obj.extract()?;
let vec16: Vec<half::f16> =
vec.into_iter().map(half::f16::from_f32).collect();
let json = pyobject_to_json(py, &payload_obj);
db.insert_with_id(ids[i], &vec16, json).map_err(
|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
},
)?;
}
Ok(())
}
DbBackend::U64(db) => {
for (i, payload_obj) in payloads.iter().enumerate() {
let vec_obj = vectors.get_item(i)?;
let vec: Vec<u64> = vec_obj.extract()?;
let json = pyobject_to_json(py, &payload_obj);
db.insert_with_id(ids[i], &vec, json).map_err(
|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
},
)?;
}
Ok(())
}
}
}
fn tql(&self, py: Python<'_>, query: &str) -> PyResult<Vec<PyQueryRow>> {
fn convert_rows<T: crate::VectorType>(
py: Python<'_>,
rows: Vec<std::collections::HashMap<String, crate::node::Node<T>>>,
) -> PyResult<Vec<PyQueryRow>> {
let mut out = Vec::with_capacity(rows.len());
for row in rows {
let py_row = PyDict::new(py);
for (var_name, node) in &row {
let node_dict = PyDict::new(py);
let _ = node_dict.set_item("id", node.id);
let _ = node_dict.set_item("payload", json_to_pyobject(py, &node.payload));
let _ = node_dict.set_item("num_edges", node.edges.len());
let _ = py_row.set_item(var_name, node_dict);
}
out.push(PyQueryRow {
row: py_row.into_any().unbind(),
});
}
Ok(out)
}
match &self.inner {
DbBackend::F32(db) => {
let rows = db.tql(query).map_err(|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
})?;
convert_rows(py, rows)
}
DbBackend::F16(db) => {
let rows = db.tql(query).map_err(|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
})?;
convert_rows(py, rows)
}
DbBackend::U64(db) => {
let rows = db.tql(query).map_err(|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
})?;
convert_rows(py, rows)
}
}
}
fn tql_mut(&mut self, py: Python<'_>, query: &str) -> PyResult<PyObject> {
let result = dispatch!(self, mut db => db.tql_mut(query)).map_err(
|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
},
)?;
let dict = PyDict::new(py);
let _ = dict.set_item("affected", result.affected);
let created: Vec<u64> = result.created_ids;
let _ = dict.set_item("created_ids", created);
Ok(dict.into_any().unbind())
}
#[pyo3(signature = (min_community_size=3, max_iterations=15, compute_centroids=true))]
fn leiden_cluster(
&self,
py: Python<'_>,
min_community_size: usize,
max_iterations: usize,
compute_centroids: bool,
) -> PyResult<PyObject> {
let result = dispatch!(self, db => db.leiden_cluster(
min_community_size,
Some(max_iterations),
Some(compute_centroids),
))
.map_err(|e: crate::error::TriviumError| {
pyo3::exceptions::PyRuntimeError::new_err(e.to_string())
})?;
let mut clusters: std::collections::HashMap<u32, Vec<u64>> =
std::collections::HashMap::new();
for (&node_id, &cluster_id) in &result.node_to_cluster {
clusters.entry(cluster_id).or_default().push(node_id);
}
let mut sorted_keys: Vec<u32> = clusters.keys().copied().collect();
sorted_keys.sort_unstable();
let communities = PyList::new(
py,
sorted_keys.iter().map(|k| {
let mut ids = clusters.get(k).cloned().unwrap_or_default();
ids.sort_unstable();
ids
}),
)?;
let centroids_dict = PyDict::new(py);
if compute_centroids {
for &k in &sorted_keys {
if let Some(centroid) = result.centroids.get(&k) {
let _ = centroids_dict.set_item(k, centroid.clone());
}
}
}
let out = PyDict::new(py);
let _ = out.set_item("communities", communities);
let _ = out.set_item("centroids", centroids_dict);
let _ = out.set_item("num_clusters", result.num_clusters);
Ok(out.into_any().unbind())
}
fn transaction(slf: Py<Self>, py: Python<'_>) -> PyResult<PyTransaction> {
let dtype = slf.borrow(py).dtype.clone();
let builder = match dtype.as_str() {
"f32" => TxBuilderBackend::F32(crate::database::TxBuilder::new()),
"f16" => TxBuilderBackend::F16(crate::database::TxBuilder::new()),
"u64" => TxBuilderBackend::U64(crate::database::TxBuilder::new()),
_ => {
return Err(pyo3::exceptions::PyValueError::new_err(
format!("不支持的 dtype: {}", dtype),
));
}
};
Ok(PyTransaction {
db: slf,
builder: Some(builder),
finished: false,
})
}
fn close(&mut self) -> PyResult<()> {
self.flush()
}
}
enum TxBuilderBackend {
F32(crate::database::TxBuilder<f32>),
F16(crate::database::TxBuilder<half::f16>),
U64(crate::database::TxBuilder<u64>),
}
#[pyclass(name = "Transaction")]
struct PyTransaction {
db: Py<PyTriviumDB>,
builder: Option<TxBuilderBackend>,
finished: bool,
}
macro_rules! check_finished {
($self:expr) => {
if $self.finished {
return Err(pyo3::exceptions::PyRuntimeError::new_err(
"事务已结束(已提交或已回滚),不能继续添加操作",
));
}
};
}
#[pymethods]
impl PyTransaction {
fn insert(&mut self, py: Python<'_>, vector: Vec<f64>, payload: &Bound<'_, PyAny>) -> PyResult<()> {
check_finished!(self);
let json = pyobject_to_json(py, payload);
match self.builder.as_mut().expect("TxBuilder missing") {
TxBuilderBackend::F32(b) => {
let v: Vec<f32> = vector.iter().map(|&x| x as f32).collect();
b.insert(&v, json);
}
TxBuilderBackend::F16(b) => {
let v: Vec<half::f16> = vector.iter().map(|&x| half::f16::from_f32(x as f32)).collect();
b.insert(&v, json);
}
TxBuilderBackend::U64(b) => {
let v: Vec<u64> = vector.iter().map(|&x| x as u64).collect();
b.insert(&v, json);
}
}
Ok(())
}
fn insert_with_id(&mut self, py: Python<'_>, id: u64, vector: Vec<f64>, payload: &Bound<'_, PyAny>) -> PyResult<()> {
check_finished!(self);
let json = pyobject_to_json(py, payload);
match self.builder.as_mut().expect("TxBuilder missing") {
TxBuilderBackend::F32(b) => {
let v: Vec<f32> = vector.iter().map(|&x| x as f32).collect();
b.insert_with_id(id, &v, json);
}
TxBuilderBackend::F16(b) => {
let v: Vec<half::f16> = vector.iter().map(|&x| half::f16::from_f32(x as f32)).collect();
b.insert_with_id(id, &v, json);
}
TxBuilderBackend::U64(b) => {
let v: Vec<u64> = vector.iter().map(|&x| x as u64).collect();
b.insert_with_id(id, &v, json);
}
}
Ok(())
}
#[pyo3(signature = (src, dst, label="related", weight=1.0))]
fn link(&mut self, src: u64, dst: u64, label: &str, weight: f32) -> PyResult<()> {
check_finished!(self);
match self.builder.as_mut().expect("TxBuilder missing") {
TxBuilderBackend::F32(b) => b.link(src, dst, label, weight),
TxBuilderBackend::F16(b) => b.link(src, dst, label, weight),
TxBuilderBackend::U64(b) => b.link(src, dst, label, weight),
}
Ok(())
}
fn delete(&mut self, id: u64) -> PyResult<()> {
check_finished!(self);
match self.builder.as_mut().expect("TxBuilder missing") {
TxBuilderBackend::F32(b) => b.delete(id),
TxBuilderBackend::F16(b) => b.delete(id),
TxBuilderBackend::U64(b) => b.delete(id),
}
Ok(())
}
fn unlink(&mut self, src: u64, dst: u64) -> PyResult<()> {
check_finished!(self);
match self.builder.as_mut().expect("TxBuilder missing") {
TxBuilderBackend::F32(b) => b.unlink(src, dst),
TxBuilderBackend::F16(b) => b.unlink(src, dst),
TxBuilderBackend::U64(b) => b.unlink(src, dst),
}
Ok(())
}
fn update_payload(&mut self, py: Python<'_>, id: u64, payload: &Bound<'_, PyAny>) -> PyResult<()> {
check_finished!(self);
let json = pyobject_to_json(py, payload);
match self.builder.as_mut().expect("TxBuilder missing") {
TxBuilderBackend::F32(b) => b.update_payload(id, json),
TxBuilderBackend::F16(b) => b.update_payload(id, json),
TxBuilderBackend::U64(b) => b.update_payload(id, json),
}
Ok(())
}
fn update_vector(&mut self, id: u64, vector: Vec<f64>) -> PyResult<()> {
check_finished!(self);
match self.builder.as_mut().expect("TxBuilder missing") {
TxBuilderBackend::F32(b) => {
let v: Vec<f32> = vector.iter().map(|&x| x as f32).collect();
b.update_vector(id, &v);
}
TxBuilderBackend::F16(b) => {
let v: Vec<half::f16> = vector.iter().map(|&x| half::f16::from_f32(x as f32)).collect();
b.update_vector(id, &v);
}
TxBuilderBackend::U64(b) => {
let v: Vec<u64> = vector.iter().map(|&x| x as u64).collect();
b.update_vector(id, &v);
}
}
Ok(())
}
fn pending_count(&self) -> usize {
match self.builder.as_ref() {
Some(TxBuilderBackend::F32(b)) => b.pending_count(),
Some(TxBuilderBackend::F16(b)) => b.pending_count(),
Some(TxBuilderBackend::U64(b)) => b.pending_count(),
None => 0,
}
}
fn commit(&mut self, py: Python<'_>) -> PyResult<Vec<u64>> {
if self.finished {
return Err(pyo3::exceptions::PyRuntimeError::new_err(
"事务已结束(已提交或已回滚),不能重复提交",
));
}
self.finished = true;
let builder = self.builder.take().expect("TxBuilder missing");
let mut db_ref = self.db.borrow_mut(py);
match (&mut db_ref.inner, builder) {
(DbBackend::F32(db), TxBuilderBackend::F32(b)) => {
db.commit_tx(b).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
}
(DbBackend::F16(db), TxBuilderBackend::F16(b)) => {
db.commit_tx(b).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
}
(DbBackend::U64(db), TxBuilderBackend::U64(b)) => {
db.commit_tx(b).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
}
_ => Err(pyo3::exceptions::PyRuntimeError::new_err("dtype 不匹配")),
}
}
fn rollback(&mut self) -> PyResult<()> {
if self.finished {
return Err(pyo3::exceptions::PyRuntimeError::new_err(
"事务已结束(已提交或已回滚),不能重复回滚",
));
}
self.finished = true;
self.builder.take();
Ok(())
}
fn __enter__(slf: Py<Self>) -> Py<Self> {
slf
}
#[pyo3(signature = (exc_type=None, _exc_val=None, _exc_tb=None))]
fn __exit__(
&mut self,
py: Python<'_>,
exc_type: Option<&Bound<'_, PyAny>>,
_exc_val: Option<&Bound<'_, PyAny>>,
_exc_tb: Option<&Bound<'_, PyAny>>,
) -> PyResult<bool> {
if self.finished {
return Ok(false);
}
if exc_type.is_some() {
self.finished = true;
self.builder.take();
} else {
self.commit(py)?;
}
Ok(false)
}
fn __repr__(&self) -> String {
format!(
"Transaction(pending={}, finished={})",
self.pending_count(),
self.finished
)
}
}
struct PySearchHookWrapper {
py_hook: PyObject,
}
unsafe impl Send for PySearchHookWrapper {}
unsafe impl Sync for PySearchHookWrapper {}
impl PySearchHookWrapper {
fn hits_to_py(py: Python<'_>, hits: &[crate::node::SearchHit]) -> PyObject {
let list = pyo3::types::PyList::new(
py,
hits.iter().map(|h| {
let d = PyDict::new(py);
let _ = d.set_item("id", h.id);
let _ = d.set_item("score", h.score);
let _ = d.set_item("payload", json_to_pyobject(py, &h.payload));
d
}),
).expect("创建 Python list 失败");
list.into_any().unbind()
}
fn py_to_hits(py: Python<'_>, obj: &PyObject) -> Vec<crate::node::SearchHit> {
let mut hits = Vec::new();
if let Ok(list) = obj.bind(py).downcast::<pyo3::types::PyList>() {
for item in list.iter() {
if let Ok(dict) = item.downcast::<PyDict>() {
let id = dict.get_item("id").ok().flatten()
.and_then(|v| v.extract::<u64>().ok()).unwrap_or(0);
let score = dict.get_item("score").ok().flatten()
.and_then(|v| v.extract::<f32>().ok()).unwrap_or(0.0);
let payload = dict.get_item("payload").ok().flatten()
.map(|v| pyobject_to_json(py, &v))
.unwrap_or(serde_json::Value::Null);
hits.push(crate::node::SearchHit { id, score, payload });
}
}
}
hits
}
}
impl crate::hook::SearchHook for PySearchHookWrapper {
fn on_pre_search(
&self,
query_vector: &mut Vec<f32>,
_config: &mut crate::database::SearchConfig,
ctx: &mut crate::hook::HookContext,
) {
pyo3::Python::with_gil(|py| {
let hook = self.py_hook.bind(py);
if let Ok(method) = hook.getattr("on_pre_search") {
if let Ok(py_vec) = pyo3::types::PyList::new(py, query_vector.iter()) {
let py_ctx = PyDict::new(py);
let _ = py_ctx.set_item("custom_data", json_to_pyobject(py, &ctx.custom_data));
let _ = py_ctx.set_item("abort", ctx.abort);
if let Ok(result) = method.call1((&py_vec, &py_ctx)) {
if let Ok(new_vec) = result.extract::<Vec<f32>>() {
*query_vector = new_vec;
}
if let Ok(Some(abort_val)) = py_ctx.get_item("abort") {
if let Ok(ab) = abort_val.extract::<bool>() {
ctx.abort = ab;
}
}
}
}
}
});
}
fn on_post_recall(&self, hits: &mut Vec<crate::node::SearchHit>, ctx: &mut crate::hook::HookContext) {
pyo3::Python::with_gil(|py| {
let hook = self.py_hook.bind(py);
if let Ok(method) = hook.getattr("on_post_recall") {
let py_hits = Self::hits_to_py(py, hits);
let py_ctx = PyDict::new(py);
let _ = py_ctx.set_item("custom_data", json_to_pyobject(py, &ctx.custom_data));
if let Ok(result) = method.call1((&py_hits, &py_ctx)) {
if !result.is_none() {
let obj = result.unbind();
*hits = Self::py_to_hits(py, &obj);
}
}
}
});
}
fn on_rerank(
&self,
hits: &mut Vec<crate::node::SearchHit>,
ctx: &mut crate::hook::HookContext,
) -> Option<Vec<crate::node::SearchHit>> {
pyo3::Python::with_gil(|py| {
let hook = self.py_hook.bind(py);
if let Ok(method) = hook.getattr("on_rerank") {
let py_hits = Self::hits_to_py(py, hits);
let py_ctx = PyDict::new(py);
let _ = py_ctx.set_item("custom_data", json_to_pyobject(py, &ctx.custom_data));
if let Ok(result) = method.call1((&py_hits, &py_ctx)) {
if !result.is_none() {
let obj = result.unbind();
return Some(Self::py_to_hits(py, &obj));
}
}
}
None
})
}
fn on_post_search(&self, hits: &mut Vec<crate::node::SearchHit>, ctx: &mut crate::hook::HookContext) {
pyo3::Python::with_gil(|py| {
let hook = self.py_hook.bind(py);
if let Ok(method) = hook.getattr("on_post_search") {
let py_hits = Self::hits_to_py(py, hits);
let py_ctx = PyDict::new(py);
let _ = py_ctx.set_item("custom_data", json_to_pyobject(py, &ctx.custom_data));
if let Ok(result) = method.call1((&py_hits, &py_ctx)) {
if !result.is_none() {
let obj = result.unbind();
*hits = Self::py_to_hits(py, &obj);
}
}
}
});
}
}
#[pyfunction]
pub fn init_logger() {
use tracing_subscriber::{EnvFilter, fmt};
let _ = fmt()
.with_env_filter(
EnvFilter::from_default_env().add_directive(tracing::Level::INFO.into()),
)
.try_init();
}
#[pymodule]
pub fn triviumdb(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyTriviumDB>()?;
m.add_class::<PySearchHit>()?;
m.add_class::<PyNodeView>()?;
m.add_class::<PyQueryRow>()?;
m.add_class::<PyHookContext>()?;
m.add_class::<PyTransaction>()?;
m.add_function(wrap_pyfunction!(init_logger, m)?)?;
Ok(())
}
}