Skip to main content

ceres_core/
embedding.rs

1//! Standalone embedding service for generating dataset embeddings.
2//!
3//! This service is decoupled from harvesting — it processes datasets that are
4//! already stored in the database with `embedding IS NULL`. This enables:
5//!
6//! - Harvesting metadata without an embedding API key
7//! - Switching embedding providers without re-harvesting
8//! - Backfilling embeddings after outages
9//! - Independent scaling of harvest and embedding workloads
10//!
11//! # Example
12//!
13//! ```ignore
14//! use ceres_core::embedding::EmbeddingService;
15//!
16//! let service = EmbeddingService::new(store, embedding_provider);
17//!
18//! // Embed all pending datasets
19//! let stats = service.embed_pending(None, &reporter, cancel_token).await?;
20//! println!("Embedded {} datasets", stats.embedded);
21//! ```
22
23use tokio_util::sync::CancellationToken;
24use tracing;
25
26use crate::AppError;
27use crate::circuit_breaker::CircuitBreaker;
28use crate::circuit_breaker::CircuitBreakerError;
29use crate::config::EmbeddingServiceConfig;
30use crate::models::NewDataset;
31use crate::progress::{HarvestEvent, ProgressReporter};
32use crate::traits::{DatasetStore, EmbeddingProvider};
33
34/// Statistics from an embedding run.
35#[derive(Debug, Clone, Default)]
36pub struct EmbeddingStats {
37    /// Number of datasets successfully embedded.
38    pub embedded: usize,
39    /// Number of datasets that failed embedding.
40    pub failed: usize,
41    /// Number of datasets skipped (circuit breaker open).
42    pub skipped: usize,
43    /// Total number of datasets that needed embedding.
44    pub total: usize,
45}
46
47impl EmbeddingStats {
48    /// Returns the number of datasets successfully processed (embedded).
49    pub fn successful(&self) -> usize {
50        self.embedded
51    }
52}
53
54impl std::fmt::Display for EmbeddingStats {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        write!(
57            f,
58            "embedded: {}, failed: {}, skipped: {}, total: {}",
59            self.embedded, self.failed, self.skipped, self.total
60        )
61    }
62}
63
64/// Standalone service for generating embeddings for datasets already in the database.
65///
66/// Queries datasets with `embedding IS NULL`, generates embeddings in batches
67/// through a circuit breaker, and upserts them back to the database.
68pub struct EmbeddingService<S, E>
69where
70    S: DatasetStore,
71    E: EmbeddingProvider,
72{
73    store: S,
74    embedding: E,
75    config: EmbeddingServiceConfig,
76}
77
78impl<S, E> Clone for EmbeddingService<S, E>
79where
80    S: DatasetStore + Clone,
81    E: EmbeddingProvider + Clone,
82{
83    fn clone(&self) -> Self {
84        Self {
85            store: self.store.clone(),
86            embedding: self.embedding.clone(),
87            config: self.config.clone(),
88        }
89    }
90}
91
92impl<S, E> EmbeddingService<S, E>
93where
94    S: DatasetStore,
95    E: EmbeddingProvider,
96{
97    /// Creates a new embedding service with default configuration.
98    pub fn new(store: S, embedding: E) -> Self {
99        Self {
100            store,
101            embedding,
102            config: EmbeddingServiceConfig::default(),
103        }
104    }
105
106    /// Creates a new embedding service with custom configuration.
107    pub fn with_config(store: S, embedding: E, config: EmbeddingServiceConfig) -> Self {
108        Self {
109            store,
110            embedding,
111            config,
112        }
113    }
114
115    /// Returns a reference to the underlying embedding provider.
116    pub fn embedding_provider(&self) -> &E {
117        &self.embedding
118    }
119
120    /// Embeds all datasets with `embedding IS NULL`.
121    ///
122    /// Fetches pending datasets from the database, generates embeddings in
123    /// batches through the circuit breaker, and upserts them back.
124    ///
125    /// # Arguments
126    ///
127    /// * `portal_filter` - Optional portal URL to scope the embedding pass
128    /// * `reporter` - Progress reporter for UI/logging
129    /// * `cancel_token` - Token for graceful cancellation
130    pub async fn embed_pending(
131        &self,
132        portal_filter: Option<&str>,
133        reporter: &impl ProgressReporter,
134        cancel_token: CancellationToken,
135    ) -> Result<EmbeddingStats, AppError> {
136        let total = self.store.count_pending_embeddings(portal_filter).await? as usize;
137
138        if total == 0 {
139            tracing::info!("No datasets pending embedding");
140            return Ok(EmbeddingStats::default());
141        }
142
143        tracing::info!(
144            total,
145            portal = portal_filter.unwrap_or("all"),
146            provider = self.embedding.name(),
147            "Starting embedding pass"
148        );
149
150        let mut stats = EmbeddingStats {
151            total,
152            ..Default::default()
153        };
154
155        let effective_batch_size =
156            std::cmp::min(self.config.batch_size, self.embedding.max_batch_size()).max(1);
157
158        let circuit_breaker =
159            CircuitBreaker::new(self.embedding.name(), self.config.circuit_breaker.clone());
160
161        let mut processed = 0usize;
162
163        // Page through pending datasets to avoid loading everything into memory.
164        // Each iteration fetches up to `page_size` rows, processes them, then
165        // fetches the next page. This keeps memory bounded even with 350k+ pending.
166        let page_size = effective_batch_size * 10; // ~10 batches per page
167        loop {
168            if cancel_token.is_cancelled() {
169                tracing::info!("Embedding pass cancelled");
170                break;
171            }
172
173            let page = self
174                .store
175                .list_pending_embeddings(portal_filter, Some(page_size))
176                .await?;
177
178            if page.is_empty() {
179                break;
180            }
181
182            let embedded_before = stats.embedded;
183
184            for batch in page.chunks(effective_batch_size) {
185                if cancel_token.is_cancelled() {
186                    tracing::info!("Embedding pass cancelled");
187                    break;
188                }
189
190                self.process_batch(batch, &circuit_breaker, &mut stats)
191                    .await;
192
193                processed += batch.len();
194
195                reporter.report(HarvestEvent::DatasetProcessed {
196                    current: processed,
197                    total,
198                    created: 0,
199                    updated: stats.embedded,
200                    unchanged: 0,
201                    failed: stats.failed,
202                    skipped: stats.skipped,
203                });
204            }
205
206            // If no datasets were successfully embedded this page, stop to avoid
207            // an infinite loop re-fetching the same failing datasets.
208            if stats.embedded == embedded_before {
209                tracing::warn!(
210                    "No progress this page — stopping to avoid infinite loop \
211                     ({} failed, {} skipped)",
212                    stats.failed,
213                    stats.skipped
214                );
215                break;
216            }
217        }
218
219        tracing::info!(
220            embedded = stats.embedded,
221            failed = stats.failed,
222            skipped = stats.skipped,
223            total = stats.total,
224            "Embedding pass complete"
225        );
226
227        Ok(stats)
228    }
229
230    /// Processes a batch of datasets: generates embeddings via circuit breaker,
231    /// then upserts them back to the database.
232    async fn process_batch(
233        &self,
234        datasets: &[crate::Dataset],
235        circuit_breaker: &CircuitBreaker,
236        stats: &mut EmbeddingStats,
237    ) {
238        // Compute text to embed for each dataset, filtering out empty text
239        let embeddable: Vec<(&crate::Dataset, String)> = datasets
240            .iter()
241            .filter_map(|d| {
242                let text = format!(
243                    "{} {}",
244                    d.title,
245                    d.description.as_deref().unwrap_or_default()
246                );
247                if text.trim().is_empty() {
248                    None
249                } else {
250                    Some((d, text))
251                }
252            })
253            .collect();
254
255        let skipped_empty = datasets.len() - embeddable.len();
256        if skipped_empty > 0 {
257            tracing::debug!(skipped_empty, "Skipped datasets with empty text");
258            stats.failed += skipped_empty;
259        }
260
261        if embeddable.is_empty() {
262            return;
263        }
264
265        let texts: Vec<String> = embeddable.iter().map(|(_, t)| t.clone()).collect();
266        let batch_size = texts.len();
267
268        match circuit_breaker
269            .call(|| self.embedding.generate_batch(&texts))
270            .await
271        {
272            Ok(embeddings) => {
273                if embeddings.len() != batch_size {
274                    tracing::warn!(
275                        expected = batch_size,
276                        got = embeddings.len(),
277                        "Batch embedding count mismatch, failing batch"
278                    );
279                    stats.failed += batch_size;
280                    return;
281                }
282
283                // Build NewDataset items with embeddings for upsert.
284                // Use existing content_hash from DB — it's always present for stored datasets.
285                let upsert_datasets: Vec<NewDataset> = embeddable
286                    .iter()
287                    .zip(embeddings)
288                    .map(|((d, _), emb)| {
289                        let content_hash = match &d.content_hash {
290                            Some(h) => h.clone(),
291                            None => {
292                                tracing::info!(
293                                    original_id = %d.original_id,
294                                    "Dataset missing content_hash, automatically generating one"
295                                );
296                                NewDataset::compute_content_hash(&d.title, d.description.as_deref())
297                            }
298                        };
299                        NewDataset {
300                            original_id: d.original_id.clone(),
301                            source_portal: d.source_portal.clone(),
302                            url: d.url.clone(),
303                            title: d.title.clone(),
304                            description: d.description.clone(),
305                            embedding: Some(emb),
306                            metadata: d.metadata.clone(),
307                            content_hash,
308                        }
309                    })
310                    .collect();
311
312                let skipped_no_hash = batch_size - upsert_datasets.len();
313                stats.failed += skipped_no_hash;
314                let upsert_count = upsert_datasets.len();
315
316                match self.store.batch_upsert(&upsert_datasets).await {
317                    Ok(_) => {
318                        stats.embedded += upsert_count;
319                    }
320                    Err(e) => {
321                        tracing::warn!(
322                            count = upsert_count,
323                            error = %e,
324                            "Failed to batch upsert datasets with embeddings"
325                        );
326                        stats.failed += upsert_count;
327                    }
328                }
329            }
330            Err(CircuitBreakerError::Open { retry_after, .. }) => {
331                tracing::debug!(
332                    batch_size,
333                    retry_after_secs = retry_after.as_secs(),
334                    "Skipping batch - circuit breaker open"
335                );
336                stats.skipped += batch_size;
337            }
338            Err(CircuitBreakerError::Inner(e)) => {
339                tracing::warn!(
340                    batch_size,
341                    error = %e,
342                    "Batch embedding generation failed"
343                );
344                stats.failed += batch_size;
345            }
346        }
347    }
348}