ceres_core/embedding.rs
1//! Standalone embedding service for generating dataset embeddings.
2//!
3//! This service is decoupled from harvesting — it processes datasets that are
4//! already stored in the database with `embedding IS NULL`. This enables:
5//!
6//! - Harvesting metadata without an embedding API key
7//! - Switching embedding providers without re-harvesting
8//! - Backfilling embeddings after outages
9//! - Independent scaling of harvest and embedding workloads
10//!
11//! # Example
12//!
13//! ```ignore
14//! use ceres_core::embedding::EmbeddingService;
15//!
16//! let service = EmbeddingService::new(store, embedding_provider);
17//!
18//! // Embed all pending datasets
19//! let stats = service.embed_pending(None, &reporter, cancel_token).await?;
20//! println!("Embedded {} datasets", stats.embedded);
21//! ```
22
23use tokio_util::sync::CancellationToken;
24use tracing;
25
26use crate::AppError;
27use crate::circuit_breaker::CircuitBreaker;
28use crate::circuit_breaker::CircuitBreakerError;
29use crate::config::EmbeddingServiceConfig;
30use crate::models::NewDataset;
31use crate::progress::{HarvestEvent, ProgressReporter};
32use crate::traits::{DatasetStore, EmbeddingProvider};
33
34/// Statistics from an embedding run.
35#[derive(Debug, Clone, Default)]
36pub struct EmbeddingStats {
37 /// Number of datasets successfully embedded.
38 pub embedded: usize,
39 /// Number of datasets that failed embedding.
40 pub failed: usize,
41 /// Number of datasets skipped (circuit breaker open).
42 pub skipped: usize,
43 /// Total number of datasets that needed embedding.
44 pub total: usize,
45}
46
47impl EmbeddingStats {
48 /// Returns the number of datasets successfully processed (embedded).
49 pub fn successful(&self) -> usize {
50 self.embedded
51 }
52}
53
54/// Outcome of attempting to embed a single batch.
55///
56/// Distinguishes a *recoverable* circuit-open skip (the provider is temporarily
57/// unavailable — Ollama timeout, external rate limit — and the batch should be
58/// retried after the breaker's recovery window) from a *terminal* outcome where
59/// the batch was processed (whether it embedded, failed, or had empty text) and
60/// should not be retried.
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62enum BatchOutcome {
63 /// The batch reached a terminal state (embedded, failed, or empty); move on.
64 Processed,
65 /// The circuit breaker was open; the batch is still pending and recoverable
66 /// after `retry_after`.
67 CircuitOpen { retry_after: std::time::Duration },
68}
69
70impl std::fmt::Display for EmbeddingStats {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 write!(
73 f,
74 "embedded: {}, failed: {}, skipped: {}, total: {}",
75 self.embedded, self.failed, self.skipped, self.total
76 )
77 }
78}
79
80/// Standalone service for generating embeddings for datasets already in the database.
81///
82/// Queries datasets with `embedding IS NULL`, generates embeddings in batches
83/// through a circuit breaker, and upserts them back to the database.
84pub struct EmbeddingService<S, E>
85where
86 S: DatasetStore,
87 E: EmbeddingProvider,
88{
89 store: S,
90 embedding: E,
91 config: EmbeddingServiceConfig,
92}
93
94impl<S, E> Clone for EmbeddingService<S, E>
95where
96 S: DatasetStore + Clone,
97 E: EmbeddingProvider + Clone,
98{
99 fn clone(&self) -> Self {
100 Self {
101 store: self.store.clone(),
102 embedding: self.embedding.clone(),
103 config: self.config.clone(),
104 }
105 }
106}
107
108impl<S, E> EmbeddingService<S, E>
109where
110 S: DatasetStore,
111 E: EmbeddingProvider,
112{
113 /// Creates a new embedding service with default configuration.
114 pub fn new(store: S, embedding: E) -> Self {
115 Self {
116 store,
117 embedding,
118 config: EmbeddingServiceConfig::default(),
119 }
120 }
121
122 /// Creates a new embedding service with custom configuration.
123 pub fn with_config(store: S, embedding: E, config: EmbeddingServiceConfig) -> Self {
124 Self {
125 store,
126 embedding,
127 config,
128 }
129 }
130
131 /// Returns a reference to the underlying embedding provider.
132 pub fn embedding_provider(&self) -> &E {
133 &self.embedding
134 }
135
136 /// Embeds all datasets with `embedding IS NULL`.
137 ///
138 /// Fetches pending datasets from the database, generates embeddings in
139 /// batches through the circuit breaker, and upserts them back.
140 ///
141 /// # Arguments
142 ///
143 /// * `portal_filter` - Optional portal URL to scope the embedding pass
144 /// * `reporter` - Progress reporter for UI/logging
145 /// * `cancel_token` - Token for graceful cancellation
146 pub async fn embed_pending(
147 &self,
148 portal_filter: Option<&str>,
149 reporter: &impl ProgressReporter,
150 cancel_token: CancellationToken,
151 ) -> Result<EmbeddingStats, AppError> {
152 let total = self.store.count_pending_embeddings(portal_filter).await? as usize;
153
154 if total == 0 {
155 tracing::info!("No datasets pending embedding");
156 return Ok(EmbeddingStats::default());
157 }
158
159 tracing::info!(
160 total,
161 portal = portal_filter.unwrap_or("all"),
162 provider = self.embedding.name(),
163 "Starting embedding pass"
164 );
165
166 let mut stats = EmbeddingStats {
167 total,
168 ..Default::default()
169 };
170
171 let effective_batch_size =
172 std::cmp::min(self.config.batch_size, self.embedding.max_batch_size()).max(1);
173
174 let circuit_breaker =
175 CircuitBreaker::new(self.embedding.name(), self.config.circuit_breaker.clone());
176
177 let mut processed = 0usize;
178
179 // Page through pending datasets to avoid loading everything into memory.
180 // Each iteration fetches up to `page_size` rows, processes them, then
181 // fetches the next page. This keeps memory bounded even with 350k+ pending.
182 let page_size = effective_batch_size * 10; // ~10 batches per page
183 loop {
184 if cancel_token.is_cancelled() {
185 tracing::info!("Embedding pass cancelled");
186 break;
187 }
188
189 let page = self
190 .store
191 .list_pending_embeddings(portal_filter, Some(page_size))
192 .await?;
193
194 if page.is_empty() {
195 break;
196 }
197
198 let embedded_before = stats.embedded;
199
200 'batches: for batch in page.chunks(effective_batch_size) {
201 if cancel_token.is_cancelled() {
202 tracing::info!("Embedding pass cancelled");
203 break;
204 }
205
206 // Retry a batch deferred by an open circuit: wait out the breaker's
207 // recovery window (cancellable) so the Open -> HalfOpen transition can
208 // re-test the provider, then retry the *same* batch. After
209 // MAX_CIRCUIT_RETRIES we give up on this batch (counted as skipped,
210 // still pending in DB) and move on — the daemon is never blocked.
211 const MAX_CIRCUIT_RETRIES: u32 = 5;
212 let mut attempt = 0;
213 loop {
214 match self
215 .process_batch(batch, &circuit_breaker, &mut stats)
216 .await
217 {
218 BatchOutcome::Processed => break,
219 BatchOutcome::CircuitOpen { retry_after } => {
220 attempt += 1;
221 if attempt > MAX_CIRCUIT_RETRIES {
222 tracing::warn!(
223 batch_size = batch.len(),
224 attempts = attempt - 1,
225 "Circuit still open after retries — leaving batch pending"
226 );
227 stats.skipped += batch.len();
228 break;
229 }
230 tracing::info!(
231 attempt,
232 wait_secs = retry_after.as_secs(),
233 "Circuit open — waiting for recovery before retry"
234 );
235 tokio::select! {
236 _ = tokio::time::sleep(retry_after) => {}
237 _ = cancel_token.cancelled() => {
238 // Cancelled mid-wait: the batch never reached a
239 // terminal outcome, so don't count or report it —
240 // it stays pending. Stop the whole pass.
241 tracing::info!("Embedding pass cancelled during circuit wait");
242 break 'batches;
243 }
244 }
245 }
246 }
247 }
248
249 processed += batch.len();
250
251 reporter.report(HarvestEvent::DatasetProcessed {
252 current: processed,
253 total,
254 created: 0,
255 updated: stats.embedded,
256 unchanged: 0,
257 failed: stats.failed,
258 skipped: stats.skipped,
259 });
260 }
261
262 // If no datasets were successfully embedded this page, stop to avoid
263 // an infinite loop re-fetching the same failing datasets.
264 if stats.embedded == embedded_before {
265 tracing::warn!(
266 "No progress this page — stopping to avoid infinite loop \
267 ({} failed, {} skipped)",
268 stats.failed,
269 stats.skipped
270 );
271 break;
272 }
273 }
274
275 tracing::info!(
276 embedded = stats.embedded,
277 failed = stats.failed,
278 skipped = stats.skipped,
279 total = stats.total,
280 "Embedding pass complete"
281 );
282
283 Ok(stats)
284 }
285
286 /// Processes a batch of datasets: generates embeddings via circuit breaker,
287 /// then upserts them back to the database.
288 ///
289 /// Returns [`BatchOutcome::CircuitOpen`] without mutating `stats` when the
290 /// breaker is open, so the caller can wait for recovery and retry the same
291 /// batch. All terminal outcomes update `stats` and return
292 /// [`BatchOutcome::Processed`].
293 async fn process_batch(
294 &self,
295 datasets: &[crate::Dataset],
296 circuit_breaker: &CircuitBreaker,
297 stats: &mut EmbeddingStats,
298 ) -> BatchOutcome {
299 // Compute text to embed for each dataset, filtering out empty text
300 let embeddable: Vec<(&crate::Dataset, String)> = datasets
301 .iter()
302 .filter_map(|d| {
303 let text = format!(
304 "{} {}",
305 d.title,
306 d.description.as_deref().unwrap_or_default()
307 );
308 if text.trim().is_empty() {
309 None
310 } else {
311 Some((d, text))
312 }
313 })
314 .collect();
315
316 let skipped_empty = datasets.len() - embeddable.len();
317 if skipped_empty > 0 {
318 tracing::debug!(skipped_empty, "Skipped datasets with empty text");
319 stats.failed += skipped_empty;
320 }
321
322 if embeddable.is_empty() {
323 return BatchOutcome::Processed;
324 }
325
326 let texts: Vec<String> = embeddable.iter().map(|(_, t)| t.clone()).collect();
327 let batch_size = texts.len();
328
329 match circuit_breaker
330 .call(|| self.embedding.generate_batch(&texts))
331 .await
332 {
333 Ok(embeddings) => {
334 if embeddings.len() != batch_size {
335 tracing::warn!(
336 expected = batch_size,
337 got = embeddings.len(),
338 "Batch embedding count mismatch, failing batch"
339 );
340 stats.failed += batch_size;
341 return BatchOutcome::Processed;
342 }
343
344 // Build NewDataset items with embeddings for upsert.
345 // Use existing content_hash from DB — it's always present for stored datasets.
346 let upsert_datasets: Vec<NewDataset> = embeddable
347 .iter()
348 .zip(embeddings)
349 .map(|((d, _), emb)| {
350 let content_hash = match &d.content_hash {
351 Some(h) => h.clone(),
352 None => {
353 tracing::info!(
354 original_id = %d.original_id,
355 "Dataset missing content_hash, automatically generating one"
356 );
357 NewDataset::compute_content_hash(&d.title, d.description.as_deref())
358 }
359 };
360 NewDataset {
361 original_id: d.original_id.clone(),
362 source_portal: d.source_portal.clone(),
363 url: d.url.clone(),
364 title: d.title.clone(),
365 description: d.description.clone(),
366 embedding: Some(emb),
367 metadata: d.metadata.clone(),
368 content_hash,
369 }
370 })
371 .collect();
372
373 let skipped_no_hash = batch_size - upsert_datasets.len();
374 stats.failed += skipped_no_hash;
375 let upsert_count = upsert_datasets.len();
376
377 match self.store.batch_upsert(&upsert_datasets).await {
378 Ok(_) => {
379 stats.embedded += upsert_count;
380 }
381 Err(e) => {
382 tracing::warn!(
383 count = upsert_count,
384 error = %e,
385 "Failed to batch upsert datasets with embeddings"
386 );
387 stats.failed += upsert_count;
388 }
389 }
390 BatchOutcome::Processed
391 }
392 Err(CircuitBreakerError::Open { retry_after, .. }) => {
393 // Recoverable: the provider is temporarily down. Don't count the
394 // batch yet — the caller waits for recovery and retries it.
395 tracing::debug!(
396 batch_size,
397 retry_after_secs = retry_after.as_secs(),
398 "Circuit breaker open - batch deferred for retry"
399 );
400 BatchOutcome::CircuitOpen { retry_after }
401 }
402 Err(CircuitBreakerError::Inner(e)) => {
403 tracing::warn!(
404 batch_size,
405 error = %e,
406 "Batch embedding generation failed"
407 );
408 stats.failed += batch_size;
409 BatchOutcome::Processed
410 }
411 }
412 }
413}