1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
use crate::vectordb::error::Result;
use crate::vectordb::provider::{EmbeddingProvider};
// Explicitly import the concrete provider type
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use crate::vectordb::error::VectorDBError;
use std::fmt;
// Use the embedding dimensions from the providers
// use crate::vectordb::provider::fast::FAST_EMBEDDING_DIM;
/// Supported embedding models.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash, Default)]
pub enum EmbeddingModelType {
/// Use the ONNX model for embeddings.
#[default]
Onnx,
// No specific CodeBert type needed if we handle dimensions dynamically
}
impl EmbeddingModelType {
/// Returns the default embedding dimension for this model type.
/// Used as a fallback when loading an index without an explicit dimension stored.
pub fn default_dimension(&self) -> usize {
match self {
// TODO: Make this dynamically configurable or read from a default ONNX model?
// For now, assume the default ONNX is MiniLM with 384 dims.
EmbeddingModelType::Onnx => 384,
}
}
}
impl fmt::Display for EmbeddingModelType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
EmbeddingModelType::Onnx => write!(f, "ONNX"),
}
}
}
/// Represents an embedding model (currently ONNX-based).
#[derive(Clone, Debug)]
pub struct EmbeddingModel {
provider: Arc<dyn EmbeddingProvider + Send + Sync>,
model_type: EmbeddingModelType,
onnx_model_path: Option<PathBuf>,
onnx_tokenizer_path: Option<PathBuf>,
}
impl EmbeddingModel {
/// Creates a new ONNX-based EmbeddingModel.
pub fn new_onnx<P: AsRef<Path>>(model_path: P, tokenizer_path: P) -> Result<Self> {
// Use full path for the provider constructor
let onnx_provider = crate::vectordb::provider::onnx::OnnxEmbeddingProvider::new(
model_path.as_ref(),
tokenizer_path.as_ref()
).map_err(|e| VectorDBError::EmbeddingError(format!("Failed to create ONNX provider: {}", e)))?; // Explicitly map error
Ok(Self {
provider: Arc::new(onnx_provider),
model_type: EmbeddingModelType::Onnx,
onnx_model_path: Some(model_path.as_ref().to_path_buf()),
onnx_tokenizer_path: Some(tokenizer_path.as_ref().to_path_buf()),
})
}
/// Get the type of the embedding model.
pub fn model_type(&self) -> EmbeddingModelType {
self.model_type
}
/// Get the dimensions of the embeddings generated by this model.
pub fn dim(&self) -> usize {
let provider_ref: &dyn EmbeddingProvider = self.provider.as_ref();
provider_ref.dimension()
}
/// Generates an embedding for the given text.
pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
let provider_ref: &dyn EmbeddingProvider = self.provider.as_ref();
provider_ref.embed(text).map_err(Into::into)
}
/// Generates embeddings for a batch of texts.
pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let provider_ref: &dyn EmbeddingProvider = self.provider.as_ref();
provider_ref.embed_batch(texts).map_err(Into::into)
}
}
#[cfg(test)]
mod tests {
use super::*;
// use crate::vectordb::error::VectorDBError; // Removed unused import
use std::path::PathBuf; // Added for invalid path test
// Removed test_fast_embedding
// #[test]
// fn test_fast_embedding() { ... }
// Removed test_embedding_batch (it used the default FastText model)
// #[test]
// fn test_embedding_batch() { ... }
// Keep test_onnx_embedding_fallback
#[test]
fn test_onnx_embedding_fallback() {
let model_path = Path::new("onnx/all-minilm-l12-v2.onnx");
let tokenizer_path = Path::new("onnx/minilm_tokenizer.json");
// Skip test if ONNX files don't exist
if !model_path.exists() || !tokenizer_path.exists() {
println!("Skipping ONNX test because model files aren't available");
return;
}
// Create ONNX model
let onnx_model = EmbeddingModel::new_onnx(model_path, tokenizer_path);
assert!(onnx_model.is_ok());
let model = onnx_model.unwrap();
let expected_dim = model.dim(); // Get dimension from model
// Test embedding
let text = "fn main() { let x = 42; }";
let embedding = model.embed(text).unwrap();
assert_eq!(embedding.len(), expected_dim); // Check against model's dimension
assert!(!embedding.iter().all(|&x| x == 0.0));
// Test cloning
let cloned_model = model.clone();
assert_eq!(cloned_model.dim(), expected_dim);
let cloned_embedding = cloned_model.embed(text).unwrap();
assert_eq!(embedding, cloned_embedding);
}
// Removed test_model_cloning (it used the default FastText model)
// #[test]
// fn test_model_cloning() { ... }
#[test]
fn test_embedding_model_type_display() {
assert_eq!(EmbeddingModelType::Onnx.to_string(), "ONNX");
// Add other types here if they exist
}
// Removed tests for FromStr (not implemented)
// #[test]
// fn test_embedding_model_type_from_str_valid() { ... }
// #[test]
// fn test_embedding_model_type_from_str_invalid() { ... }
// --- EmbeddingModel Tests (Error paths only for now) ---
// Mock/Dummy ONNX provider needed for deeper testing
// #[test]
// fn test_embedding_model_new_onnx_valid() { ... }
#[test]
fn test_embedding_model_new_onnx_invalid_path() {
// Use paths known not to exist
let model_path = PathBuf::from("./nonexistent/model.onnx");
let tokenizer_path = PathBuf::from("./nonexistent/tokenizer.json");
// This check relies on the underlying `OnnxProvider::new` failing
// We expect an EmbeddingError wrapping the provider's error
let result = EmbeddingModel::new_onnx(&model_path, &tokenizer_path);
assert!(matches!(result, Err(VectorDBError::EmbeddingError(_))));
// We can't easily assert the inner error message without a real provider error
}
// Mock Provider needed to test `generate_embeddings`
// #[test]
// fn test_embedding_model_generate_embeddings() { ... }
}