ceres-core 0.4.0

Core types, harvesting logic, and services for Ceres
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
//! Standalone embedding service for generating dataset embeddings.
//!
//! This service is decoupled from harvesting — it processes datasets that are
//! already stored in the database with `embedding IS NULL`. This enables:
//!
//! - Harvesting metadata without an embedding API key
//! - Switching embedding providers without re-harvesting
//! - Backfilling embeddings after outages
//! - Independent scaling of harvest and embedding workloads
//!
//! # Example
//!
//! ```ignore
//! use ceres_core::embedding::EmbeddingService;
//!
//! let service = EmbeddingService::new(store, embedding_provider);
//!
//! // Embed all pending datasets
//! let stats = service.embed_pending(None, &reporter, cancel_token).await?;
//! println!("Embedded {} datasets", stats.embedded);
//! ```

use tokio_util::sync::CancellationToken;
use tracing;

use crate::AppError;
use crate::circuit_breaker::CircuitBreaker;
use crate::circuit_breaker::CircuitBreakerError;
use crate::config::EmbeddingServiceConfig;
use crate::models::NewDataset;
use crate::progress::{HarvestEvent, ProgressReporter};
use crate::traits::{DatasetStore, EmbeddingProvider};

/// Statistics from an embedding run.
#[derive(Debug, Clone, Default)]
pub struct EmbeddingStats {
    /// Number of datasets successfully embedded.
    pub embedded: usize,
    /// Number of datasets that failed embedding.
    pub failed: usize,
    /// Number of datasets skipped (circuit breaker open).
    pub skipped: usize,
    /// Total number of datasets that needed embedding.
    pub total: usize,
}

impl EmbeddingStats {
    /// Returns the number of datasets successfully processed (embedded).
    pub fn successful(&self) -> usize {
        self.embedded
    }
}

/// Outcome of attempting to embed a single batch.
///
/// Distinguishes a *recoverable* circuit-open skip (the provider is temporarily
/// unavailable — Ollama timeout, external rate limit — and the batch should be
/// retried after the breaker's recovery window) from a *terminal* outcome where
/// the batch was processed (whether it embedded, failed, or had empty text) and
/// should not be retried.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum BatchOutcome {
    /// The batch reached a terminal state (embedded, failed, or empty); move on.
    Processed,
    /// The circuit breaker was open; the batch is still pending and recoverable
    /// after `retry_after`.
    CircuitOpen { retry_after: std::time::Duration },
}

impl std::fmt::Display for EmbeddingStats {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(
            f,
            "embedded: {}, failed: {}, skipped: {}, total: {}",
            self.embedded, self.failed, self.skipped, self.total
        )
    }
}

/// Standalone service for generating embeddings for datasets already in the database.
///
/// Queries datasets with `embedding IS NULL`, generates embeddings in batches
/// through a circuit breaker, and upserts them back to the database.
pub struct EmbeddingService<S, E>
where
    S: DatasetStore,
    E: EmbeddingProvider,
{
    store: S,
    embedding: E,
    config: EmbeddingServiceConfig,
}

impl<S, E> Clone for EmbeddingService<S, E>
where
    S: DatasetStore + Clone,
    E: EmbeddingProvider + Clone,
{
    fn clone(&self) -> Self {
        Self {
            store: self.store.clone(),
            embedding: self.embedding.clone(),
            config: self.config.clone(),
        }
    }
}

