use super::{
entity_db::{EdgeDb, EntityDb, NodeDb},
utils::apply_window,
vector_selection::VectorSelection,
};
use crate::{
core::entities::nodes::node_ref::AsNodeRef,
db::api::view::{DynamicGraph, IntoDynamic, StaticGraphViewOps},
errors::GraphResult,
prelude::GraphViewOps,
vectors::{
cache::CachedEmbeddingModel,
template::DocumentTemplate,
utils::find_top_k,
vector_collection::{lancedb::LanceDbCollection, VectorCollection},
Embedding, VectorsQuery,
},
};
#[derive(Clone)]
pub struct VectorisedGraph<G: StaticGraphViewOps> {
pub(crate) source_graph: G,
pub(crate) template: DocumentTemplate,
pub(crate) model: CachedEmbeddingModel,
pub(super) node_db: NodeDb<LanceDbCollection>,
pub(super) edge_db: EdgeDb<LanceDbCollection>,
}
impl<G: StaticGraphViewOps + IntoDynamic> VectorisedGraph<G> {
pub fn into_dynamic(self) -> VectorisedGraph<DynamicGraph> {
VectorisedGraph {
source_graph: self.source_graph.clone().into_dynamic(),
template: self.template,
model: self.model,
node_db: self.node_db,
edge_db: self.edge_db,
}
}
}
impl<G: StaticGraphViewOps> VectorisedGraph<G> {
pub async fn update_nodes<T: AsNodeRef>(&self, nodes: Vec<T>) -> GraphResult<()> {
let (ids, docs): (Vec<_>, Vec<_>) = nodes
.iter()
.filter_map(|node| {
self.source_graph.node(node).and_then(|node| {
let id = node.node.index() as u64;
self.template.node(node).map(|doc| (id, doc))
})
})
.unzip();
let vectors = self.model.get_embeddings(docs).await?;
self.node_db.insert_vectors(ids, vectors).await?;
Ok(())
}
pub async fn update_edges<T: AsNodeRef>(&self, edges: Vec<(T, T)>) -> GraphResult<()> {
let (ids, docs): (Vec<_>, Vec<_>) = edges
.iter()
.filter_map(|(src, dst)| {
self.source_graph.edge(src, dst).and_then(|edge| {
let id = edge.edge.pid().0 as u64;
self.template.edge(edge).map(|doc| (id, doc))
})
})
.unzip();
let vectors = self.model.get_embeddings(docs).await?;
self.edge_db.insert_vectors(ids, vectors).await?;
Ok(())
}
pub async fn optimize_index(&self) -> GraphResult<()> {
self.node_db.create_or_update_index().await?;
self.edge_db.create_or_update_index().await?;
Ok(())
}
pub fn empty_selection(&self) -> VectorSelection<G> {
VectorSelection::empty(self.clone())
}
pub fn entities_by_similarity(
&self,
query: &Embedding,
limit: usize,
window: Option<(i64, i64)>,
) -> VectorsQuery<GraphResult<VectorSelection<G>>> {
let view = apply_window(&self.source_graph, window);
let node_query = self.node_db.top_k(query, limit, view.clone(), None);
let edge_query = self.edge_db.top_k(query, limit, view, None);
let cloned = self.clone();
VectorsQuery::new(Box::pin(async move {
println!("executing node similarity query");
let nodes = node_query.execute().await?;
let edges = edge_query.execute().await?;
let docs = find_top_k(nodes.into_iter().chain(edges), limit).collect::<Vec<_>>();
Ok(VectorSelection::new(cloned, docs))
}))
}
pub fn nodes_by_similarity(
&self,
query: &Embedding,
limit: usize,
window: Option<(i64, i64)>,
) -> VectorsQuery<GraphResult<VectorSelection<G>>> {
let view = apply_window(&self.source_graph, window);
let query = self.node_db.top_k(query, limit, view, None);
let cloned = self.clone();
VectorsQuery::new(Box::pin(async move {
let docs = query.execute().await?;
Ok(VectorSelection::new(cloned, docs))
}))
}
pub fn edges_by_similarity(
&self,
query: &Embedding,
limit: usize,
window: Option<(i64, i64)>,
) -> VectorsQuery<GraphResult<VectorSelection<G>>> {
let view = apply_window(&self.source_graph, window);
let query = self.edge_db.top_k(query, limit, view, None);
let cloned = self.clone();
VectorsQuery::new(Box::pin(async move {
let docs = query.execute().await?;
Ok(VectorSelection::new(cloned, docs))
}))
}
pub async fn embed_text<T: Into<String>>(&self, text: T) -> GraphResult<Embedding> {
self.model.get_single(text.into()).await
}
pub fn model(&self) -> &CachedEmbeddingModel {
&self.model
}
}