claw_vector/embeddings/
client.rs1use std::{sync::Arc, time::Duration};
3
4use async_trait::async_trait;
5use tokio::{sync::Mutex, time::sleep};
6use tonic::{
7 transport::{Channel, Endpoint},
8 Request,
9};
10use tracing::instrument;
11
12use crate::{
13 config::VectorConfig,
14 embeddings::{
15 cache::{EmbeddingCache, EmbeddingCacheStats},
16 types::ModelInfo,
17 },
18 error::{VectorError, VectorResult},
19 grpc::proto::{
20 embedding_service_client::EmbeddingServiceClient, EmbedRequest, HealthRequest,
21 ModelInfoRequest,
22 },
23};
24
25#[async_trait]
27pub trait EmbeddingProvider: Send + Sync {
28 async fn embed(&self, texts: Vec<String>) -> VectorResult<Vec<Vec<f32>>>;
30
31 async fn embed_one(&self, text: &str) -> VectorResult<Vec<f32>>;
33
34 async fn health_check(&self) -> VectorResult<bool>;
36
37 async fn model_info(&self) -> VectorResult<ModelInfo>;
39
40 async fn cache_stats(&self) -> Option<EmbeddingCacheStats> {
42 None
43 }
44}
45
46pub struct EmbeddingClient {
48 pub client: EmbeddingServiceClient<Channel>,
50 pub config: VectorConfig,
52 pub cache: Arc<Mutex<EmbeddingCache>>,
54}
55
56impl EmbeddingClient {
57 pub async fn new(config: &VectorConfig) -> VectorResult<Self> {
59 let timeout = Duration::from_millis(config.embedding_timeout_ms);
60 let endpoint = Endpoint::from_shared(config.embedding_service_url.clone())
61 .map_err(|err| VectorError::Embedding(format!("invalid embedding URL: {err}")))?
62 .connect_timeout(timeout)
63 .timeout(timeout);
64 let channel = endpoint.connect().await.map_err(|err| {
65 VectorError::Embedding(format!("failed to connect to embedding service: {err}"))
66 })?;
67
68 let client = EmbeddingClient {
69 client: EmbeddingServiceClient::new(channel),
70 config: config.clone(),
71 cache: Arc::new(Mutex::new(EmbeddingCache::new(config.cache_size))),
72 };
73
74 let mut delay = Duration::from_millis(100);
75 for attempt in 0..3 {
76 match client.health_check().await {
77 Ok(true) => return Ok(client),
78 Ok(false) if attempt < 2 => sleep(delay).await,
79 Err(_) if attempt < 2 => sleep(delay).await,
80 Ok(false) => {
81 return Err(VectorError::Embedding(
82 "embedding service is reachable but not ready".into(),
83 ))
84 }
85 Err(err) => return Err(err),
86 }
87 delay *= 2;
88 }
89
90 Err(VectorError::Embedding(
91 "embedding service readiness check failed".into(),
92 ))
93 }
94
95 pub async fn connect(config: &VectorConfig) -> VectorResult<Self> {
97 Self::new(config).await
98 }
99
100 pub async fn cache_stats_snapshot(&self) -> EmbeddingCacheStats {
102 self.cache.lock().await.stats()
103 }
104
105 pub async fn embed(&self, texts: Vec<String>) -> VectorResult<Vec<Vec<f32>>> {
107 <Self as EmbeddingProvider>::embed(self, texts).await
108 }
109
110 pub async fn embed_one(&self, text: &str) -> VectorResult<Vec<f32>> {
112 <Self as EmbeddingProvider>::embed_one(self, text).await
113 }
114
115 pub async fn health_check(&self) -> VectorResult<bool> {
117 <Self as EmbeddingProvider>::health_check(self).await
118 }
119
120 pub async fn model_info(&self) -> VectorResult<ModelInfo> {
122 <Self as EmbeddingProvider>::model_info(self).await
123 }
124}
125
126#[async_trait]
127impl EmbeddingProvider for EmbeddingClient {
128 #[instrument(skip(self, texts))]
130 async fn embed(&self, texts: Vec<String>) -> VectorResult<Vec<Vec<f32>>> {
131 if texts.is_empty() {
132 return Ok(Vec::new());
133 }
134
135 let mut outputs: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
136 let mut uncached = Vec::<(usize, String)>::new();
137
138 {
139 let mut cache = self.cache.lock().await;
140 for (index, text) in texts.iter().enumerate() {
141 if let Some(vector) = cache.get(text) {
142 outputs[index] = Some(vector);
143 } else {
144 uncached.push((index, text.clone()));
145 }
146 }
147 }
148
149 let mut fresh_embeddings = Vec::<(String, Vec<f32>)>::new();
150 let batch_size = self.config.batch_size.max(1);
151 let mut client = self.client.clone();
152 for chunk in uncached.chunks(batch_size) {
153 let batch_texts = chunk
154 .iter()
155 .map(|(_, text)| text.clone())
156 .collect::<Vec<_>>();
157 let mut request = Request::new(EmbedRequest {
158 texts: batch_texts.clone(),
159 model_name: String::new(),
160 normalize: true,
161 });
162 request.set_timeout(Duration::from_millis(self.config.embedding_timeout_ms));
163
164 let response = client
165 .embed(request)
166 .await
167 .map_err(VectorError::from)?
168 .into_inner();
169
170 if response.vectors.len() != batch_texts.len() {
171 return Err(VectorError::Embedding(format!(
172 "embedding service returned {} vectors for {} texts",
173 response.vectors.len(),
174 batch_texts.len()
175 )));
176 }
177
178 for ((index, text), vector) in chunk.iter().zip(response.vectors.into_iter()) {
179 outputs[*index] = Some(vector.values.clone());
180 fresh_embeddings.push((text.clone(), vector.values));
181 }
182 }
183
184 if !fresh_embeddings.is_empty() {
185 let mut cache = self.cache.lock().await;
186 for (text, vector) in &fresh_embeddings {
187 cache.insert(text, vector.clone());
188 }
189 }
190
191 outputs
192 .into_iter()
193 .map(|vector| {
194 vector.ok_or_else(|| {
195 VectorError::Embedding("embedding response did not contain all vectors".into())
196 })
197 })
198 .collect()
199 }
200
201 #[instrument(skip(self, text))]
203 async fn embed_one(&self, text: &str) -> VectorResult<Vec<f32>> {
204 self.embed(vec![text.to_string()])
205 .await?
206 .into_iter()
207 .next()
208 .ok_or_else(|| {
209 VectorError::Embedding(
210 "embedding response was empty for single-text request".into(),
211 )
212 })
213 }
214
215 #[instrument(skip(self))]
217 async fn health_check(&self) -> VectorResult<bool> {
218 let mut client = self.client.clone();
219 let mut request = Request::new(HealthRequest {});
220 request.set_timeout(Duration::from_millis(self.config.embedding_timeout_ms));
221 let response = client
222 .health(request)
223 .await
224 .map_err(VectorError::from)?
225 .into_inner();
226 Ok(response.ready)
227 }
228
229 #[instrument(skip(self))]
231 async fn model_info(&self) -> VectorResult<ModelInfo> {
232 let mut client = self.client.clone();
233 let mut request = Request::new(ModelInfoRequest {});
234 request.set_timeout(Duration::from_millis(self.config.embedding_timeout_ms));
235 let response = client
236 .model_info(request)
237 .await
238 .map_err(VectorError::from)?
239 .into_inner();
240 Ok(ModelInfo {
241 model_name: response.model_name,
242 dimensions: response.dimensions as usize,
243 max_sequence_length: response.max_sequence_length as usize,
244 device: response.device,
245 })
246 }
247
248 async fn cache_stats(&self) -> Option<EmbeddingCacheStats> {
250 Some(self.cache.lock().await.stats())
251 }
252}