Skip to main content

claw_vector/embeddings/
client.rs

1// embeddings/client.rs — gRPC client to the Python embedding microservice.
2use std::{sync::Arc, time::Duration};
3
4use async_trait::async_trait;
5use tokio::{sync::Mutex, time::sleep};
6use tonic::{
7    transport::{Channel, Endpoint},
8    Request,
9};
10use tracing::instrument;
11
12use crate::{
13    config::VectorConfig,
14    embeddings::{
15        cache::{EmbeddingCache, EmbeddingCacheStats},
16        types::ModelInfo,
17    },
18    error::{VectorError, VectorResult},
19    grpc::proto::{
20        embedding_service_client::EmbeddingServiceClient, EmbedRequest, HealthRequest,
21        ModelInfoRequest,
22    },
23};
24
25/// Abstraction over text embedding providers.
26#[async_trait]
27pub trait EmbeddingProvider: Send + Sync {
28    /// Embed a list of texts.
29    async fn embed(&self, texts: Vec<String>) -> VectorResult<Vec<Vec<f32>>>;
30
31    /// Embed a single text.
32    async fn embed_one(&self, text: &str) -> VectorResult<Vec<f32>>;
33
34    /// Check service readiness.
35    async fn health_check(&self) -> VectorResult<bool>;
36
37    /// Return metadata about the active model.
38    async fn model_info(&self) -> VectorResult<ModelInfo>;
39
40    /// Return cache statistics when available.
41    async fn cache_stats(&self) -> Option<EmbeddingCacheStats> {
42        None
43    }
44}
45
46/// gRPC client that calls the Python embedding service and caches results locally.
47pub struct EmbeddingClient {
48    /// Underlying tonic client.
49    pub client: EmbeddingServiceClient<Channel>,
50    /// Runtime configuration.
51    pub config: VectorConfig,
52    /// Shared embedding cache.
53    pub cache: Arc<Mutex<EmbeddingCache>>,
54}
55
56impl EmbeddingClient {
57    /// Connect to the embedding service, initialize the cache, and verify readiness.
58    pub async fn new(config: &VectorConfig) -> VectorResult<Self> {
59        let timeout = Duration::from_millis(config.embedding_timeout_ms);
60        let endpoint = Endpoint::from_shared(config.embedding_service_url.clone())
61            .map_err(|err| VectorError::Embedding(format!("invalid embedding URL: {err}")))?
62            .connect_timeout(timeout)
63            .timeout(timeout);
64        let channel = endpoint.connect().await.map_err(|err| {
65            VectorError::Embedding(format!("failed to connect to embedding service: {err}"))
66        })?;
67
68        let client = EmbeddingClient {
69            client: EmbeddingServiceClient::new(channel),
70            config: config.clone(),
71            cache: Arc::new(Mutex::new(EmbeddingCache::new(config.cache_size))),
72        };
73
74        let mut delay = Duration::from_millis(100);
75        for attempt in 0..3 {
76            match client.health_check().await {
77                Ok(true) => return Ok(client),
78                Ok(false) if attempt < 2 => sleep(delay).await,
79                Err(_) if attempt < 2 => sleep(delay).await,
80                Ok(false) => {
81                    return Err(VectorError::Embedding(
82                        "embedding service is reachable but not ready".into(),
83                    ))
84                }
85                Err(err) => return Err(err),
86            }
87            delay *= 2;
88        }
89
90        Err(VectorError::Embedding(
91            "embedding service readiness check failed".into(),
92        ))
93    }
94
95    /// Alias for [`EmbeddingClient::new`].
96    pub async fn connect(config: &VectorConfig) -> VectorResult<Self> {
97        Self::new(config).await
98    }
99
100    /// Return the current cache statistics snapshot.
101    pub async fn cache_stats_snapshot(&self) -> EmbeddingCacheStats {
102        self.cache.lock().await.stats()
103    }
104
105    /// Embed a list of texts.
106    pub async fn embed(&self, texts: Vec<String>) -> VectorResult<Vec<Vec<f32>>> {
107        <Self as EmbeddingProvider>::embed(self, texts).await
108    }
109
110    /// Embed a single text.
111    pub async fn embed_one(&self, text: &str) -> VectorResult<Vec<f32>> {
112        <Self as EmbeddingProvider>::embed_one(self, text).await
113    }
114
115    /// Health-check the embedding service.
116    pub async fn health_check(&self) -> VectorResult<bool> {
117        <Self as EmbeddingProvider>::health_check(self).await
118    }
119
120    /// Return model metadata from the embedding service.
121    pub async fn model_info(&self) -> VectorResult<ModelInfo> {
122        <Self as EmbeddingProvider>::model_info(self).await
123    }
124}
125
126#[async_trait]
127impl EmbeddingProvider for EmbeddingClient {
128    /// Embed a list of texts, serving cached results where available.
129    #[instrument(skip(self, texts))]
130    async fn embed(&self, texts: Vec<String>) -> VectorResult<Vec<Vec<f32>>> {
131        if texts.is_empty() {
132            return Ok(Vec::new());
133        }
134
135        let mut outputs: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
136        let mut uncached = Vec::<(usize, String)>::new();
137
138        {
139            let mut cache = self.cache.lock().await;
140            for (index, text) in texts.iter().enumerate() {
141                if let Some(vector) = cache.get(text) {
142                    outputs[index] = Some(vector);
143                } else {
144                    uncached.push((index, text.clone()));
145                }
146            }
147        }
148
149        let mut fresh_embeddings = Vec::<(String, Vec<f32>)>::new();
150        let batch_size = self.config.batch_size.max(1);
151        let mut client = self.client.clone();
152        for chunk in uncached.chunks(batch_size) {
153            let batch_texts = chunk
154                .iter()
155                .map(|(_, text)| text.clone())
156                .collect::<Vec<_>>();
157            let mut request = Request::new(EmbedRequest {
158                texts: batch_texts.clone(),
159                model_name: String::new(),
160                normalize: true,
161            });
162            request.set_timeout(Duration::from_millis(self.config.embedding_timeout_ms));
163
164            let response = client
165                .embed(request)
166                .await
167                .map_err(VectorError::from)?
168                .into_inner();
169
170            if response.vectors.len() != batch_texts.len() {
171                return Err(VectorError::Embedding(format!(
172                    "embedding service returned {} vectors for {} texts",
173                    response.vectors.len(),
174                    batch_texts.len()
175                )));
176            }
177
178            for ((index, text), vector) in chunk.iter().zip(response.vectors.into_iter()) {
179                outputs[*index] = Some(vector.values.clone());
180                fresh_embeddings.push((text.clone(), vector.values));
181            }
182        }
183
184        if !fresh_embeddings.is_empty() {
185            let mut cache = self.cache.lock().await;
186            for (text, vector) in &fresh_embeddings {
187                cache.insert(text, vector.clone());
188            }
189        }
190
191        outputs
192            .into_iter()
193            .map(|vector| {
194                vector.ok_or_else(|| {
195                    VectorError::Embedding("embedding response did not contain all vectors".into())
196                })
197            })
198            .collect()
199    }
200
201    /// Embed a single text.
202    #[instrument(skip(self, text))]
203    async fn embed_one(&self, text: &str) -> VectorResult<Vec<f32>> {
204        self.embed(vec![text.to_string()])
205            .await?
206            .into_iter()
207            .next()
208            .ok_or_else(|| {
209                VectorError::Embedding(
210                    "embedding response was empty for single-text request".into(),
211                )
212            })
213    }
214
215    /// Health-check the embedding service.
216    #[instrument(skip(self))]
217    async fn health_check(&self) -> VectorResult<bool> {
218        let mut client = self.client.clone();
219        let mut request = Request::new(HealthRequest {});
220        request.set_timeout(Duration::from_millis(self.config.embedding_timeout_ms));
221        let response = client
222            .health(request)
223            .await
224            .map_err(VectorError::from)?
225            .into_inner();
226        Ok(response.ready)
227    }
228
229    /// Return metadata about the currently loaded model.
230    #[instrument(skip(self))]
231    async fn model_info(&self) -> VectorResult<ModelInfo> {
232        let mut client = self.client.clone();
233        let mut request = Request::new(ModelInfoRequest {});
234        request.set_timeout(Duration::from_millis(self.config.embedding_timeout_ms));
235        let response = client
236            .model_info(request)
237            .await
238            .map_err(VectorError::from)?
239            .into_inner();
240        Ok(ModelInfo {
241            model_name: response.model_name,
242            dimensions: response.dimensions as usize,
243            max_sequence_length: response.max_sequence_length as usize,
244            device: response.device,
245        })
246    }
247
248    /// Return cache statistics for the client-local embedding cache.
249    async fn cache_stats(&self) -> Option<EmbeddingCacheStats> {
250        Some(self.cache.lock().await.stats())
251    }
252}