1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
use crate::vectordb::error::Result;
use crate::vectordb::provider::{EmbeddingProvider};
// Explicitly import the concrete provider type
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use crate::vectordb::error::VectorDBError;
use std::fmt;
use crate::config::AppConfig;
#[cfg(feature = "ort")]
use crate::vectordb::provider::onnx::OnnxEmbeddingModel;
// Use the embedding dimensions from the providers
// use crate::vectordb::provider::fast::FAST_EMBEDDING_DIM;
/// Type alias for embeddings (vectors of f32).
// pub type Embedding = Vec<f32>; // Keep commented out or remove if confirmed unused
/// Enum representing the type of embedding model to use.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum EmbeddingModelType {
#[default]
Default, // Represents the default model
Onnx,
// Add other model types here in the future (e.g., SentenceTransformers)
}
impl EmbeddingModelType {
/// Get the expected dimension for the model type.
pub fn dimension(&self) -> usize {
match self {
EmbeddingModelType::Onnx => 384,
EmbeddingModelType::Default => 384, // Use default dimension for now
// Add other model types here
}
}
}
impl fmt::Display for EmbeddingModelType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
EmbeddingModelType::Onnx => write!(f, "ONNX"),
EmbeddingModelType::Default => write!(f, "Default"),
}
}
}
/// Represents an embedding model (currently ONNX-based).
#[derive(Clone, Debug)]
pub struct EmbeddingModel {
provider: Arc<dyn EmbeddingProvider + Send + Sync>,
model_type: EmbeddingModelType,
onnx_model_path: Option<PathBuf>,
onnx_tokenizer_path: Option<PathBuf>,
}
impl EmbeddingModel {
/// Creates a new ONNX-based EmbeddingModel.
pub fn new_onnx<P: AsRef<Path>>(model_path: P, tokenizer_path: P) -> Result<Self> {
// Use full path for the provider constructor
let onnx_provider = crate::vectordb::provider::onnx::OnnxEmbeddingModel::new(
model_path.as_ref(),
tokenizer_path.as_ref()
).map_err(|e| VectorDBError::EmbeddingError(format!("Failed to create ONNX provider: {}", e)))?; // Explicitly map error
Ok(Self {
provider: Arc::new(onnx_provider),
model_type: EmbeddingModelType::Onnx,
onnx_model_path: Some(model_path.as_ref().to_path_buf()),
onnx_tokenizer_path: Some(tokenizer_path.as_ref().to_path_buf()),
})
}
/// Get the type of the embedding model.
pub fn model_type(&self) -> EmbeddingModelType {
self.model_type
}
/// Get the dimensions of the embeddings generated by this model.
pub fn dim(&self) -> usize {
let provider_ref: &dyn EmbeddingProvider = self.provider.as_ref();
provider_ref.dimension()
}
/// Generates embeddings for a batch of texts.
pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let provider_ref: &dyn EmbeddingProvider = self.provider.as_ref();
provider_ref.embed_batch(texts).map_err(Into::into)
}
}
#[cfg(test)]
mod tests {
use super::*;
// use crate::vectordb::error::VectorDBError; // Removed unused import
use std::path::PathBuf; // Added for invalid path test
// Removed test_fast_embedding
// #[test]
// fn test_fast_embedding() { ... }
// Removed test_embedding_batch (it used the default FastText model)
// #[test]
// fn test_embedding_batch() { ... }
// Keep test_onnx_embedding_fallback
#[test]
fn test_onnx_embedding_fallback() {
let model_path = Path::new("onnx/all-minilm-l12-v2.onnx");
let tokenizer_path = Path::new("onnx/minilm_tokenizer.json");
// Skip test if ONNX files don't exist
if !model_path.exists() || !tokenizer_path.exists() {
println!("Skipping ONNX test because model files aren't available");
return;
}
// Create ONNX model
let onnx_model = EmbeddingModel::new_onnx(model_path, tokenizer_path);
assert!(onnx_model.is_ok());
let model = onnx_model.unwrap();
let expected_dim = model.dim(); // Get dimension from model
// Test embedding
let text = "fn main() { let x = 42; }";
let embedding = model.embed_batch(&[text]).unwrap().into_iter().next().unwrap();
assert_eq!(embedding.len(), expected_dim); // Check against model's dimension
assert!(!embedding.iter().all(|&x| x == 0.0));
// Test cloning
let cloned_model = model.clone();
assert_eq!(cloned_model.dim(), expected_dim);
let cloned_embedding = cloned_model.embed_batch(&[text]).unwrap().into_iter().next().unwrap();
assert_eq!(embedding, cloned_embedding);
}
// Removed test_model_cloning (it used the default FastText model)
// #[test]
// fn test_model_cloning() { ... }
#[test]
fn test_embedding_model_type_display() {
assert_eq!(EmbeddingModelType::Onnx.to_string(), "ONNX");
// Add other types here if they exist
}
// Removed tests for FromStr (not implemented)
// #[test]
// fn test_embedding_model_type_from_str_valid() { ... }
// #[test]
// fn test_embedding_model_type_from_str_invalid() { ... }
// --- EmbeddingModel Tests (Error paths only for now) ---
// Mock/Dummy ONNX provider needed for deeper testing
// #[test]
// fn test_embedding_model_new_onnx_valid() { ... }
#[test]
fn test_embedding_model_new_onnx_invalid_path() {
// Use paths known not to exist
let model_path = PathBuf::from("./nonexistent/model.onnx");
let tokenizer_path = PathBuf::from("./nonexistent/tokenizer.json");
// This check relies on the underlying `OnnxProvider::new` failing
// We expect an EmbeddingError wrapping the provider's error
let result = EmbeddingModel::new_onnx(&model_path, &tokenizer_path);
assert!(matches!(result, Err(VectorDBError::EmbeddingError(_))));
// We can't easily assert the inner error message without a real provider error
}
// Mock Provider needed to test `generate_embeddings`
// #[test]
// fn test_embedding_model_generate_embeddings() { ... }
}
pub fn initialize_provider(
config: &AppConfig,
) -> std::result::Result<Arc<dyn EmbeddingProvider + Send + Sync>, VectorDBError> {
// Determine model type - currently only supports ONNX/Default
// In the future, this could check a field like `config.embedding_model.model_type`
let model_type = EmbeddingModelType::Onnx; // Assume ONNX for now
match model_type {
EmbeddingModelType::Default | EmbeddingModelType::Onnx => {
#[cfg(feature = "ort")]
{
// Access AppConfig fields directly
let model_path = config.onnx_model_path.as_deref()
.ok_or_else(|| VectorDBError::ConfigurationError("ONNX model path not set in AppConfig".to_string()))?;
let tokenizer_path = config.onnx_tokenizer_path.as_deref()
.ok_or_else(|| VectorDBError::ConfigurationError("ONNX tokenizer path not set in AppConfig".to_string()))?;
let onnx_provider_result = OnnxEmbeddingModel::new(
Path::new(model_path),
Path::new(tokenizer_path)
);
match onnx_provider_result {
Ok(provider) => Ok(Arc::new(provider)),
// Explicitly convert error just in case
Err(e) => Err(VectorDBError::from(e)),
}
}
#[cfg(not(feature = "ort"))]
{
Err(VectorDBError::FeatureNotEnabled("ort".to_string()))
}
}
// Handle other model types if necessary
}
}