use super::{
entity_db::{EdgeDb, NodeDb},
storage::{db_path, VectorMeta},
};
use crate::{
db::api::view::{internal::IntoDynamic, StaticGraphViewOps},
errors::{GraphError, GraphResult},
prelude::GraphViewOps,
vectors::{
cache::CachedEmbeddingModel,
embeddings::compute_embeddings,
entity_db::EntityDb,
template::DocumentTemplate,
vector_collection::{
lancedb::LanceDb, CollectionPath, VectorCollection, VectorCollectionFactory,
},
vectorised_graph::VectorisedGraph,
},
};
use async_trait::async_trait;
use std::{path::Path, sync::Arc};
use tracing::info;
#[async_trait]
pub trait Vectorisable<G: StaticGraphViewOps> {
async fn vectorise(
&self,
model: CachedEmbeddingModel,
template: DocumentTemplate,
path: Option<&Path>,
verbose: bool,
) -> GraphResult<VectorisedGraph<G>>;
}
#[async_trait]
impl<G: StaticGraphViewOps + IntoDynamic + Send> Vectorisable<G> for G {
async fn vectorise(
&self,
model: CachedEmbeddingModel,
template: DocumentTemplate,
path: Option<&Path>,
verbose: bool,
) -> GraphResult<VectorisedGraph<G>> {
let db_path = path
.map(|path| Ok::<CollectionPath, std::io::Error>(Arc::new(db_path(path))))
.unwrap_or_else(|| Ok(Arc::new(tempfile::tempdir()?)))?;
let factory = LanceDb;
let dim = model.dim().ok_or_else(|| GraphError::UnresolvedModel)?;
if verbose {
info!("computing embeddings for nodes");
}
let nodes = self.nodes();
let node_docs = nodes
.iter()
.filter_map(|node| template.node(node).map(|doc| (node.node.0 as u64, doc)));
let node_vectors = compute_embeddings(node_docs, &model);
let node_db = NodeDb(
factory
.new_collection(db_path.clone(), "nodes", dim)
.await?,
);
node_db.insert_vector_stream(node_vectors).await.unwrap();
node_db.create_or_update_index().await?;
if verbose {
info!("computing embeddings for edges");
}
let edges = self.edges();
let edge_docs = edges.iter().filter_map(|edge| {
template
.edge(edge)
.map(|doc| (edge.edge.pid().0 as u64, doc))
});
let edge_vectors = compute_embeddings(edge_docs, &model);
let edge_db = EdgeDb(factory.new_collection(db_path, "edges", dim).await?);
edge_db.insert_vector_stream(edge_vectors).await.unwrap();
edge_db.create_or_update_index().await?;
if let Some(path) = path {
let meta = VectorMeta {
template: template.clone(),
model: model.model.clone(),
};
meta.write_to_path(path)?;
}
Ok(VectorisedGraph {
source_graph: self.clone(),
template,
model,
node_db,
edge_db,
})
}
}