Skip to main content

ceres_core/
embedding.rs

1//! Standalone embedding service for generating dataset embeddings.
2//!
3//! This service is decoupled from harvesting — it processes datasets that are
4//! already stored in the database with `embedding IS NULL`. This enables:
5//!
6//! - Harvesting metadata without an embedding API key
7//! - Switching embedding providers without re-harvesting
8//! - Backfilling embeddings after outages
9//! - Independent scaling of harvest and embedding workloads
10//!
11//! # Example
12//!
13//! ```ignore
14//! use ceres_core::embedding::EmbeddingService;
15//!
16//! let service = EmbeddingService::new(store, embedding_provider);
17//!
18//! // Embed all pending datasets
19//! let stats = service.embed_pending(None, &reporter, cancel_token).await?;
20//! println!("Embedded {} datasets", stats.embedded);
21//! ```
22
23use tokio_util::sync::CancellationToken;
24use tracing;
25
26use crate::AppError;
27use crate::circuit_breaker::CircuitBreaker;
28use crate::circuit_breaker::CircuitBreakerError;
29use crate::config::EmbeddingServiceConfig;
30use crate::models::NewDataset;
31use crate::progress::{HarvestEvent, ProgressReporter};
32use crate::traits::{DatasetStore, EmbeddingProvider};
33
34/// Statistics from an embedding run.
35#[derive(Debug, Clone, Default)]
36pub struct EmbeddingStats {
37    /// Number of datasets successfully embedded.
38    pub embedded: usize,
39    /// Number of datasets that failed embedding.
40    pub failed: usize,
41    /// Number of datasets skipped (circuit breaker open).
42    pub skipped: usize,
43    /// Total number of datasets that needed embedding.
44    pub total: usize,
45}
46
47impl EmbeddingStats {
48    /// Returns the number of datasets successfully processed (embedded).
49    pub fn successful(&self) -> usize {
50        self.embedded
51    }
52}
53
54/// Outcome of attempting to embed a single batch.
55///
56/// Distinguishes a *recoverable* circuit-open skip (the provider is temporarily
57/// unavailable — Ollama timeout, external rate limit — and the batch should be
58/// retried after the breaker's recovery window) from a *terminal* outcome where
59/// the batch was processed (whether it embedded, failed, or had empty text) and
60/// should not be retried.
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62enum BatchOutcome {
63    /// The batch reached a terminal state (embedded, failed, or empty); move on.
64    Processed,
65    /// The circuit breaker was open; the batch is still pending and recoverable
66    /// after `retry_after`.
67    CircuitOpen { retry_after: std::time::Duration },
68}
69
70impl std::fmt::Display for EmbeddingStats {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        write!(
73            f,
74            "embedded: {}, failed: {}, skipped: {}, total: {}",
75            self.embedded, self.failed, self.skipped, self.total
76        )
77    }
78}
79
80/// Standalone service for generating embeddings for datasets already in the database.
81///
82/// Queries datasets with `embedding IS NULL`, generates embeddings in batches
83/// through a circuit breaker, and upserts them back to the database.
84pub struct EmbeddingService<S, E>
85where
86    S: DatasetStore,
87    E: EmbeddingProvider,
88{
89    store: S,
90    embedding: E,
91    config: EmbeddingServiceConfig,
92}
93
94impl<S, E> Clone for EmbeddingService<S, E>
95where
96    S: DatasetStore + Clone,
97    E: EmbeddingProvider + Clone,
98{
99    fn clone(&self) -> Self {
100        Self {
101            store: self.store.clone(),
102            embedding: self.embedding.clone(),
103            config: self.config.clone(),
104        }
105    }
106}
107
108impl<S, E> EmbeddingService<S, E>
109where
110    S: DatasetStore,
111    E: EmbeddingProvider,
112{
113    /// Creates a new embedding service with default configuration.
114    pub fn new(store: S, embedding: E) -> Self {
115        Self {
116            store,
117            embedding,
118            config: EmbeddingServiceConfig::default(),
119        }
120    }
121
122    /// Creates a new embedding service with custom configuration.
123    pub fn with_config(store: S, embedding: E, config: EmbeddingServiceConfig) -> Self {
124        Self {
125            store,
126            embedding,
127            config,
128        }
129    }
130
131    /// Returns a reference to the underlying embedding provider.
132    pub fn embedding_provider(&self) -> &E {
133        &self.embedding
134    }
135
136    /// Embeds all datasets with `embedding IS NULL`.
137    ///
138    /// Fetches pending datasets from the database, generates embeddings in
139    /// batches through the circuit breaker, and upserts them back.
140    ///
141    /// # Arguments
142    ///
143    /// * `portal_filter` - Optional portal URL to scope the embedding pass
144    /// * `reporter` - Progress reporter for UI/logging
145    /// * `cancel_token` - Token for graceful cancellation
146    pub async fn embed_pending(
147        &self,
148        portal_filter: Option<&str>,
149        reporter: &impl ProgressReporter,
150        cancel_token: CancellationToken,
151    ) -> Result<EmbeddingStats, AppError> {
152        let total = self.store.count_pending_embeddings(portal_filter).await? as usize;
153
154        if total == 0 {
155            tracing::info!("No datasets pending embedding");
156            return Ok(EmbeddingStats::default());
157        }
158
159        tracing::info!(
160            total,
161            portal = portal_filter.unwrap_or("all"),
162            provider = self.embedding.name(),
163            "Starting embedding pass"
164        );
165
166        let mut stats = EmbeddingStats {
167            total,
168            ..Default::default()
169        };
170
171        let effective_batch_size =
172            std::cmp::min(self.config.batch_size, self.embedding.max_batch_size()).max(1);
173
174        let circuit_breaker =
175            CircuitBreaker::new(self.embedding.name(), self.config.circuit_breaker.clone());
176
177        let mut processed = 0usize;
178
179        // Page through pending datasets to avoid loading everything into memory.
180        // Each iteration fetches up to `page_size` rows, processes them, then
181        // fetches the next page. This keeps memory bounded even with 350k+ pending.
182        let page_size = effective_batch_size * 10; // ~10 batches per page
183        loop {
184            if cancel_token.is_cancelled() {
185                tracing::info!("Embedding pass cancelled");
186                break;
187            }
188
189            let page = self
190                .store
191                .list_pending_embeddings(portal_filter, Some(page_size))
192                .await?;
193
194            if page.is_empty() {
195                break;
196            }
197
198            let embedded_before = stats.embedded;
199
200            'batches: for batch in page.chunks(effective_batch_size) {
201                if cancel_token.is_cancelled() {
202                    tracing::info!("Embedding pass cancelled");
203                    break;
204                }
205
206                // Retry a batch deferred by an open circuit: wait out the breaker's
207                // recovery window (cancellable) so the Open -> HalfOpen transition can
208                // re-test the provider, then retry the *same* batch. After
209                // MAX_CIRCUIT_RETRIES we give up on this batch (counted as skipped,
210                // still pending in DB) and move on — the daemon is never blocked.
211                const MAX_CIRCUIT_RETRIES: u32 = 5;
212                let mut attempt = 0;
213                loop {
214                    match self
215                        .process_batch(batch, &circuit_breaker, &mut stats)
216                        .await
217                    {
218                        BatchOutcome::Processed => break,
219                        BatchOutcome::CircuitOpen { retry_after } => {
220                            attempt += 1;
221                            if attempt > MAX_CIRCUIT_RETRIES {
222                                tracing::warn!(
223                                    batch_size = batch.len(),
224                                    attempts = attempt - 1,
225                                    "Circuit still open after retries — leaving batch pending"
226                                );
227                                stats.skipped += batch.len();
228                                break;
229                            }
230                            tracing::info!(
231                                attempt,
232                                wait_secs = retry_after.as_secs(),
233                                "Circuit open — waiting for recovery before retry"
234                            );
235                            tokio::select! {
236                                _ = tokio::time::sleep(retry_after) => {}
237                                _ = cancel_token.cancelled() => {
238                                    // Cancelled mid-wait: the batch never reached a
239                                    // terminal outcome, so don't count or report it —
240                                    // it stays pending. Stop the whole pass.
241                                    tracing::info!("Embedding pass cancelled during circuit wait");
242                                    break 'batches;
243                                }
244                            }
245                        }
246                    }
247                }
248
249                processed += batch.len();
250
251                reporter.report(HarvestEvent::DatasetProcessed {
252                    current: processed,
253                    total,
254                    created: 0,
255                    updated: stats.embedded,
256                    unchanged: 0,
257                    failed: stats.failed,
258                    skipped: stats.skipped,
259                });
260            }
261
262            // If no datasets were successfully embedded this page, stop to avoid
263            // an infinite loop re-fetching the same failing datasets.
264            if stats.embedded == embedded_before {
265                tracing::warn!(
266                    "No progress this page — stopping to avoid infinite loop \
267                     ({} failed, {} skipped)",
268                    stats.failed,
269                    stats.skipped
270                );
271                break;
272            }
273        }
274
275        tracing::info!(
276            embedded = stats.embedded,
277            failed = stats.failed,
278            skipped = stats.skipped,
279            total = stats.total,
280            "Embedding pass complete"
281        );
282
283        Ok(stats)
284    }
285
286    /// Processes a batch of datasets: generates embeddings via circuit breaker,
287    /// then upserts them back to the database.
288    ///
289    /// Returns [`BatchOutcome::CircuitOpen`] without mutating `stats` when the
290    /// breaker is open, so the caller can wait for recovery and retry the same
291    /// batch. All terminal outcomes update `stats` and return
292    /// [`BatchOutcome::Processed`].
293    async fn process_batch(
294        &self,
295        datasets: &[crate::Dataset],
296        circuit_breaker: &CircuitBreaker,
297        stats: &mut EmbeddingStats,
298    ) -> BatchOutcome {
299        // Compute text to embed for each dataset, filtering out empty text
300        let embeddable: Vec<(&crate::Dataset, String)> = datasets
301            .iter()
302            .filter_map(|d| {
303                let text = format!(
304                    "{} {}",
305                    d.title,
306                    d.description.as_deref().unwrap_or_default()
307                );
308                if text.trim().is_empty() {
309                    None
310                } else {
311                    Some((d, text))
312                }
313            })
314            .collect();
315
316        let skipped_empty = datasets.len() - embeddable.len();
317        if skipped_empty > 0 {
318            tracing::debug!(skipped_empty, "Skipped datasets with empty text");
319            stats.failed += skipped_empty;
320        }
321
322        if embeddable.is_empty() {
323            return BatchOutcome::Processed;
324        }
325
326        let texts: Vec<String> = embeddable.iter().map(|(_, t)| t.clone()).collect();
327        let batch_size = texts.len();
328
329        match circuit_breaker
330            .call(|| self.embedding.generate_batch(&texts))
331            .await
332        {
333            Ok(embeddings) => {
334                if embeddings.len() != batch_size {
335                    tracing::warn!(
336                        expected = batch_size,
337                        got = embeddings.len(),
338                        "Batch embedding count mismatch, failing batch"
339                    );
340                    stats.failed += batch_size;
341                    return BatchOutcome::Processed;
342                }
343
344                // Build NewDataset items with embeddings for upsert.
345                // Use existing content_hash from DB — it's always present for stored datasets.
346                let upsert_datasets: Vec<NewDataset> = embeddable
347                    .iter()
348                    .zip(embeddings)
349                    .map(|((d, _), emb)| {
350                        let content_hash = match &d.content_hash {
351                            Some(h) => h.clone(),
352                            None => {
353                                tracing::info!(
354                                    original_id = %d.original_id,
355                                    "Dataset missing content_hash, automatically generating one"
356                                );
357                                NewDataset::compute_content_hash(&d.title, d.description.as_deref())
358                            }
359                        };
360                        NewDataset {
361                            original_id: d.original_id.clone(),
362                            source_portal: d.source_portal.clone(),
363                            url: d.url.clone(),
364                            title: d.title.clone(),
365                            description: d.description.clone(),
366                            embedding: Some(emb),
367                            metadata: d.metadata.clone(),
368                            content_hash,
369                        }
370                    })
371                    .collect();
372
373                let skipped_no_hash = batch_size - upsert_datasets.len();
374                stats.failed += skipped_no_hash;
375                let upsert_count = upsert_datasets.len();
376
377                match self.store.batch_upsert(&upsert_datasets).await {
378                    Ok(_) => {
379                        stats.embedded += upsert_count;
380                    }
381                    Err(e) => {
382                        tracing::warn!(
383                            count = upsert_count,
384                            error = %e,
385                            "Failed to batch upsert datasets with embeddings"
386                        );
387                        stats.failed += upsert_count;
388                    }
389                }
390                BatchOutcome::Processed
391            }
392            Err(CircuitBreakerError::Open { retry_after, .. }) => {
393                // Recoverable: the provider is temporarily down. Don't count the
394                // batch yet — the caller waits for recovery and retries it.
395                tracing::debug!(
396                    batch_size,
397                    retry_after_secs = retry_after.as_secs(),
398                    "Circuit breaker open - batch deferred for retry"
399                );
400                BatchOutcome::CircuitOpen { retry_after }
401            }
402            Err(CircuitBreakerError::Inner(e)) => {
403                tracing::warn!(
404                    batch_size,
405                    error = %e,
406                    "Batch embedding generation failed"
407                );
408                stats.failed += batch_size;
409                BatchOutcome::Processed
410            }
411        }
412    }
413}