agcodex_core/context_engine/
embeddings.rs1use crate::config::Config;
8use crate::embeddings::EmbeddingsConfig;
9use crate::embeddings::EmbeddingsManager;
10use thiserror::Error;
11
12pub type EmbeddingVector = Vec<f32>;
13
14#[derive(Debug, Error)]
15pub enum EmbeddingError {
16 #[error("not implemented")]
17 NotImplemented,
18
19 #[error(transparent)]
20 Io(#[from] std::io::Error),
21
22 #[error("provider error: {0}")]
23 Provider(String),
24
25 #[error("embeddings error: {0}")]
26 EmbeddingsError(#[from] crate::embeddings::EmbeddingError),
27}
28
29#[allow(async_fn_in_trait)]
30pub trait EmbeddingModel {
31 fn dimensions(&self) -> usize {
32 1536 }
34
35 async fn embed(&self, _text: &str) -> Result<EmbeddingVector, EmbeddingError> {
36 Err(EmbeddingError::NotImplemented)
37 }
38
39 async fn embed_batch(&self, _texts: &[String]) -> Result<Vec<EmbeddingVector>, EmbeddingError> {
40 Err(EmbeddingError::NotImplemented)
41 }
42}
43
44pub struct EmbeddingModelBridge {
67 manager: Option<EmbeddingsManager>,
69}
70
71impl EmbeddingModelBridge {
72 pub const fn new(manager: Option<EmbeddingsManager>) -> Self {
74 Self { manager }
75 }
76
77 pub fn from_embeddings_config(config: Option<EmbeddingsConfig>) -> Self {
82 let manager = EmbeddingsManager::new(config);
83 Self::new(Some(manager))
84 }
85
86 pub const fn from_config(_config: &Config) -> Self {
92 Self::new(None)
98 }
99
100 pub const fn disabled() -> Self {
102 Self::new(None)
103 }
104
105 pub fn is_enabled(&self) -> bool {
107 self.manager
108 .as_ref()
109 .map(|m| m.is_enabled())
110 .unwrap_or(false)
111 }
112}
113
114impl EmbeddingModel for EmbeddingModelBridge {
115 fn dimensions(&self) -> usize {
116 if let Some(manager) = &self.manager {
117 manager.current_dimensions().unwrap_or(1536)
118 } else {
119 1536 }
121 }
122
123 async fn embed(&self, text: &str) -> Result<EmbeddingVector, EmbeddingError> {
124 match &self.manager {
125 Some(manager) => {
126 if !manager.is_enabled() {
127 return Err(EmbeddingError::NotImplemented);
128 }
129
130 let result = manager.embed(text).await?;
131 match result {
132 Some(vector) => Ok(vector),
133 None => Err(EmbeddingError::NotImplemented),
134 }
135 }
136 None => Err(EmbeddingError::NotImplemented),
137 }
138 }
139
140 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<EmbeddingVector>, EmbeddingError> {
141 match &self.manager {
142 Some(manager) => {
143 if !manager.is_enabled() {
144 return Err(EmbeddingError::NotImplemented);
145 }
146
147 let result = manager.embed_batch(texts).await?;
148 match result {
149 Some(vectors) => Ok(vectors),
150 None => Err(EmbeddingError::NotImplemented),
151 }
152 }
153 None => Err(EmbeddingError::NotImplemented),
154 }
155 }
156}
157
158pub struct NoOpEmbeddingModel;
162
163impl EmbeddingModel for NoOpEmbeddingModel {
164 fn dimensions(&self) -> usize {
165 1536
166 }
167
168 async fn embed(&self, _text: &str) -> Result<EmbeddingVector, EmbeddingError> {
169 Err(EmbeddingError::NotImplemented)
170 }
171
172 async fn embed_batch(&self, _texts: &[String]) -> Result<Vec<EmbeddingVector>, EmbeddingError> {
173 Err(EmbeddingError::NotImplemented)
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180
181 #[test]
182 fn test_disabled_bridge_has_zero_overhead() {
183 let bridge = EmbeddingModelBridge::disabled();
184 assert!(!bridge.is_enabled());
185 assert_eq!(bridge.dimensions(), 1536);
186 }
187
188 #[tokio::test]
189 async fn test_disabled_bridge_returns_not_implemented() {
190 let bridge = EmbeddingModelBridge::disabled();
191
192 let result = bridge.embed("test").await;
193 assert!(matches!(result, Err(EmbeddingError::NotImplemented)));
194
195 let batch_result = bridge.embed_batch(&["test".to_string()]).await;
196 assert!(matches!(batch_result, Err(EmbeddingError::NotImplemented)));
197 }
198
199 #[test]
204 fn test_noop_embedding_model() {
205 let model = NoOpEmbeddingModel;
206 assert_eq!(model.dimensions(), 1536);
207 }
208
209 #[tokio::test]
210 async fn test_noop_embedding_model_returns_not_implemented() {
211 let model = NoOpEmbeddingModel;
212
213 let result = model.embed("test").await;
214 assert!(matches!(result, Err(EmbeddingError::NotImplemented)));
215
216 let batch_result = model.embed_batch(&["test".to_string()]).await;
217 assert!(matches!(batch_result, Err(EmbeddingError::NotImplemented)));
218 }
219}