1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
//! Embedding model management for GrafeoDB.
use std::sync::Arc;
use grafeo_common::utils::error::Result;
impl super::GrafeoDB {
// ── Embedding ────────────────────────────────────────────────────────
/// Loads a pre-configured embedding model, downloading from HuggingFace Hub if needed.
///
/// For preset models, the ONNX model and tokenizer are automatically
/// downloaded on first use and cached locally (in `~/.cache/huggingface/`).
/// The model is registered under its display name (e.g., `"all-MiniLM-L6-v2"`).
///
/// # Examples
///
/// ```no_run
/// use grafeo_engine::{GrafeoDB, Config, embedding::EmbeddingModelConfig};
///
/// # fn main() -> grafeo_common::utils::error::Result<()> {
/// let db = GrafeoDB::with_config(Config::in_memory())?;
/// db.load_embedding_model(EmbeddingModelConfig::MiniLmL6v2)?;
/// let vecs = db.embed_text("all-MiniLM-L6-v2", &["hello"])?;
/// # Ok(())
/// # }
/// ```
///
/// # Errors
///
/// Returns an error if the model fails to download or load.
#[cfg(feature = "embed")]
pub fn load_embedding_model(
&self,
config: crate::embedding::EmbeddingModelConfig,
) -> Result<()> {
self.load_embedding_model_with_options(
config,
crate::embedding::EmbeddingOptions::default(),
)
}
/// Loads a pre-configured embedding model with custom options.
///
/// See [`EmbeddingOptions`](crate::embedding::EmbeddingOptions) for
/// batch size and thread configuration.
///
/// # Errors
///
/// Returns an error if the model fails to download, load, or initialize.
#[cfg(feature = "embed")]
pub fn load_embedding_model_with_options(
&self,
config: crate::embedding::EmbeddingModelConfig,
options: crate::embedding::EmbeddingOptions,
) -> Result<()> {
let name = config.display_name();
let model =
crate::embedding::OnnxEmbeddingModel::from_config_with_options(config, options)?;
self.register_embedding_model(&name, Arc::new(model));
Ok(())
}
/// Registers an embedding model for text-to-vector conversion.
///
/// Once registered, you can use [`embed_text()`](Self::embed_text) and
/// [`vector_search_text()`](Self::vector_search_text) with the model name.
#[cfg(feature = "embed")]
pub fn register_embedding_model(
&self,
name: &str,
model: Arc<dyn crate::embedding::EmbeddingModel>,
) {
self.embedding_models
.write()
.insert(name.to_string(), model);
}
/// Generates embeddings for a batch of texts using a registered model.
///
/// # Errors
///
/// Returns an error if the model is not registered or embedding fails.
#[cfg(feature = "embed")]
pub fn embed_text(&self, model_name: &str, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let models = self.embedding_models.read();
let model = models.get(model_name).ok_or_else(|| {
grafeo_common::utils::error::Error::Internal(format!(
"Embedding model '{}' not registered",
model_name
))
})?;
model.embed(texts)
}
/// Searches a vector index using a text query, generating the embedding on-the-fly.
///
/// This combines [`embed_text()`](Self::embed_text) with
/// [`vector_search()`](Self::vector_search) in a single call.
///
/// # Errors
///
/// Returns an error if the model is not registered, embedding fails,
/// or the vector index doesn't exist.
#[cfg(all(feature = "embed", feature = "vector-index"))]
pub fn vector_search_text(
&self,
label: &str,
property: &str,
model_name: &str,
query_text: &str,
k: usize,
ef: Option<usize>,
) -> Result<Vec<(grafeo_common::types::NodeId, f32)>> {
let vectors = self.embed_text(model_name, &[query_text])?;
let query_vec = vectors.into_iter().next().ok_or_else(|| {
grafeo_common::utils::error::Error::Internal(
"Embedding model returned no vectors".to_string(),
)
})?;
self.vector_search(label, property, &query_vec, k, ef, None)
}
}