skill_runtime/embeddings/
mod.rs

1//! Embedding provider abstraction for vector generation
2//!
3//! This module provides a trait-based abstraction for embedding generation,
4//! supporting multiple providers (FastEmbed, OpenAI, Ollama) with a unified interface.
5//!
6//! # Architecture
7//!
8//! ```text
9//! ┌──────────────────────────────────────────────────────────────┐
10//! │                  EmbeddingProvider Trait                     │
11//! │  embed_documents, embed_query, dimensions, model_name       │
12//! └──────────────────────────────────────────────────────────────┘
13//!                              │
14//!          ┌───────────────────┼───────────────────┐
15//!          ▼                   ▼                   ▼
16//!   ┌─────────────┐    ┌─────────────┐    ┌─────────────┐
17//!   │  FastEmbed  │    │   OpenAI    │    │   Ollama    │
18//!   │  (local)    │    │   (API)     │    │  (local)    │
19//!   └─────────────┘    └─────────────┘    └─────────────┘
20//! ```
21//!
22//! # Example
23//!
24//! ```ignore
25//! use skill_runtime::embeddings::{EmbeddingProvider, FastEmbedProvider, EmbeddingConfig};
26//!
27//! // Create a provider
28//! let provider = FastEmbedProvider::new(FastEmbedModel::AllMiniLM)?;
29//!
30//! // Embed a query
31//! let query_embedding = provider.embed_query("search for kubernetes tools").await?;
32//!
33//! // Embed multiple documents
34//! let texts = vec!["doc1".to_string(), "doc2".to_string()];
35//! let embeddings = provider.embed_documents(texts).await?;
36//! ```
37
38mod types;
39mod fastembed;
40mod openai;
41mod ollama;
42mod factory;
43
44pub use types::*;
45pub use fastembed::FastEmbedProvider;
46pub use openai::OpenAIEmbedProvider;
47pub use ollama::OllamaProvider;
48pub use factory::{EmbeddingProviderFactory, create_provider};
49
50use async_trait::async_trait;
51use anyhow::Result;
52
53/// Trait for embedding generation providers
54///
55/// Implementors generate vector embeddings from text, supporting both
56/// single queries and batch document processing.
57#[async_trait]
58pub trait EmbeddingProvider: Send + Sync {
59    /// Generate embeddings for multiple documents
60    ///
61    /// # Arguments
62    /// * `texts` - List of text documents to embed
63    ///
64    /// # Returns
65    /// Vector of embeddings, one per input document, in the same order
66    async fn embed_documents(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>>;
67
68    /// Generate embedding for a single query
69    ///
70    /// Some providers optimize query embeddings differently than document embeddings.
71    /// By default, this calls embed_documents with a single item.
72    ///
73    /// # Arguments
74    /// * `text` - The query text to embed
75    async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
76        let results = self.embed_documents(vec![text.to_string()]).await?;
77        results.into_iter().next().ok_or_else(|| {
78            anyhow::anyhow!("embed_documents returned empty result for single query")
79        })
80    }
81
82    /// Get the embedding dimension size
83    fn dimensions(&self) -> usize;
84
85    /// Get the model name/identifier
86    fn model_name(&self) -> &str;
87
88    /// Get the provider name (e.g., "fastembed", "openai", "ollama")
89    fn provider_name(&self) -> &str;
90
91    /// Check if the provider is available (API key set, server running, etc.)
92    async fn health_check(&self) -> Result<bool> {
93        // Default: try to embed a simple query
94        match self.embed_query("test").await {
95            Ok(_) => Ok(true),
96            Err(_) => Ok(false),
97        }
98    }
99
100    /// Get the maximum batch size for embed_documents
101    fn max_batch_size(&self) -> usize {
102        100 // Default, can be overridden
103    }
104
105    /// Embed documents in batches, respecting max_batch_size
106    async fn embed_documents_batched(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
107        let batch_size = self.max_batch_size();
108        if texts.len() <= batch_size {
109            return self.embed_documents(texts).await;
110        }
111
112        let mut all_embeddings = Vec::with_capacity(texts.len());
113        for chunk in texts.chunks(batch_size) {
114            let embeddings = self.embed_documents(chunk.to_vec()).await?;
115            all_embeddings.extend(embeddings);
116        }
117        Ok(all_embeddings)
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124
125    // Mock provider for testing
126    struct MockProvider {
127        dims: usize,
128    }
129
130    #[async_trait]
131    impl EmbeddingProvider for MockProvider {
132        async fn embed_documents(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
133            Ok(texts
134                .iter()
135                .map(|_| vec![0.1; self.dims])
136                .collect())
137        }
138
139        fn dimensions(&self) -> usize {
140            self.dims
141        }
142
143        fn model_name(&self) -> &str {
144            "mock-model"
145        }
146
147        fn provider_name(&self) -> &str {
148            "mock"
149        }
150
151        fn max_batch_size(&self) -> usize {
152            2
153        }
154    }
155
156    #[tokio::test]
157    async fn test_embed_query_default() {
158        let provider = MockProvider { dims: 384 };
159        let embedding = provider.embed_query("test query").await.unwrap();
160        assert_eq!(embedding.len(), 384);
161    }
162
163    #[tokio::test]
164    async fn test_embed_documents_batched() {
165        let provider = MockProvider { dims: 3 };
166        let texts: Vec<String> = (0..5).map(|i| format!("doc{}", i)).collect();
167
168        let embeddings = provider.embed_documents_batched(texts).await.unwrap();
169        assert_eq!(embeddings.len(), 5);
170        for emb in embeddings {
171            assert_eq!(emb.len(), 3);
172        }
173    }
174
175    #[tokio::test]
176    async fn test_health_check_default() {
177        let provider = MockProvider { dims: 3 };
178        let healthy = provider.health_check().await.unwrap();
179        assert!(healthy);
180    }
181}