use super::{
cache::VectorCache,
entity_db::{EdgeDb, NodeDb},
template::DocumentTemplate,
vectorised_graph::VectorisedGraph,
};
use crate::{
db::api::view::StaticGraphViewOps,
errors::{GraphError, GraphResult},
vectors::{
embeddings::ModelConfig,
vector_collection::{lancedb::LanceDb, VectorCollectionFactory},
},
};
use async_openai::config::{OpenAIConfig, OPENAI_API_BASE};
use serde::{Deserialize, Serialize};
use std::{
fs::File,
path::{Path, PathBuf},
sync::Arc,
};
use tokio::sync::OnceCell;
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)]
pub struct OpenAIEmbeddings {
pub model: String,
pub api_base: Option<String>,
pub api_key_env: Option<String>,
pub org_id: Option<String>,
pub project_id: Option<String>,
pub dim: Option<usize>,
}
impl OpenAIEmbeddings {
pub fn empty(name: impl AsRef<str>) -> Self {
Self {
model: name.as_ref().to_owned(),
api_base: None,
api_key_env: None,
org_id: None,
project_id: None,
dim: None,
}
}
pub fn new(model: impl AsRef<str>, api_base: impl AsRef<str>) -> Self {
Self {
model: model.as_ref().to_owned(),
api_base: Some(api_base.as_ref().to_owned()),
api_key_env: None,
org_id: None,
project_id: None,
dim: None,
}
}
pub(super) fn resolve_config(&self) -> OpenAIConfig {
let api_key_env = self
.api_key_env
.clone()
.unwrap_or("OPENAI_API_KEY".to_owned());
let api_key = std::env::var(api_key_env).unwrap_or_default();
let api_base = self.api_base.clone().unwrap_or(OPENAI_API_BASE.to_owned());
OpenAIConfig::new()
.with_api_base(api_base)
.with_api_key(api_key)
.with_org_id(self.org_id.clone().unwrap_or_default())
.with_project_id(self.project_id.clone().unwrap_or_default())
}
}
#[derive(Serialize, Deserialize, Debug)]
pub(super) struct VectorMeta {
pub(super) template: DocumentTemplate,
pub(super) model: ModelConfig,
}
impl VectorMeta {
pub(super) fn write_to_path(&self, path: &Path) -> Result<(), GraphError> {
let file = File::create(meta_path(path))?;
serde_json::to_writer(file, self)?;
Ok(())
}
pub(super) async fn read_from_path(path: &Path) -> GraphResult<Self> {
let meta_string = std::fs::read_to_string(path)?;
let meta: VectorMeta = serde_json::from_str(&meta_string)?;
Ok(meta)
}
}
#[derive(Clone)]
pub struct LazyDiskVectorCache {
path: PathBuf,
cache: OnceCell<VectorCache>,
}
impl LazyDiskVectorCache {
pub fn new(path: PathBuf) -> Self {
Self {
path,
cache: Default::default(),
}
}
pub async fn resolve(&self) -> GraphResult<&VectorCache> {
self.cache
.get_or_try_init(async || VectorCache::on_disk(&self.path.clone()).await)
.await
}
}
impl<G: StaticGraphViewOps> VectorisedGraph<G> {
pub async fn read_from_path(
path: &Path,
graph: G,
cache: &LazyDiskVectorCache,
) -> GraphResult<Self> {
let meta = VectorMeta::read_from_path(&meta_path(path)).await?;
let factory = LanceDb;
let db_path = Arc::new(db_path(path));
let resolved = cache.resolve().await?;
let model = resolved.validate_and_set_dim(meta.model).await?;
let dim = model.dim().ok_or_else(|| GraphError::UnresolvedModel)?;
let node_db = NodeDb(factory.from_path(db_path.clone(), "nodes", dim).await?);
let edge_db = EdgeDb(factory.from_path(db_path, "edges", dim).await?);
Ok(VectorisedGraph {
template: meta.template,
source_graph: graph,
model,
node_db,
edge_db,
})
}
}
fn meta_path(path: &Path) -> PathBuf {
path.join("meta")
}
pub(super) fn db_path(path: &Path) -> PathBuf {
path.join("db")
}
#[cfg(test)]
mod vector_storage_tests {
}