1use tokio_util::sync::CancellationToken;
24use tracing;
25
26use crate::AppError;
27use crate::circuit_breaker::CircuitBreaker;
28use crate::circuit_breaker::CircuitBreakerError;
29use crate::config::EmbeddingServiceConfig;
30use crate::models::NewDataset;
31use crate::progress::{HarvestEvent, ProgressReporter};
32use crate::traits::{DatasetStore, EmbeddingProvider};
33
34#[derive(Debug, Clone, Default)]
36pub struct EmbeddingStats {
37 pub embedded: usize,
39 pub failed: usize,
41 pub skipped: usize,
43 pub total: usize,
45}
46
47impl EmbeddingStats {
48 pub fn successful(&self) -> usize {
50 self.embedded
51 }
52}
53
54impl std::fmt::Display for EmbeddingStats {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 write!(
57 f,
58 "embedded: {}, failed: {}, skipped: {}, total: {}",
59 self.embedded, self.failed, self.skipped, self.total
60 )
61 }
62}
63
64pub struct EmbeddingService<S, E>
69where
70 S: DatasetStore,
71 E: EmbeddingProvider,
72{
73 store: S,
74 embedding: E,
75 config: EmbeddingServiceConfig,
76}
77
78impl<S, E> Clone for EmbeddingService<S, E>
79where
80 S: DatasetStore + Clone,
81 E: EmbeddingProvider + Clone,
82{
83 fn clone(&self) -> Self {
84 Self {
85 store: self.store.clone(),
86 embedding: self.embedding.clone(),
87 config: self.config.clone(),
88 }
89 }
90}
91
92impl<S, E> EmbeddingService<S, E>
93where
94 S: DatasetStore,
95 E: EmbeddingProvider,
96{
97 pub fn new(store: S, embedding: E) -> Self {
99 Self {
100 store,
101 embedding,
102 config: EmbeddingServiceConfig::default(),
103 }
104 }
105
106 pub fn with_config(store: S, embedding: E, config: EmbeddingServiceConfig) -> Self {
108 Self {
109 store,
110 embedding,
111 config,
112 }
113 }
114
115 pub fn embedding_provider(&self) -> &E {
117 &self.embedding
118 }
119
120 pub async fn embed_pending(
131 &self,
132 portal_filter: Option<&str>,
133 reporter: &impl ProgressReporter,
134 cancel_token: CancellationToken,
135 ) -> Result<EmbeddingStats, AppError> {
136 let total = self.store.count_pending_embeddings(portal_filter).await? as usize;
137
138 if total == 0 {
139 tracing::info!("No datasets pending embedding");
140 return Ok(EmbeddingStats::default());
141 }
142
143 tracing::info!(
144 total,
145 portal = portal_filter.unwrap_or("all"),
146 provider = self.embedding.name(),
147 "Starting embedding pass"
148 );
149
150 let mut stats = EmbeddingStats {
151 total,
152 ..Default::default()
153 };
154
155 let effective_batch_size =
156 std::cmp::min(self.config.batch_size, self.embedding.max_batch_size()).max(1);
157
158 let circuit_breaker =
159 CircuitBreaker::new(self.embedding.name(), self.config.circuit_breaker.clone());
160
161 let mut processed = 0usize;
162
163 let page_size = effective_batch_size * 10; loop {
168 if cancel_token.is_cancelled() {
169 tracing::info!("Embedding pass cancelled");
170 break;
171 }
172
173 let page = self
174 .store
175 .list_pending_embeddings(portal_filter, Some(page_size))
176 .await?;
177
178 if page.is_empty() {
179 break;
180 }
181
182 let embedded_before = stats.embedded;
183
184 for batch in page.chunks(effective_batch_size) {
185 if cancel_token.is_cancelled() {
186 tracing::info!("Embedding pass cancelled");
187 break;
188 }
189
190 self.process_batch(batch, &circuit_breaker, &mut stats)
191 .await;
192
193 processed += batch.len();
194
195 reporter.report(HarvestEvent::DatasetProcessed {
196 current: processed,
197 total,
198 created: 0,
199 updated: stats.embedded,
200 unchanged: 0,
201 failed: stats.failed,
202 skipped: stats.skipped,
203 });
204 }
205
206 if stats.embedded == embedded_before {
209 tracing::warn!(
210 "No progress this page — stopping to avoid infinite loop \
211 ({} failed, {} skipped)",
212 stats.failed,
213 stats.skipped
214 );
215 break;
216 }
217 }
218
219 tracing::info!(
220 embedded = stats.embedded,
221 failed = stats.failed,
222 skipped = stats.skipped,
223 total = stats.total,
224 "Embedding pass complete"
225 );
226
227 Ok(stats)
228 }
229
230 async fn process_batch(
233 &self,
234 datasets: &[crate::Dataset],
235 circuit_breaker: &CircuitBreaker,
236 stats: &mut EmbeddingStats,
237 ) {
238 let embeddable: Vec<(&crate::Dataset, String)> = datasets
240 .iter()
241 .filter_map(|d| {
242 let text = format!(
243 "{} {}",
244 d.title,
245 d.description.as_deref().unwrap_or_default()
246 );
247 if text.trim().is_empty() {
248 None
249 } else {
250 Some((d, text))
251 }
252 })
253 .collect();
254
255 let skipped_empty = datasets.len() - embeddable.len();
256 if skipped_empty > 0 {
257 tracing::debug!(skipped_empty, "Skipped datasets with empty text");
258 stats.failed += skipped_empty;
259 }
260
261 if embeddable.is_empty() {
262 return;
263 }
264
265 let texts: Vec<String> = embeddable.iter().map(|(_, t)| t.clone()).collect();
266 let batch_size = texts.len();
267
268 match circuit_breaker
269 .call(|| self.embedding.generate_batch(&texts))
270 .await
271 {
272 Ok(embeddings) => {
273 if embeddings.len() != batch_size {
274 tracing::warn!(
275 expected = batch_size,
276 got = embeddings.len(),
277 "Batch embedding count mismatch, failing batch"
278 );
279 stats.failed += batch_size;
280 return;
281 }
282
283 let upsert_datasets: Vec<NewDataset> = embeddable
286 .iter()
287 .zip(embeddings)
288 .map(|((d, _), emb)| {
289 let content_hash = match &d.content_hash {
290 Some(h) => h.clone(),
291 None => {
292 tracing::info!(
293 original_id = %d.original_id,
294 "Dataset missing content_hash, automatically generating one"
295 );
296 NewDataset::compute_content_hash(&d.title, d.description.as_deref())
297 }
298 };
299 NewDataset {
300 original_id: d.original_id.clone(),
301 source_portal: d.source_portal.clone(),
302 url: d.url.clone(),
303 title: d.title.clone(),
304 description: d.description.clone(),
305 embedding: Some(emb),
306 metadata: d.metadata.clone(),
307 content_hash,
308 }
309 })
310 .collect();
311
312 let skipped_no_hash = batch_size - upsert_datasets.len();
313 stats.failed += skipped_no_hash;
314 let upsert_count = upsert_datasets.len();
315
316 match self.store.batch_upsert(&upsert_datasets).await {
317 Ok(_) => {
318 stats.embedded += upsert_count;
319 }
320 Err(e) => {
321 tracing::warn!(
322 count = upsert_count,
323 error = %e,
324 "Failed to batch upsert datasets with embeddings"
325 );
326 stats.failed += upsert_count;
327 }
328 }
329 }
330 Err(CircuitBreakerError::Open { retry_after, .. }) => {
331 tracing::debug!(
332 batch_size,
333 retry_after_secs = retry_after.as_secs(),
334 "Skipping batch - circuit breaker open"
335 );
336 stats.skipped += batch_size;
337 }
338 Err(CircuitBreakerError::Inner(e)) => {
339 tracing::warn!(
340 batch_size,
341 error = %e,
342 "Batch embedding generation failed"
343 );
344 stats.failed += batch_size;
345 }
346 }
347 }
348}