Skip to main content

clawft_kernel/
embedding.rs

1//! Pluggable embedding backends for ECC vector operations (K3c-G2).
2//!
3//! Provides the [`EmbeddingProvider`] trait that the [`WeaverEngine`] uses
4//! to convert text into vector embeddings for HNSW storage and similarity
5//! search. Ships with [`MockEmbeddingProvider`] for deterministic testing.
6
7use std::fmt;
8use std::time::Duration;
9
10use async_trait::async_trait;
11use serde::{Deserialize, Serialize};
12
13// ---------------------------------------------------------------------------
14// EmbeddingError
15// ---------------------------------------------------------------------------
16
17/// Errors that embedding backends may produce.
18#[non_exhaustive]
19#[derive(Debug)]
20pub enum EmbeddingError {
21    /// The underlying model has not been loaded yet.
22    ModelNotLoaded,
23    /// Vector dimensionality does not match the expected value.
24    DimensionMismatch {
25        /// Expected dimensionality.
26        expected: usize,
27        /// Actual dimensionality returned.
28        got: usize,
29    },
30    /// Generic backend failure.
31    BackendError(String),
32    /// Rate-limited; caller should retry after the given duration.
33    RateLimited {
34        /// How long to wait before retrying.
35        retry_after: Duration,
36    },
37}
38
39impl fmt::Display for EmbeddingError {
40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41        match self {
42            Self::ModelNotLoaded => write!(f, "embedding model not loaded"),
43            Self::DimensionMismatch { expected, got } => {
44                write!(f, "dimension mismatch: expected {expected}, got {got}")
45            }
46            Self::BackendError(msg) => write!(f, "embedding backend error: {msg}"),
47            Self::RateLimited { retry_after } => {
48                write!(f, "rate limited, retry after {}ms", retry_after.as_millis())
49            }
50        }
51    }
52}
53
54impl std::error::Error for EmbeddingError {}
55
56// ---------------------------------------------------------------------------
57// EmbeddingProvider trait
58// ---------------------------------------------------------------------------
59
60/// Trait for pluggable embedding backends.
61///
62/// Implementations convert text into fixed-dimensionality float vectors
63/// suitable for HNSW indexing and cosine similarity search.
64#[async_trait]
65pub trait EmbeddingProvider: Send + Sync {
66    /// Embed a single text chunk into a vector.
67    async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError>;
68
69    /// Embed a batch of text chunks.
70    ///
71    /// The default implementation calls [`embed`](Self::embed) in a loop.
72    /// Backends that support native batching should override this for
73    /// efficiency.
74    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
75        let mut results = Vec::with_capacity(texts.len());
76        for text in texts {
77            results.push(self.embed(text).await?);
78        }
79        Ok(results)
80    }
81
82    /// Dimensionality of the output vectors.
83    fn dimensions(&self) -> usize;
84
85    /// Name of the embedding model (for metadata tracking).
86    fn model_name(&self) -> &str;
87}
88
89// ---------------------------------------------------------------------------
90// MockEmbeddingProvider
91// ---------------------------------------------------------------------------
92
93/// Deterministic embedding provider for testing.
94///
95/// Produces vectors derived from a SHA-256 hash of the input text,
96/// ensuring reproducible results without any external model dependency.
97pub struct MockEmbeddingProvider {
98    /// Output vector dimensionality.
99    pub dims: usize,
100}
101
102impl MockEmbeddingProvider {
103    /// Create a mock provider with the given output dimensionality.
104    pub fn new(dims: usize) -> Self {
105        Self { dims }
106    }
107
108    /// Deterministic hash-based embedding generation.
109    fn hash_embed(&self, text: &str) -> Vec<f32> {
110        use sha2::{Digest, Sha256};
111        let mut hasher = Sha256::new();
112        hasher.update(text.as_bytes());
113        let hash = hasher.finalize();
114
115        let mut vec = Vec::with_capacity(self.dims);
116        for i in 0..self.dims {
117            // Cycle through hash bytes, normalise to [-1, 1]
118            let byte = hash[i % 32];
119            vec.push((byte as f32 / 128.0) - 1.0);
120        }
121        vec
122    }
123}
124
125#[async_trait]
126impl EmbeddingProvider for MockEmbeddingProvider {
127    async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
128        Ok(self.hash_embed(text))
129    }
130
131    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
132        Ok(texts.iter().map(|t| self.hash_embed(t)).collect())
133    }
134
135    fn dimensions(&self) -> usize {
136        self.dims
137    }
138
139    fn model_name(&self) -> &str {
140        "mock-sha256"
141    }
142}
143
144// ---------------------------------------------------------------------------
145// LlmEmbeddingProvider
146// ---------------------------------------------------------------------------
147
148/// Configuration for the LLM API embedding backend.
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct LlmEmbeddingConfig {
151    /// Model identifier (e.g., "text-embedding-3-small").
152    pub model: String,
153    /// Output vector dimensionality (e.g., 384 or 1536).
154    pub dimensions: usize,
155    /// Maximum texts per API call for batching.
156    pub batch_size: usize,
157    /// Whether the API is currently available.
158    pub api_available: bool,
159}
160
161impl Default for LlmEmbeddingConfig {
162    fn default() -> Self {
163        Self {
164            model: "text-embedding-3-small".to_string(),
165            dimensions: 384,
166            batch_size: 16,
167            api_available: false,
168        }
169    }
170}
171
172/// LLM-backed embedding provider that calls the clawft-llm provider layer.
173///
174/// Uses the model's embedding endpoint to produce real semantic vectors.
175/// When no API is configured (or the API is unavailable), falls back to
176/// [`MockEmbeddingProvider`] for deterministic hash-based embeddings.
177pub struct LlmEmbeddingProvider {
178    config: LlmEmbeddingConfig,
179    fallback: MockEmbeddingProvider,
180}
181
182impl LlmEmbeddingProvider {
183    /// Create a new LLM embedding provider with the given configuration.
184    pub fn new(config: LlmEmbeddingConfig) -> Self {
185        let fallback = MockEmbeddingProvider::new(config.dimensions);
186        Self { config, fallback }
187    }
188
189    /// Create from a weave.toml-style configuration table.
190    ///
191    /// Expected keys: `model` (string), `dimensions` (int), `batch_size` (int).
192    /// If the table is missing or incomplete, returns a provider with defaults
193    /// that falls back to mock embeddings.
194    pub fn from_config(table: &std::collections::HashMap<String, String>) -> Self {
195        let model = table
196            .get("model")
197            .cloned()
198            .unwrap_or_else(|| "text-embedding-3-small".to_string());
199        let dimensions = table
200            .get("dimensions")
201            .and_then(|d| d.parse::<usize>().ok())
202            .unwrap_or(384);
203        let batch_size = table
204            .get("batch_size")
205            .and_then(|b| b.parse::<usize>().ok())
206            .unwrap_or(16);
207        let api_available = table
208            .get("api_available")
209            .map(|v| v == "true")
210            .unwrap_or(false);
211
212        Self::new(LlmEmbeddingConfig {
213            model,
214            dimensions,
215            batch_size,
216            api_available,
217        })
218    }
219
220    /// Whether the LLM API is available (non-fallback mode).
221    pub fn is_api_available(&self) -> bool {
222        self.config.api_available
223    }
224
225    /// Get the underlying configuration.
226    pub fn config(&self) -> &LlmEmbeddingConfig {
227        &self.config
228    }
229
230    /// Perform an LLM API embedding call.
231    ///
232    /// In a production deployment this would call the clawft-llm provider's
233    /// embed endpoint. Currently returns an error so that the `embed()` method
234    /// falls back to the mock provider.
235    async fn call_llm_api(&self, _text: &str) -> Result<Vec<f32>, EmbeddingError> {
236        if !self.config.api_available {
237            return Err(EmbeddingError::BackendError(
238                "LLM API not configured; using fallback".to_string(),
239            ));
240        }
241        // Production implementation would call:
242        //   provider.embed(EmbedRequest { model, input: vec![text], dimensions })
243        // For now, the API path is gated behind api_available.
244        Err(EmbeddingError::BackendError(
245            "LLM API call not yet wired to clawft-llm provider".to_string(),
246        ))
247    }
248
249    /// Perform a batched LLM API embedding call.
250    async fn call_llm_api_batch(
251        &self,
252        _texts: &[&str],
253    ) -> Result<Vec<Vec<f32>>, EmbeddingError> {
254        if !self.config.api_available {
255            return Err(EmbeddingError::BackendError(
256                "LLM API not configured; using fallback".to_string(),
257            ));
258        }
259        Err(EmbeddingError::BackendError(
260            "LLM API batch call not yet wired to clawft-llm provider".to_string(),
261        ))
262    }
263}
264
265#[async_trait]
266impl EmbeddingProvider for LlmEmbeddingProvider {
267    async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
268        // Try the LLM API first; fall back to mock on any error.
269        match self.call_llm_api(text).await {
270            Ok(vec) => {
271                if vec.len() != self.config.dimensions {
272                    return Err(EmbeddingError::DimensionMismatch {
273                        expected: self.config.dimensions,
274                        got: vec.len(),
275                    });
276                }
277                Ok(vec)
278            }
279            Err(_) => self.fallback.embed(text).await,
280        }
281    }
282
283    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
284        // Try the LLM API batch endpoint; fall back to mock.
285        match self.call_llm_api_batch(texts).await {
286            Ok(vecs) => {
287                for v in &vecs {
288                    if v.len() != self.config.dimensions {
289                        return Err(EmbeddingError::DimensionMismatch {
290                            expected: self.config.dimensions,
291                            got: v.len(),
292                        });
293                    }
294                }
295                Ok(vecs)
296            }
297            Err(_) => self.fallback.embed_batch(texts).await,
298        }
299    }
300
301    fn dimensions(&self) -> usize {
302        self.config.dimensions
303    }
304
305    fn model_name(&self) -> &str {
306        &self.config.model
307    }
308}
309
310// ---------------------------------------------------------------------------
311// select_embedding_provider
312// ---------------------------------------------------------------------------
313
314/// Select the best available embedding provider based on configuration.
315///
316/// Priority order:
317/// 1. ONNX local model if `onnx_model_path` points to a valid `.onnx` file
318/// 2. LLM API if llm_embedding config is present
319/// 3. Mock (fallback, for testing or when no backend available)
320pub fn select_embedding_provider(
321    llm_config: Option<LlmEmbeddingConfig>,
322) -> Box<dyn EmbeddingProvider> {
323    // Try ONNX first: check standard model locations.
324    let onnx_paths = onnx_model_search_paths();
325    for path in &onnx_paths {
326        if path.exists() {
327            let provider = crate::embedding_onnx::OnnxEmbeddingProvider::new(path);
328            if provider.is_runtime_available() {
329                tracing::info!("Using ONNX embedding provider from {}", path.display());
330                return Box::new(provider);
331            }
332        }
333    }
334
335    if let Some(config) = llm_config {
336        return Box::new(LlmEmbeddingProvider::new(config));
337    }
338    Box::new(MockEmbeddingProvider::new(64))
339}
340
341/// Standard search paths for the ONNX embedding model.
342///
343/// Looks in (in order):
344/// 1. `.weftos/models/all-MiniLM-L6-v2.onnx` (project-local)
345/// 2. `$HOME/.weftos/models/all-MiniLM-L6-v2.onnx` (user-global)
346/// 3. `$WEFTOS_MODEL_PATH` environment variable
347fn onnx_model_search_paths() -> Vec<std::path::PathBuf> {
348    let model_name = "all-MiniLM-L6-v2.onnx";
349    let mut paths = Vec::new();
350
351    // Project-local.
352    paths.push(std::path::PathBuf::from(format!(".weftos/models/{model_name}")));
353
354    // User-global.
355    if let Some(home) = dirs_home() {
356        paths.push(home.join(format!(".weftos/models/{model_name}")));
357    }
358
359    // Env override.
360    if let Ok(env_path) = std::env::var("WEFTOS_MODEL_PATH") {
361        paths.push(std::path::PathBuf::from(env_path));
362    }
363
364    paths
365}
366
367/// Get the user's home directory.
368fn dirs_home() -> Option<std::path::PathBuf> {
369    std::env::var("HOME")
370        .ok()
371        .map(std::path::PathBuf::from)
372}
373
374// ── Tests ─────────────────────────────────────────────────────────────────
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379
380    #[tokio::test]
381    async fn mock_embed_returns_correct_dimensions() {
382        let provider = MockEmbeddingProvider::new(64);
383        let vec = provider.embed("hello world").await.unwrap();
384        assert_eq!(vec.len(), 64);
385    }
386
387    #[tokio::test]
388    async fn mock_embed_deterministic() {
389        let provider = MockEmbeddingProvider::new(32);
390        let v1 = provider.embed("test input").await.unwrap();
391        let v2 = provider.embed("test input").await.unwrap();
392        assert_eq!(v1, v2);
393    }
394
395    #[tokio::test]
396    async fn mock_embed_different_inputs_differ() {
397        let provider = MockEmbeddingProvider::new(32);
398        let v1 = provider.embed("alpha").await.unwrap();
399        let v2 = provider.embed("beta").await.unwrap();
400        assert_ne!(v1, v2);
401    }
402
403    #[tokio::test]
404    async fn mock_embed_batch() {
405        let provider = MockEmbeddingProvider::new(16);
406        let results = provider.embed_batch(&["a", "b", "c"]).await.unwrap();
407        assert_eq!(results.len(), 3);
408        for v in &results {
409            assert_eq!(v.len(), 16);
410        }
411    }
412
413    #[tokio::test]
414    async fn mock_embed_batch_matches_individual() {
415        let provider = MockEmbeddingProvider::new(8);
416        let batch = provider.embed_batch(&["x", "y"]).await.unwrap();
417        let x = provider.embed("x").await.unwrap();
418        let y = provider.embed("y").await.unwrap();
419        assert_eq!(batch[0], x);
420        assert_eq!(batch[1], y);
421    }
422
423    #[test]
424    fn mock_model_name() {
425        let provider = MockEmbeddingProvider::new(16);
426        assert_eq!(provider.model_name(), "mock-sha256");
427    }
428
429    #[test]
430    fn mock_dimensions() {
431        let provider = MockEmbeddingProvider::new(128);
432        assert_eq!(provider.dimensions(), 128);
433    }
434
435    #[test]
436    fn embedding_error_display() {
437        let err = EmbeddingError::DimensionMismatch {
438            expected: 384,
439            got: 256,
440        };
441        assert!(err.to_string().contains("384"));
442        assert!(err.to_string().contains("256"));
443
444        let err2 = EmbeddingError::ModelNotLoaded;
445        assert!(err2.to_string().contains("not loaded"));
446    }
447
448    // ── LlmEmbeddingProvider tests ───────────────────────────────────
449
450    #[tokio::test]
451    async fn llm_provider_falls_back_to_mock_when_api_unavailable() {
452        let config = LlmEmbeddingConfig {
453            api_available: false,
454            dimensions: 64,
455            ..Default::default()
456        };
457        let provider = LlmEmbeddingProvider::new(config);
458        // Should succeed via fallback, not error.
459        let vec = provider.embed("hello world").await.unwrap();
460        assert_eq!(vec.len(), 64);
461    }
462
463    #[tokio::test]
464    async fn llm_provider_fallback_is_deterministic() {
465        let config = LlmEmbeddingConfig {
466            api_available: false,
467            dimensions: 32,
468            ..Default::default()
469        };
470        let provider = LlmEmbeddingProvider::new(config);
471        let v1 = provider.embed("test").await.unwrap();
472        let v2 = provider.embed("test").await.unwrap();
473        assert_eq!(v1, v2);
474    }
475
476    #[tokio::test]
477    async fn llm_provider_batch_fallback() {
478        let config = LlmEmbeddingConfig {
479            api_available: false,
480            dimensions: 16,
481            ..Default::default()
482        };
483        let provider = LlmEmbeddingProvider::new(config);
484        let results = provider.embed_batch(&["a", "b", "c"]).await.unwrap();
485        assert_eq!(results.len(), 3);
486        for v in &results {
487            assert_eq!(v.len(), 16);
488        }
489    }
490
491    #[test]
492    fn llm_provider_reports_model_name() {
493        let config = LlmEmbeddingConfig {
494            model: "custom-embed-v1".to_string(),
495            ..Default::default()
496        };
497        let provider = LlmEmbeddingProvider::new(config);
498        assert_eq!(provider.model_name(), "custom-embed-v1");
499    }
500
501    #[test]
502    fn llm_provider_reports_dimensions() {
503        let config = LlmEmbeddingConfig {
504            dimensions: 1536,
505            ..Default::default()
506        };
507        let provider = LlmEmbeddingProvider::new(config);
508        assert_eq!(provider.dimensions(), 1536);
509    }
510
511    #[test]
512    fn llm_provider_api_availability_check() {
513        let unavailable = LlmEmbeddingProvider::new(LlmEmbeddingConfig::default());
514        assert!(!unavailable.is_api_available());
515
516        let available = LlmEmbeddingProvider::new(LlmEmbeddingConfig {
517            api_available: true,
518            ..Default::default()
519        });
520        assert!(available.is_api_available());
521    }
522
523    #[test]
524    fn llm_provider_from_config_defaults() {
525        let table = std::collections::HashMap::new();
526        let provider = LlmEmbeddingProvider::from_config(&table);
527        assert_eq!(provider.dimensions(), 384);
528        assert_eq!(provider.model_name(), "text-embedding-3-small");
529        assert!(!provider.is_api_available());
530    }
531
532    #[test]
533    fn llm_provider_from_config_custom() {
534        let mut table = std::collections::HashMap::new();
535        table.insert("model".to_string(), "my-model".to_string());
536        table.insert("dimensions".to_string(), "768".to_string());
537        table.insert("batch_size".to_string(), "32".to_string());
538        table.insert("api_available".to_string(), "true".to_string());
539        let provider = LlmEmbeddingProvider::from_config(&table);
540        assert_eq!(provider.model_name(), "my-model");
541        assert_eq!(provider.dimensions(), 768);
542        assert_eq!(provider.config().batch_size, 32);
543        assert!(provider.is_api_available());
544    }
545
546    #[test]
547    fn select_provider_returns_mock_when_no_config() {
548        let provider = select_embedding_provider(None);
549        assert_eq!(provider.dimensions(), 64);
550        assert_eq!(provider.model_name(), "mock-sha256");
551    }
552
553    #[test]
554    fn select_provider_returns_llm_when_config_present() {
555        let config = LlmEmbeddingConfig {
556            model: "test-embed".to_string(),
557            dimensions: 256,
558            ..Default::default()
559        };
560        let provider = select_embedding_provider(Some(config));
561        assert_eq!(provider.dimensions(), 256);
562        assert_eq!(provider.model_name(), "test-embed");
563    }
564
565    #[tokio::test]
566    async fn llm_provider_fallback_matches_mock() {
567        let config = LlmEmbeddingConfig {
568            api_available: false,
569            dimensions: 32,
570            ..Default::default()
571        };
572        let llm = LlmEmbeddingProvider::new(config);
573        let mock = MockEmbeddingProvider::new(32);
574        let llm_vec = llm.embed("same input").await.unwrap();
575        let mock_vec = mock.embed("same input").await.unwrap();
576        // Fallback should produce identical results to mock.
577        assert_eq!(llm_vec, mock_vec);
578    }
579}