skill_runtime/embeddings/
fastembed.rs1use super::{EmbeddingProvider, FastEmbedModel};
7use anyhow::{Context, Result};
8use async_trait::async_trait;
9use rig::embeddings::EmbeddingModel as RigEmbeddingModel;
10use rig_fastembed::{Client as FastembedClient, FastembedModel as RigFastembedModel};
11use std::sync::Arc;
12
13pub struct FastEmbedProvider {
18 client: Arc<FastembedClient>,
19 model: FastEmbedModel,
20 rig_model: RigFastembedModel,
21 dims: usize,
22}
23
24impl FastEmbedProvider {
25 pub fn new() -> Result<Self> {
27 Self::with_model(FastEmbedModel::default())
28 }
29
30 pub fn with_model(model: FastEmbedModel) -> Result<Self> {
32 let client = Arc::new(FastembedClient::new());
33 let rig_model = Self::to_rig_model(&model);
34 let dims = model.dimensions();
35
36 Ok(Self {
37 client,
38 model,
39 rig_model,
40 dims,
41 })
42 }
43
44 pub fn from_model_name(name: &str) -> Result<Self> {
46 let model: FastEmbedModel = name.parse()?;
47 Self::with_model(model)
48 }
49
50 fn to_rig_model(model: &FastEmbedModel) -> RigFastembedModel {
52 match model {
53 FastEmbedModel::AllMiniLM => RigFastembedModel::AllMiniLML6V2Q,
54 FastEmbedModel::BGESmallEN => RigFastembedModel::BGESmallENV15Q,
55 FastEmbedModel::BGEBaseEN => RigFastembedModel::BGEBaseENV15,
56 FastEmbedModel::BGELargeEN => RigFastembedModel::BGELargeENV15,
57 }
58 }
59
60}
61
62impl Default for FastEmbedProvider {
63 fn default() -> Self {
64 Self::new().expect("Failed to create default FastEmbed provider")
65 }
66}
67
68#[async_trait]
69impl EmbeddingProvider for FastEmbedProvider {
70 async fn embed_documents(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
71 if texts.is_empty() {
72 return Ok(Vec::new());
73 }
74
75 let embedding_model = self.client.embedding_model(&self.rig_model);
76
77 let embeddings = embedding_model
79 .embed_texts(texts)
80 .await
81 .context("FastEmbed failed to generate embeddings")?;
82
83 let results: Vec<Vec<f32>> = embeddings
85 .into_iter()
86 .map(|emb| emb.vec.into_iter().map(|x| x as f32).collect())
87 .collect();
88
89 Ok(results)
90 }
91
92 fn dimensions(&self) -> usize {
93 self.dims
94 }
95
96 fn model_name(&self) -> &str {
97 self.model.rig_model_name()
98 }
99
100 fn provider_name(&self) -> &str {
101 "fastembed"
102 }
103
104 fn max_batch_size(&self) -> usize {
105 256
107 }
108
109 async fn health_check(&self) -> Result<bool> {
110 match self.embed_query("health check").await {
112 Ok(emb) => Ok(emb.len() == self.dims),
113 Err(_) => Ok(false),
114 }
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 #[test]
123 fn test_model_conversion() {
124 let _ = FastEmbedProvider::to_rig_model(&FastEmbedModel::AllMiniLM);
126 let _ = FastEmbedProvider::to_rig_model(&FastEmbedModel::BGESmallEN);
127 let _ = FastEmbedProvider::to_rig_model(&FastEmbedModel::BGEBaseEN);
128 let _ = FastEmbedProvider::to_rig_model(&FastEmbedModel::BGELargeEN);
129 }
130
131 #[test]
132 fn test_provider_creation() {
133 let provider = FastEmbedProvider::new().unwrap();
134 assert_eq!(provider.dimensions(), 384);
135 assert_eq!(provider.model_name(), "all-minilm");
136 assert_eq!(provider.provider_name(), "fastembed");
137 }
138
139 #[test]
140 fn test_from_model_name() {
141 let provider = FastEmbedProvider::from_model_name("bge-small").unwrap();
142 assert_eq!(provider.dimensions(), 384);
143
144 let provider = FastEmbedProvider::from_model_name("bge-base").unwrap();
145 assert_eq!(provider.dimensions(), 768);
146 }
147
148 #[tokio::test]
150 #[ignore = "requires model download"]
151 async fn test_embed_documents() {
152 let provider = FastEmbedProvider::new().unwrap();
153 let texts = vec![
154 "Hello world".to_string(),
155 "How are you".to_string(),
156 ];
157
158 let embeddings = provider.embed_documents(texts).await.unwrap();
159 assert_eq!(embeddings.len(), 2);
160 assert_eq!(embeddings[0].len(), 384);
161 assert_eq!(embeddings[1].len(), 384);
162 }
163
164 #[tokio::test]
165 async fn test_embed_empty() {
166 let provider = FastEmbedProvider::new().unwrap();
167 let embeddings = provider.embed_documents(vec![]).await.unwrap();
168 assert!(embeddings.is_empty());
169 }
170}