skill_runtime/embeddings/mod.rs
1//! Embedding provider abstraction for vector generation
2//!
3//! This module provides a trait-based abstraction for embedding generation,
4//! supporting multiple providers (FastEmbed, OpenAI, Ollama) with a unified interface.
5//!
6//! # Architecture
7//!
8//! ```text
9//! ┌──────────────────────────────────────────────────────────────┐
10//! │ EmbeddingProvider Trait │
11//! │ embed_documents, embed_query, dimensions, model_name │
12//! └──────────────────────────────────────────────────────────────┘
13//! │
14//! ┌───────────────────┼───────────────────┐
15//! ▼ ▼ ▼
16//! ┌─────────────┐ ┌─────────────┐ ┌─────────────┐
17//! │ FastEmbed │ │ OpenAI │ │ Ollama │
18//! │ (local) │ │ (API) │ │ (local) │
19//! └─────────────┘ └─────────────┘ └─────────────┘
20//! ```
21//!
22//! # Example
23//!
24//! ```ignore
25//! use skill_runtime::embeddings::{EmbeddingProvider, FastEmbedProvider, EmbeddingConfig};
26//!
27//! // Create a provider
28//! let provider = FastEmbedProvider::new(FastEmbedModel::AllMiniLM)?;
29//!
30//! // Embed a query
31//! let query_embedding = provider.embed_query("search for kubernetes tools").await?;
32//!
33//! // Embed multiple documents
34//! let texts = vec!["doc1".to_string(), "doc2".to_string()];
35//! let embeddings = provider.embed_documents(texts).await?;
36//! ```
37
38mod types;
39mod fastembed;
40mod openai;
41mod ollama;
42mod factory;
43
44pub use types::*;
45pub use fastembed::FastEmbedProvider;
46pub use openai::OpenAIEmbedProvider;
47pub use ollama::OllamaProvider;
48pub use factory::{EmbeddingProviderFactory, create_provider};
49
50use async_trait::async_trait;
51use anyhow::Result;
52
53/// Trait for embedding generation providers
54///
55/// Implementors generate vector embeddings from text, supporting both
56/// single queries and batch document processing.
57#[async_trait]
58pub trait EmbeddingProvider: Send + Sync {
59 /// Generate embeddings for multiple documents
60 ///
61 /// # Arguments
62 /// * `texts` - List of text documents to embed
63 ///
64 /// # Returns
65 /// Vector of embeddings, one per input document, in the same order
66 async fn embed_documents(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>>;
67
68 /// Generate embedding for a single query
69 ///
70 /// Some providers optimize query embeddings differently than document embeddings.
71 /// By default, this calls embed_documents with a single item.
72 ///
73 /// # Arguments
74 /// * `text` - The query text to embed
75 async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
76 let results = self.embed_documents(vec![text.to_string()]).await?;
77 results.into_iter().next().ok_or_else(|| {
78 anyhow::anyhow!("embed_documents returned empty result for single query")
79 })
80 }
81
82 /// Get the embedding dimension size
83 fn dimensions(&self) -> usize;
84
85 /// Get the model name/identifier
86 fn model_name(&self) -> &str;
87
88 /// Get the provider name (e.g., "fastembed", "openai", "ollama")
89 fn provider_name(&self) -> &str;
90
91 /// Check if the provider is available (API key set, server running, etc.)
92 async fn health_check(&self) -> Result<bool> {
93 // Default: try to embed a simple query
94 match self.embed_query("test").await {
95 Ok(_) => Ok(true),
96 Err(_) => Ok(false),
97 }
98 }
99
100 /// Get the maximum batch size for embed_documents
101 fn max_batch_size(&self) -> usize {
102 100 // Default, can be overridden
103 }
104
105 /// Embed documents in batches, respecting max_batch_size
106 async fn embed_documents_batched(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
107 let batch_size = self.max_batch_size();
108 if texts.len() <= batch_size {
109 return self.embed_documents(texts).await;
110 }
111
112 let mut all_embeddings = Vec::with_capacity(texts.len());
113 for chunk in texts.chunks(batch_size) {
114 let embeddings = self.embed_documents(chunk.to_vec()).await?;
115 all_embeddings.extend(embeddings);
116 }
117 Ok(all_embeddings)
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124
125 // Mock provider for testing
126 struct MockProvider {
127 dims: usize,
128 }
129
130 #[async_trait]
131 impl EmbeddingProvider for MockProvider {
132 async fn embed_documents(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
133 Ok(texts
134 .iter()
135 .map(|_| vec![0.1; self.dims])
136 .collect())
137 }
138
139 fn dimensions(&self) -> usize {
140 self.dims
141 }
142
143 fn model_name(&self) -> &str {
144 "mock-model"
145 }
146
147 fn provider_name(&self) -> &str {
148 "mock"
149 }
150
151 fn max_batch_size(&self) -> usize {
152 2
153 }
154 }
155
156 #[tokio::test]
157 async fn test_embed_query_default() {
158 let provider = MockProvider { dims: 384 };
159 let embedding = provider.embed_query("test query").await.unwrap();
160 assert_eq!(embedding.len(), 384);
161 }
162
163 #[tokio::test]
164 async fn test_embed_documents_batched() {
165 let provider = MockProvider { dims: 3 };
166 let texts: Vec<String> = (0..5).map(|i| format!("doc{}", i)).collect();
167
168 let embeddings = provider.embed_documents_batched(texts).await.unwrap();
169 assert_eq!(embeddings.len(), 5);
170 for emb in embeddings {
171 assert_eq!(emb.len(), 3);
172 }
173 }
174
175 #[tokio::test]
176 async fn test_health_check_default() {
177 let provider = MockProvider { dims: 3 };
178 let healthy = provider.health_check().await.unwrap();
179 assert!(healthy);
180 }
181}