use std::collections::HashSet;
use std::sync::Arc;
use manifoldb_core::{Entity, EntityId, ScoredEntity, VectorData};
use manifoldb_graph::traversal::{Direction, PathPattern, PathStep};
use manifoldb_storage::backends::RedbEngine;
use crate::collection::CollectionHandle;
use crate::collection::Vector as CollectionVector;
use crate::error::Result;
use crate::filter::Filter;
use crate::Error;
#[derive(Clone)]
pub struct TraversalConstraint {
start: EntityId,
pattern: PathPattern,
}
impl TraversalConstraint {
pub fn new(start: EntityId, pattern: PathPattern) -> Self {
Self { start, pattern }
}
#[must_use]
pub fn start(&self) -> EntityId {
self.start
}
#[must_use]
pub fn pattern(&self) -> &PathPattern {
&self.pattern
}
}
#[derive(Debug, Clone, Default)]
pub struct TraversalPatternBuilder {
pattern: PathPattern,
}
impl TraversalPatternBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn edge_out(mut self, edge_type: impl Into<manifoldb_core::EdgeType>) -> Self {
self.pattern = self.pattern.add_step(PathStep::outgoing(edge_type));
self
}
#[must_use]
pub fn edge_in(mut self, edge_type: impl Into<manifoldb_core::EdgeType>) -> Self {
self.pattern = self.pattern.add_step(PathStep::incoming(edge_type));
self
}
#[must_use]
pub fn edge_both(mut self, edge_type: impl Into<manifoldb_core::EdgeType>) -> Self {
self.pattern = self.pattern.add_step(PathStep::both(edge_type));
self
}
#[must_use]
pub fn any_out(mut self) -> Self {
self.pattern = self.pattern.add_step(PathStep::any(Direction::Outgoing));
self
}
#[must_use]
pub fn any_in(mut self) -> Self {
self.pattern = self.pattern.add_step(PathStep::any(Direction::Incoming));
self
}
#[must_use]
pub fn any_both(mut self) -> Self {
self.pattern = self.pattern.add_step(PathStep::any(Direction::Both));
self
}
#[must_use]
pub fn variable_length(mut self, min: usize, max: usize) -> Self {
let steps = self.pattern.steps();
if steps.is_empty() {
return self;
}
let last_idx = steps.len() - 1;
let last_step = steps[last_idx].clone();
let var_step =
PathStep::new(last_step.direction, last_step.filter.clone()).variable_length(min, max);
let mut new_pattern = PathPattern::new();
for (i, step) in steps.iter().enumerate() {
if i == last_idx {
new_pattern = new_pattern.add_step(var_step.clone());
} else {
new_pattern = new_pattern.add_step(step.clone());
}
}
self.pattern = new_pattern;
self
}
#[must_use]
pub fn step(mut self, step: PathStep) -> Self {
self.pattern = self.pattern.add_step(step);
self
}
#[must_use]
pub fn build(self) -> PathPattern {
self.pattern
}
}
pub struct EntitySearchBuilder {
handle: CollectionHandle<Arc<RedbEngine>>,
engine: Arc<RedbEngine>,
vector_name: String,
query: Option<VectorData>,
limit: usize,
offset: usize,
filter: Option<Filter>,
score_threshold: Option<f32>,
traversal_constraint: Option<TraversalConstraint>,
}
impl EntitySearchBuilder {
pub(crate) fn new(
handle: CollectionHandle<Arc<RedbEngine>>,
engine: Arc<RedbEngine>,
vector_name: impl Into<String>,
) -> Self {
Self {
handle,
engine,
vector_name: vector_name.into(),
query: None,
limit: 10,
offset: 0,
filter: None,
score_threshold: None,
traversal_constraint: None,
}
}
#[must_use]
pub fn query(mut self, vector: impl Into<VectorData>) -> Self {
self.query = Some(vector.into());
self
}
#[must_use]
pub fn limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
#[must_use]
pub fn offset(mut self, offset: usize) -> Self {
self.offset = offset;
self
}
#[must_use]
pub fn filter(mut self, filter: Filter) -> Self {
self.filter = Some(filter);
self
}
#[must_use]
pub fn score_threshold(mut self, threshold: f32) -> Self {
self.score_threshold = Some(threshold);
self
}
#[must_use]
pub fn within_traversal<F>(mut self, start: EntityId, pattern_builder: F) -> Self
where
F: FnOnce(TraversalPatternBuilder) -> TraversalPatternBuilder,
{
let builder = TraversalPatternBuilder::new();
let pattern = pattern_builder(builder).build();
self.traversal_constraint = Some(TraversalConstraint::new(start, pattern));
self
}
#[must_use]
pub fn with_traversal_constraint(mut self, constraint: TraversalConstraint) -> Self {
self.traversal_constraint = Some(constraint);
self
}
pub fn execute(self) -> Result<Vec<ScoredEntity>> {
use manifoldb_storage::StorageEngine;
let query =
self.query.ok_or_else(|| Error::InvalidInput("No query vector provided".into()))?;
let reachable_ids: Option<HashSet<EntityId>> =
if let Some(ref constraint) = self.traversal_constraint {
let tx = self.engine.begin_read().map_err(|e| {
Error::Execution(format!("Failed to start read transaction: {e}"))
})?;
let matches = constraint
.pattern()
.find_from(&tx, constraint.start())
.map_err(|e| Error::Execution(format!("Graph traversal failed: {e}")))?;
let ids: HashSet<EntityId> = matches.iter().map(|m| m.target()).collect();
Some(ids)
} else {
None
};
let collection_query = vector_data_to_collection_vector(&query);
let collection_filter = self.filter.map(filter_to_collection_filter);
let fetch_limit = if self.traversal_constraint.is_some() {
(self.limit * 10).max(100)
} else {
self.limit
};
let scored_points = self
.handle
.execute_search(
&self.vector_name,
collection_query,
fetch_limit,
self.offset,
collection_filter,
true, false, self.score_threshold,
None, )
.map_err(|e| Error::Collection(e.to_string()))?;
let results: Vec<ScoredEntity> = if let Some(ref allowed_ids) = reachable_ids {
scored_points
.into_iter()
.filter_map(|sp| {
let entity_id = EntityId::new(sp.id.as_u64());
if allowed_ids.contains(&entity_id) {
let entity = scored_point_to_entity(sp.id, sp.payload);
Some(ScoredEntity::new(entity, sp.score))
} else {
None
}
})
.take(self.limit)
.collect()
} else {
scored_points
.into_iter()
.map(|sp| {
let entity = scored_point_to_entity(sp.id, sp.payload);
ScoredEntity::new(entity, sp.score)
})
.collect()
};
Ok(results)
}
}
fn vector_data_to_collection_vector(data: &VectorData) -> CollectionVector {
match data {
VectorData::Dense(v) => CollectionVector::Dense(v.clone()),
VectorData::Sparse(v) => CollectionVector::Sparse(v.clone()),
VectorData::Multi(v) => CollectionVector::Multi(v.clone()),
}
}
fn filter_to_collection_filter(filter: Filter) -> crate::collection::Filter {
match filter {
Filter::Eq { field, value } => crate::collection::Filter::Eq { field, value },
Filter::Ne { field, value } => crate::collection::Filter::Ne { field, value },
Filter::Gt { field, value } => crate::collection::Filter::Gt { field, value },
Filter::Gte { field, value } => crate::collection::Filter::Gte { field, value },
Filter::Lt { field, value } => crate::collection::Filter::Lt { field, value },
Filter::Lte { field, value } => crate::collection::Filter::Lte { field, value },
Filter::Range { field, min, max } => crate::collection::Filter::Range { field, min, max },
Filter::In { field, values } => crate::collection::Filter::In { field, values },
Filter::NotIn { field, values } => crate::collection::Filter::NotIn { field, values },
Filter::Contains { field, substring } => {
crate::collection::Filter::Contains { field, substring }
}
Filter::StartsWith { field, prefix } => {
crate::collection::Filter::StartsWith { field, prefix }
}
Filter::ArrayContains { field, value } => {
crate::collection::Filter::ArrayContains { field, value }
}
Filter::Exists { field } => crate::collection::Filter::Exists { field },
Filter::NotExists { field } => crate::collection::Filter::NotExists { field },
Filter::And(filters) => crate::collection::Filter::And(
filters.into_iter().map(filter_to_collection_filter).collect(),
),
Filter::Or(filters) => crate::collection::Filter::Or(
filters.into_iter().map(filter_to_collection_filter).collect(),
),
Filter::Not(filter) => {
crate::collection::Filter::Not(Box::new(filter_to_collection_filter(*filter)))
}
}
}
fn scored_point_to_entity(
id: manifoldb_core::PointId,
payload: Option<serde_json::Value>,
) -> Entity {
let entity_id = EntityId::new(id.as_u64());
let mut entity = Entity::new(entity_id);
if let Some(serde_json::Value::Object(map)) = payload {
for (key, value) in map {
if let Some(prop_value) = json_to_value(&value) {
entity = entity.with_property(key, prop_value);
}
}
}
entity
}
fn json_to_value(json: &serde_json::Value) -> Option<manifoldb_core::Value> {
match json {
serde_json::Value::Null => Some(manifoldb_core::Value::Null),
serde_json::Value::Bool(b) => Some(manifoldb_core::Value::Bool(*b)),
serde_json::Value::Number(n) => n
.as_i64()
.map(manifoldb_core::Value::Int)
.or_else(|| n.as_f64().map(manifoldb_core::Value::Float)),
serde_json::Value::String(s) => Some(manifoldb_core::Value::String(s.clone())),
serde_json::Value::Array(arr) => {
let floats: Option<Vec<f32>> =
arr.iter().map(|v| v.as_f64().map(|f| f as f32)).collect();
if let Some(vec) = floats {
Some(manifoldb_core::Value::Vector(vec))
} else {
let values: Option<Vec<manifoldb_core::Value>> =
arr.iter().map(json_to_value).collect();
values.map(manifoldb_core::Value::Array)
}
}
serde_json::Value::Object(_) => {
None
}
}
}
pub fn entity_to_point_struct(
entity: &Entity,
collection_name: &str,
) -> crate::collection::PointStruct {
use crate::collection::PointStruct;
let mut point = PointStruct::new(entity.id.as_u64());
let payload = entity_properties_to_json(entity);
if !payload.as_object().map_or(true, |o| o.is_empty()) {
point = point.with_payload(payload);
}
for (name, vector_data) in &entity.vectors {
let collection_vec = match vector_data {
VectorData::Dense(v) => CollectionVector::Dense(v.clone()),
VectorData::Sparse(v) => CollectionVector::Sparse(v.clone()),
VectorData::Multi(v) => CollectionVector::Multi(v.clone()),
};
point = point.with_vector(name.clone(), collection_vec);
}
let _ = collection_name;
point
}
fn entity_properties_to_json(entity: &Entity) -> serde_json::Value {
let mut map = serde_json::Map::new();
if !entity.labels.is_empty() {
let labels: Vec<serde_json::Value> = entity
.labels
.iter()
.map(|l| serde_json::Value::String(l.as_str().to_string()))
.collect();
map.insert("_labels".to_string(), serde_json::Value::Array(labels));
}
for (key, value) in &entity.properties {
map.insert(key.clone(), value_to_json(value));
}
serde_json::Value::Object(map)
}
fn value_to_json(value: &manifoldb_core::Value) -> serde_json::Value {
match value {
manifoldb_core::Value::Null => serde_json::Value::Null,
manifoldb_core::Value::Bool(b) => serde_json::Value::Bool(*b),
manifoldb_core::Value::Int(i) => serde_json::json!(*i),
manifoldb_core::Value::Float(f) => serde_json::Number::from_f64(*f)
.map_or(serde_json::Value::Null, serde_json::Value::Number),
manifoldb_core::Value::String(s) => serde_json::Value::String(s.clone()),
manifoldb_core::Value::Bytes(b) => {
use base64::Engine;
serde_json::Value::String(base64::engine::general_purpose::STANDARD.encode(b))
}
manifoldb_core::Value::Vector(v) => {
serde_json::Value::Array(v.iter().map(|f| serde_json::json!(*f)).collect())
}
manifoldb_core::Value::SparseVector(pairs) => serde_json::Value::Array(
pairs.iter().map(|(idx, val)| serde_json::json!([*idx, *val])).collect(),
),
manifoldb_core::Value::MultiVector(vecs) => serde_json::Value::Array(
vecs.iter()
.map(|v| {
serde_json::Value::Array(v.iter().map(|f| serde_json::json!(*f)).collect())
})
.collect(),
),
manifoldb_core::Value::Array(arr) => {
serde_json::Value::Array(arr.iter().map(value_to_json).collect())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_json_to_value_primitives() {
assert_eq!(json_to_value(&serde_json::json!(null)), Some(manifoldb_core::Value::Null));
assert_eq!(
json_to_value(&serde_json::json!(true)),
Some(manifoldb_core::Value::Bool(true))
);
assert_eq!(json_to_value(&serde_json::json!(42)), Some(manifoldb_core::Value::Int(42)));
assert_eq!(
json_to_value(&serde_json::json!(3.14)),
Some(manifoldb_core::Value::Float(3.14))
);
assert_eq!(
json_to_value(&serde_json::json!("hello")),
Some(manifoldb_core::Value::String("hello".to_string()))
);
}
#[test]
fn test_value_to_json_roundtrip() {
let original = manifoldb_core::Value::String("test".to_string());
let json = value_to_json(&original);
let recovered = json_to_value(&json);
assert_eq!(recovered, Some(original));
}
#[test]
fn test_entity_properties_to_json() {
let entity = Entity::new(EntityId::new(1))
.with_label("Test")
.with_property("name", "Alice")
.with_property("age", 30i64);
let json = entity_properties_to_json(&entity);
let obj = json.as_object().expect("should be object");
assert!(obj.contains_key("_labels"));
assert!(obj.contains_key("name"));
assert!(obj.contains_key("age"));
}
}