use std::sync::Arc;
use crate::storage::query::unified::ExecutionError;
use super::super::super::entity::{EntityId, UnifiedEntity};
use super::super::super::store::UnifiedStore;
use super::super::execution::execute_three_way_join;
use super::super::filters::{Filter, FilterAcceptor, WhereClause};
use super::super::types::QueryResult;
use super::graph::TraversalDirection;
#[derive(Debug, Clone)]
pub enum JoinPhase {
VectorStart { vector: Vec<f32>, k: usize },
GraphStart { node_id: String },
TableStart { table: String },
}
#[derive(Debug, Clone)]
pub enum JoinStep {
Traverse {
edge_label: Option<String>,
depth: u32,
direction: TraversalDirection,
},
JoinTable {
table: String,
on_field: Option<String>,
},
VectorExpand { k: usize },
}
#[derive(Debug, Clone)]
pub struct ThreeWayJoinBuilder {
pub(crate) start: Option<JoinPhase>,
pub(crate) pipeline: Vec<JoinStep>,
pub(crate) filters: Vec<Filter>,
pub(crate) limit: Option<usize>,
pub(crate) min_score: f32,
pub(crate) weights: CrossModalWeights,
}
#[derive(Debug, Clone)]
pub struct CrossModalWeights {
pub vector: f32,
pub graph: f32,
pub table: f32,
}
impl Default for CrossModalWeights {
fn default() -> Self {
Self {
vector: 0.4,
graph: 0.4,
table: 0.2,
}
}
}
impl ThreeWayJoinBuilder {
pub fn new() -> Self {
Self {
start: None,
pipeline: Vec::new(),
filters: Vec::new(),
limit: None,
min_score: 0.0,
weights: CrossModalWeights::default(),
}
}
pub fn start_vector(mut self, vector: &[f32], k: usize) -> Self {
self.start = Some(JoinPhase::VectorStart {
vector: vector.to_vec(),
k,
});
self
}
pub fn start_node(mut self, node_id: impl Into<String>) -> Self {
self.start = Some(JoinPhase::GraphStart {
node_id: node_id.into(),
});
self
}
pub fn start_table(mut self, table: impl Into<String>) -> Self {
self.start = Some(JoinPhase::TableStart {
table: table.into(),
});
self
}
pub fn traverse(mut self, edge_label: impl Into<String>, depth: u32) -> Self {
self.pipeline.push(JoinStep::Traverse {
edge_label: Some(edge_label.into()),
depth,
direction: TraversalDirection::Out,
});
self
}
pub fn traverse_any(mut self, depth: u32) -> Self {
self.pipeline.push(JoinStep::Traverse {
edge_label: None,
depth,
direction: TraversalDirection::Both,
});
self
}
pub fn traverse_in(mut self, edge_label: impl Into<String>, depth: u32) -> Self {
self.pipeline.push(JoinStep::Traverse {
edge_label: Some(edge_label.into()),
depth,
direction: TraversalDirection::In,
});
self
}
pub fn join_table(mut self, table: impl Into<String>) -> Self {
self.pipeline.push(JoinStep::JoinTable {
table: table.into(),
on_field: None,
});
self
}
pub fn join_table_on(mut self, table: impl Into<String>, field: impl Into<String>) -> Self {
self.pipeline.push(JoinStep::JoinTable {
table: table.into(),
on_field: Some(field.into()),
});
self
}
pub fn expand_similar(mut self, k: usize) -> Self {
self.pipeline.push(JoinStep::VectorExpand { k });
self
}
pub fn where_(self, field: impl Into<String>) -> WhereClause<Self> {
WhereClause::new(self, field.into())
}
pub fn limit(mut self, n: usize) -> Self {
self.limit = Some(n);
self
}
pub fn min_score(mut self, score: f32) -> Self {
self.min_score = score;
self
}
pub fn with_weights(mut self, vector: f32, graph: f32, table: f32) -> Self {
self.weights = CrossModalWeights {
vector,
graph,
table,
};
self
}
pub fn execute(self, store: &Arc<UnifiedStore>) -> Result<QueryResult, ExecutionError> {
execute_three_way_join(self, store)
}
}
impl Default for ThreeWayJoinBuilder {
fn default() -> Self {
Self::new()
}
}
impl FilterAcceptor for ThreeWayJoinBuilder {
fn add_filter(&mut self, filter: Filter) {
self.filters.push(filter);
}
}
pub struct CrossModalMatch {
pub entity: UnifiedEntity,
pub vector_score: f32,
pub graph_score: f32,
pub table_score: f32,
pub path: Vec<EntityId>,
}