1use std::fmt;
8use std::time::Duration;
9
10use async_trait::async_trait;
11use serde::{Deserialize, Serialize};
12
13#[non_exhaustive]
19#[derive(Debug)]
20pub enum EmbeddingError {
21 ModelNotLoaded,
23 DimensionMismatch {
25 expected: usize,
27 got: usize,
29 },
30 BackendError(String),
32 RateLimited {
34 retry_after: Duration,
36 },
37}
38
39impl fmt::Display for EmbeddingError {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 match self {
42 Self::ModelNotLoaded => write!(f, "embedding model not loaded"),
43 Self::DimensionMismatch { expected, got } => {
44 write!(f, "dimension mismatch: expected {expected}, got {got}")
45 }
46 Self::BackendError(msg) => write!(f, "embedding backend error: {msg}"),
47 Self::RateLimited { retry_after } => {
48 write!(f, "rate limited, retry after {}ms", retry_after.as_millis())
49 }
50 }
51 }
52}
53
54impl std::error::Error for EmbeddingError {}
55
56#[async_trait]
65pub trait EmbeddingProvider: Send + Sync {
66 async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError>;
68
69 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
75 let mut results = Vec::with_capacity(texts.len());
76 for text in texts {
77 results.push(self.embed(text).await?);
78 }
79 Ok(results)
80 }
81
82 fn dimensions(&self) -> usize;
84
85 fn model_name(&self) -> &str;
87}
88
89pub struct MockEmbeddingProvider {
98 pub dims: usize,
100}
101
102impl MockEmbeddingProvider {
103 pub fn new(dims: usize) -> Self {
105 Self { dims }
106 }
107
108 fn hash_embed(&self, text: &str) -> Vec<f32> {
110 use sha2::{Digest, Sha256};
111 let mut hasher = Sha256::new();
112 hasher.update(text.as_bytes());
113 let hash = hasher.finalize();
114
115 let mut vec = Vec::with_capacity(self.dims);
116 for i in 0..self.dims {
117 let byte = hash[i % 32];
119 vec.push((byte as f32 / 128.0) - 1.0);
120 }
121 vec
122 }
123}
124
125#[async_trait]
126impl EmbeddingProvider for MockEmbeddingProvider {
127 async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
128 Ok(self.hash_embed(text))
129 }
130
131 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
132 Ok(texts.iter().map(|t| self.hash_embed(t)).collect())
133 }
134
135 fn dimensions(&self) -> usize {
136 self.dims
137 }
138
139 fn model_name(&self) -> &str {
140 "mock-sha256"
141 }
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct LlmEmbeddingConfig {
151 pub model: String,
153 pub dimensions: usize,
155 pub batch_size: usize,
157 pub api_available: bool,
159}
160
161impl Default for LlmEmbeddingConfig {
162 fn default() -> Self {
163 Self {
164 model: "text-embedding-3-small".to_string(),
165 dimensions: 384,
166 batch_size: 16,
167 api_available: false,
168 }
169 }
170}
171
172pub struct LlmEmbeddingProvider {
178 config: LlmEmbeddingConfig,
179 fallback: MockEmbeddingProvider,
180}
181
182impl LlmEmbeddingProvider {
183 pub fn new(config: LlmEmbeddingConfig) -> Self {
185 let fallback = MockEmbeddingProvider::new(config.dimensions);
186 Self { config, fallback }
187 }
188
189 pub fn from_config(table: &std::collections::HashMap<String, String>) -> Self {
195 let model = table
196 .get("model")
197 .cloned()
198 .unwrap_or_else(|| "text-embedding-3-small".to_string());
199 let dimensions = table
200 .get("dimensions")
201 .and_then(|d| d.parse::<usize>().ok())
202 .unwrap_or(384);
203 let batch_size = table
204 .get("batch_size")
205 .and_then(|b| b.parse::<usize>().ok())
206 .unwrap_or(16);
207 let api_available = table
208 .get("api_available")
209 .map(|v| v == "true")
210 .unwrap_or(false);
211
212 Self::new(LlmEmbeddingConfig {
213 model,
214 dimensions,
215 batch_size,
216 api_available,
217 })
218 }
219
220 pub fn is_api_available(&self) -> bool {
222 self.config.api_available
223 }
224
225 pub fn config(&self) -> &LlmEmbeddingConfig {
227 &self.config
228 }
229
230 async fn call_llm_api(&self, _text: &str) -> Result<Vec<f32>, EmbeddingError> {
236 if !self.config.api_available {
237 return Err(EmbeddingError::BackendError(
238 "LLM API not configured; using fallback".to_string(),
239 ));
240 }
241 Err(EmbeddingError::BackendError(
245 "LLM API call not yet wired to clawft-llm provider".to_string(),
246 ))
247 }
248
249 async fn call_llm_api_batch(
251 &self,
252 _texts: &[&str],
253 ) -> Result<Vec<Vec<f32>>, EmbeddingError> {
254 if !self.config.api_available {
255 return Err(EmbeddingError::BackendError(
256 "LLM API not configured; using fallback".to_string(),
257 ));
258 }
259 Err(EmbeddingError::BackendError(
260 "LLM API batch call not yet wired to clawft-llm provider".to_string(),
261 ))
262 }
263}
264
265#[async_trait]
266impl EmbeddingProvider for LlmEmbeddingProvider {
267 async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
268 match self.call_llm_api(text).await {
270 Ok(vec) => {
271 if vec.len() != self.config.dimensions {
272 return Err(EmbeddingError::DimensionMismatch {
273 expected: self.config.dimensions,
274 got: vec.len(),
275 });
276 }
277 Ok(vec)
278 }
279 Err(_) => self.fallback.embed(text).await,
280 }
281 }
282
283 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
284 match self.call_llm_api_batch(texts).await {
286 Ok(vecs) => {
287 for v in &vecs {
288 if v.len() != self.config.dimensions {
289 return Err(EmbeddingError::DimensionMismatch {
290 expected: self.config.dimensions,
291 got: v.len(),
292 });
293 }
294 }
295 Ok(vecs)
296 }
297 Err(_) => self.fallback.embed_batch(texts).await,
298 }
299 }
300
301 fn dimensions(&self) -> usize {
302 self.config.dimensions
303 }
304
305 fn model_name(&self) -> &str {
306 &self.config.model
307 }
308}
309
310pub fn select_embedding_provider(
321 llm_config: Option<LlmEmbeddingConfig>,
322) -> Box<dyn EmbeddingProvider> {
323 let onnx_paths = onnx_model_search_paths();
325 for path in &onnx_paths {
326 if path.exists() {
327 let provider = crate::embedding_onnx::OnnxEmbeddingProvider::new(path);
328 if provider.is_runtime_available() {
329 tracing::info!("Using ONNX embedding provider from {}", path.display());
330 return Box::new(provider);
331 }
332 }
333 }
334
335 if let Some(config) = llm_config {
336 return Box::new(LlmEmbeddingProvider::new(config));
337 }
338 Box::new(MockEmbeddingProvider::new(64))
339}
340
341fn onnx_model_search_paths() -> Vec<std::path::PathBuf> {
348 let model_name = "all-MiniLM-L6-v2.onnx";
349 let mut paths = Vec::new();
350
351 paths.push(std::path::PathBuf::from(format!(".weftos/models/{model_name}")));
353
354 if let Some(home) = dirs_home() {
356 paths.push(home.join(format!(".weftos/models/{model_name}")));
357 }
358
359 if let Ok(env_path) = std::env::var("WEFTOS_MODEL_PATH") {
361 paths.push(std::path::PathBuf::from(env_path));
362 }
363
364 paths
365}
366
367fn dirs_home() -> Option<std::path::PathBuf> {
369 std::env::var("HOME")
370 .ok()
371 .map(std::path::PathBuf::from)
372}
373
374#[cfg(test)]
377mod tests {
378 use super::*;
379
380 #[tokio::test]
381 async fn mock_embed_returns_correct_dimensions() {
382 let provider = MockEmbeddingProvider::new(64);
383 let vec = provider.embed("hello world").await.unwrap();
384 assert_eq!(vec.len(), 64);
385 }
386
387 #[tokio::test]
388 async fn mock_embed_deterministic() {
389 let provider = MockEmbeddingProvider::new(32);
390 let v1 = provider.embed("test input").await.unwrap();
391 let v2 = provider.embed("test input").await.unwrap();
392 assert_eq!(v1, v2);
393 }
394
395 #[tokio::test]
396 async fn mock_embed_different_inputs_differ() {
397 let provider = MockEmbeddingProvider::new(32);
398 let v1 = provider.embed("alpha").await.unwrap();
399 let v2 = provider.embed("beta").await.unwrap();
400 assert_ne!(v1, v2);
401 }
402
403 #[tokio::test]
404 async fn mock_embed_batch() {
405 let provider = MockEmbeddingProvider::new(16);
406 let results = provider.embed_batch(&["a", "b", "c"]).await.unwrap();
407 assert_eq!(results.len(), 3);
408 for v in &results {
409 assert_eq!(v.len(), 16);
410 }
411 }
412
413 #[tokio::test]
414 async fn mock_embed_batch_matches_individual() {
415 let provider = MockEmbeddingProvider::new(8);
416 let batch = provider.embed_batch(&["x", "y"]).await.unwrap();
417 let x = provider.embed("x").await.unwrap();
418 let y = provider.embed("y").await.unwrap();
419 assert_eq!(batch[0], x);
420 assert_eq!(batch[1], y);
421 }
422
423 #[test]
424 fn mock_model_name() {
425 let provider = MockEmbeddingProvider::new(16);
426 assert_eq!(provider.model_name(), "mock-sha256");
427 }
428
429 #[test]
430 fn mock_dimensions() {
431 let provider = MockEmbeddingProvider::new(128);
432 assert_eq!(provider.dimensions(), 128);
433 }
434
435 #[test]
436 fn embedding_error_display() {
437 let err = EmbeddingError::DimensionMismatch {
438 expected: 384,
439 got: 256,
440 };
441 assert!(err.to_string().contains("384"));
442 assert!(err.to_string().contains("256"));
443
444 let err2 = EmbeddingError::ModelNotLoaded;
445 assert!(err2.to_string().contains("not loaded"));
446 }
447
448 #[tokio::test]
451 async fn llm_provider_falls_back_to_mock_when_api_unavailable() {
452 let config = LlmEmbeddingConfig {
453 api_available: false,
454 dimensions: 64,
455 ..Default::default()
456 };
457 let provider = LlmEmbeddingProvider::new(config);
458 let vec = provider.embed("hello world").await.unwrap();
460 assert_eq!(vec.len(), 64);
461 }
462
463 #[tokio::test]
464 async fn llm_provider_fallback_is_deterministic() {
465 let config = LlmEmbeddingConfig {
466 api_available: false,
467 dimensions: 32,
468 ..Default::default()
469 };
470 let provider = LlmEmbeddingProvider::new(config);
471 let v1 = provider.embed("test").await.unwrap();
472 let v2 = provider.embed("test").await.unwrap();
473 assert_eq!(v1, v2);
474 }
475
476 #[tokio::test]
477 async fn llm_provider_batch_fallback() {
478 let config = LlmEmbeddingConfig {
479 api_available: false,
480 dimensions: 16,
481 ..Default::default()
482 };
483 let provider = LlmEmbeddingProvider::new(config);
484 let results = provider.embed_batch(&["a", "b", "c"]).await.unwrap();
485 assert_eq!(results.len(), 3);
486 for v in &results {
487 assert_eq!(v.len(), 16);
488 }
489 }
490
491 #[test]
492 fn llm_provider_reports_model_name() {
493 let config = LlmEmbeddingConfig {
494 model: "custom-embed-v1".to_string(),
495 ..Default::default()
496 };
497 let provider = LlmEmbeddingProvider::new(config);
498 assert_eq!(provider.model_name(), "custom-embed-v1");
499 }
500
501 #[test]
502 fn llm_provider_reports_dimensions() {
503 let config = LlmEmbeddingConfig {
504 dimensions: 1536,
505 ..Default::default()
506 };
507 let provider = LlmEmbeddingProvider::new(config);
508 assert_eq!(provider.dimensions(), 1536);
509 }
510
511 #[test]
512 fn llm_provider_api_availability_check() {
513 let unavailable = LlmEmbeddingProvider::new(LlmEmbeddingConfig::default());
514 assert!(!unavailable.is_api_available());
515
516 let available = LlmEmbeddingProvider::new(LlmEmbeddingConfig {
517 api_available: true,
518 ..Default::default()
519 });
520 assert!(available.is_api_available());
521 }
522
523 #[test]
524 fn llm_provider_from_config_defaults() {
525 let table = std::collections::HashMap::new();
526 let provider = LlmEmbeddingProvider::from_config(&table);
527 assert_eq!(provider.dimensions(), 384);
528 assert_eq!(provider.model_name(), "text-embedding-3-small");
529 assert!(!provider.is_api_available());
530 }
531
532 #[test]
533 fn llm_provider_from_config_custom() {
534 let mut table = std::collections::HashMap::new();
535 table.insert("model".to_string(), "my-model".to_string());
536 table.insert("dimensions".to_string(), "768".to_string());
537 table.insert("batch_size".to_string(), "32".to_string());
538 table.insert("api_available".to_string(), "true".to_string());
539 let provider = LlmEmbeddingProvider::from_config(&table);
540 assert_eq!(provider.model_name(), "my-model");
541 assert_eq!(provider.dimensions(), 768);
542 assert_eq!(provider.config().batch_size, 32);
543 assert!(provider.is_api_available());
544 }
545
546 #[test]
547 fn select_provider_returns_mock_when_no_config() {
548 let provider = select_embedding_provider(None);
549 assert_eq!(provider.dimensions(), 64);
550 assert_eq!(provider.model_name(), "mock-sha256");
551 }
552
553 #[test]
554 fn select_provider_returns_llm_when_config_present() {
555 let config = LlmEmbeddingConfig {
556 model: "test-embed".to_string(),
557 dimensions: 256,
558 ..Default::default()
559 };
560 let provider = select_embedding_provider(Some(config));
561 assert_eq!(provider.dimensions(), 256);
562 assert_eq!(provider.model_name(), "test-embed");
563 }
564
565 #[tokio::test]
566 async fn llm_provider_fallback_matches_mock() {
567 let config = LlmEmbeddingConfig {
568 api_available: false,
569 dimensions: 32,
570 ..Default::default()
571 };
572 let llm = LlmEmbeddingProvider::new(config);
573 let mock = MockEmbeddingProvider::new(32);
574 let llm_vec = llm.embed("same input").await.unwrap();
575 let mock_vec = mock.embed("same input").await.unwrap();
576 assert_eq!(llm_vec, mock_vec);
578 }
579}