impl<S, E> EmbeddingService<S, E>
where
    S: DatasetStore,
    E: EmbeddingProvider,
{
    /// Creates a new embedding service with default configuration.
    pub fn new(store: S, embedding: E) -> Self {
        Self {
            store,
            embedding,
            config: EmbeddingServiceConfig::default(),
        }
    }

    /// Creates a new embedding service with custom configuration.
    pub fn with_config(store: S, embedding: E, config: EmbeddingServiceConfig) -> Self {
        Self {
            store,
            embedding,
            config,
        }
    }

    /// Returns a reference to the underlying embedding provider.
    pub fn embedding_provider(&self) -> &E {
        &self.embedding
    }

    /// Embeds all datasets with `embedding IS NULL`.
    ///
    /// Fetches pending datasets from the database, generates embeddings in
    /// batches through the circuit breaker, and upserts them back.
    ///
    /// # Arguments
    ///
    /// * `portal_filter` - Optional portal URL to scope the embedding pass
    /// * `reporter` - Progress reporter for UI/logging
    /// * `cancel_token` - Token for graceful cancellation
    pub async fn embed_pending(
        &self,
        portal_filter: Option<&str>,
        reporter: &impl ProgressReporter,
        cancel_token: CancellationToken,
    ) -> Result<EmbeddingStats, AppError> {
        let total = self.store.count_pending_embeddings(portal_filter).await? as usize;

        if total == 0 {
            tracing::info!("No datasets pending embedding");
            return Ok(EmbeddingStats::default());
        }

        tracing::info!(
            total,
            portal = portal_filter.unwrap_or("all"),
            provider = self.embedding.name(),
            "Starting embedding pass"
        );

        let mut stats = EmbeddingStats {
            total,
            ..Default::default()
        };

        let effective_batch_size =
            std::cmp::min(self.config.batch_size, self.embedding.max_batch_size()).max(1);

        let circuit_breaker =
            CircuitBreaker::new(self.embedding.name(), self.config.circuit_breaker.clone());

        let mut processed = 0usize;

        // Page through pending datasets to avoid loading everything into memory.
        // Each iteration fetches up to `page_size` rows, processes them, then
        // fetches the next page. This keeps memory bounded even with 350k+ pending.
        let page_size = effective_batch_size * 10; // ~10 batches per page
        loop {
            if cancel_token.is_cancelled() {
                tracing::info!("Embedding pass cancelled");
                break;
            }

            let page = self
                .store
                .list_pending_embeddings(portal_filter, Some(page_size))
                .await?;

            if page.is_empty() {
                break;
            }

            let embedded_before = stats.embedded;

            'batches: for batch in page.chunks(effective_batch_size) {
                if cancel_token.is_cancelled() {
                    tracing::info!("Embedding pass cancelled");
                    break;
                }

                // Retry a batch deferred by an open circuit: wait out the breaker's
                // recovery window (cancellable) so the Open -> HalfOpen transition can
                // re-test the provider, then retry the *same* batch. After
                // MAX_CIRCUIT_RETRIES we give up on this batch (counted as skipped,
                // still pending in DB) and move on — the daemon is never blocked.
                const MAX_CIRCUIT_RETRIES: u32 = 5;
                let mut attempt = 0;
                loop {
                    match self
                        .process_batch(batch, &circuit_breaker, &mut stats)
                        .await
                    {
                        BatchOutcome::Processed => break,
                        BatchOutcome::CircuitOpen { retry_after } => {
                            attempt += 1;
                            if attempt > MAX_CIRCUIT_RETRIES {
                                tracing::warn!(
                                    batch_size = batch.len(),
                                    attempts = attempt - 1,
                                    "Circuit still open after retries — leaving batch pending"
                                );
                                stats.skipped += batch.len();
                                break;
                            }
                            tracing::info!(
                                attempt,
                                wait_secs = retry_after.as_secs(),
                                "Circuit open — waiting for recovery before retry"
                            );
                            tokio::select! {
                                _ = tokio::time::sleep(retry_after) => {}
                                _ = cancel_token.cancelled() => {
                                    // Cancelled mid-wait: the batch never reached a
                                    // terminal outcome, so don't count or report it —
                                    // it stays pending. Stop the whole pass.
                                    tracing::info!("Embedding pass cancelled during circuit wait");
                                    break 'batches;
                                }
                            }
                        }
                    }
                }

                processed += batch.len();

                reporter.report(HarvestEvent::DatasetProcessed {
                    current: processed,
                    total,
                    created: 0,
                    updated: stats.embedded,
                    unchanged: 0,
                    failed: stats.failed,
                    skipped: stats.skipped,
                });
            }

            // If no datasets were successfully embedded this page, stop to avoid
            // an infinite loop re-fetching the same failing datasets.
            if stats.embedded == embedded_before {
                tracing::warn!(
                    "No progress this page — stopping to avoid infinite loop \
                     ({} failed, {} skipped)",
                    stats.failed,
                    stats.skipped
                );
                break;
            }
        }

        tracing::info!(
            embedded = stats.embedded,
            failed = stats.failed,
            skipped = stats.skipped,
            total = stats.total,
            "Embedding pass complete"
        );

        Ok(stats)
    }

    /// Processes a batch of datasets: generates embeddings via circuit breaker,
    /// then upserts them back to the database.
    ///
    /// Returns [`BatchOutcome::CircuitOpen`] without mutating `stats` when the
    /// breaker is open, so the caller can wait for recovery and retry the same
    /// batch. All terminal outcomes update `stats` and return
    /// [`BatchOutcome::Processed`].
    async fn process_batch(
        &self,
        datasets: &[crate::Dataset],
        circuit_breaker: &CircuitBreaker,
        stats: &mut EmbeddingStats,
    ) -> BatchOutcome {
        // Compute text to embed for each dataset, filtering out empty text
        let embeddable: Vec<(&crate::Dataset, String)> = datasets
            .iter()
            .filter_map(|d| {
                let text = format!(
                    "{} {}",
                    d.title,
                    d.description.as_deref().unwrap_or_default()
                );
                if text.trim().is_empty() {
                    None
                } else {
                    Some((d, text))
                }
            })
            .collect();

        let skipped_empty = datasets.len() - embeddable.len();
        if skipped_empty > 0 {
            tracing::debug!(skipped_empty, "Skipped datasets with empty text");
            stats.failed += skipped_empty;
        }

        if embeddable.is_empty() {
            return BatchOutcome::Processed;
        }

        let texts: Vec<String> = embeddable.iter().map(|(_, t)| t.clone()).collect();
        let batch_size = texts.len();

        match circuit_breaker
            .call(|| self.embedding.generate_batch(&texts))
            .await
        {
            Ok(embeddings) => {
                if embeddings.len() != batch_size {
                    tracing::warn!(
                        expected = batch_size,
                        got = embeddings.len(),
                        "Batch embedding count mismatch, failing batch"
                    );
                    stats.failed += batch_size;
                    return BatchOutcome::Processed;
                }

                // Build NewDataset items with embeddings for upsert.
                // Use existing content_hash from DB — it's always present for stored datasets.
                let upsert_datasets: Vec<NewDataset> = embeddable
                    .iter()
                    .zip(embeddings)
                    .map(|((d, _), emb)| {
                        let content_hash = match &d.content_hash {
                            Some(h) => h.clone(),
                            None => {
                                tracing::info!(
                                    original_id = %d.original_id,
                                    "Dataset missing content_hash, automatically generating one"
                                );
                                NewDataset::compute_content_hash(&d.title, d.description.as_deref())
                            }
                        };
                        NewDataset {
                            original_id: d.original_id.clone(),
                            source_portal: d.source_portal.clone(),
                            url: d.url.clone(),
                            title: d.title.clone(),
                            description: d.description.clone(),
                            embedding: Some(emb),
                            metadata: d.metadata.clone(),
                            content_hash,
                        }
                    })
                    .collect();

                let skipped_no_hash = batch_size - upsert_datasets.len();
                stats.failed += skipped_no_hash;
                let upsert_count = upsert_datasets.len();

                match self.store.batch_upsert(&upsert_datasets).await {
                    Ok(_) => {
                        stats.embedded += upsert_count;
                    }
                    Err(e) => {
                        tracing::warn!(
                            count = upsert_count,
                            error = %e,
                            "Failed to batch upsert datasets with embeddings"
                        );
                        stats.failed += upsert_count;
                    }
                }
                BatchOutcome::Processed
            }
            Err(CircuitBreakerError::Open { retry_after, .. }) => {
                // Recoverable: the provider is temporarily down. Don't count the
                // batch yet — the caller waits for recovery and retries it.
                tracing::debug!(
                    batch_size,
                    retry_after_secs = retry_after.as_secs(),
                    "Circuit breaker open - batch deferred for retry"
                );
                BatchOutcome::CircuitOpen { retry_after }
            }
            Err(CircuitBreakerError::Inner(e)) => {
                tracing::warn!(
                    batch_size,
                    error = %e,
                    "Batch embedding generation failed"
                );
                stats.failed += batch_size;
                BatchOutcome::Processed
            }
        }
    }
}