use tokio_util::sync::CancellationToken;
use tracing;
use crate::AppError;
use crate::circuit_breaker::CircuitBreaker;
use crate::circuit_breaker::CircuitBreakerError;
use crate::config::EmbeddingServiceConfig;
use crate::models::NewDataset;
use crate::progress::{HarvestEvent, ProgressReporter};
use crate::traits::{DatasetStore, EmbeddingProvider};
#[derive(Debug, Clone, Default)]
pub struct EmbeddingStats {
pub embedded: usize,
pub failed: usize,
pub skipped: usize,
pub total: usize,
}
impl EmbeddingStats {
pub fn successful(&self) -> usize {
self.embedded
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum BatchOutcome {
Processed,
CircuitOpen { retry_after: std::time::Duration },
}
impl std::fmt::Display for EmbeddingStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"embedded: {}, failed: {}, skipped: {}, total: {}",
self.embedded, self.failed, self.skipped, self.total
)
}
}
pub struct EmbeddingService<S, E>
where
S: DatasetStore,
E: EmbeddingProvider,
{
store: S,
embedding: E,
config: EmbeddingServiceConfig,
}
impl<S, E> Clone for EmbeddingService<S, E>
where
S: DatasetStore + Clone,
E: EmbeddingProvider + Clone,
{
fn clone(&self) -> Self {
Self {
store: self.store.clone(),
embedding: self.embedding.clone(),
config: self.config.clone(),
}
}
}
impl<S, E> EmbeddingService<S, E>
where
S: DatasetStore,
E: EmbeddingProvider,
{
pub fn new(store: S, embedding: E) -> Self {
Self {
store,
embedding,
config: EmbeddingServiceConfig::default(),
}
}
pub fn with_config(store: S, embedding: E, config: EmbeddingServiceConfig) -> Self {
Self {
store,
embedding,
config,
}
}
pub fn embedding_provider(&self) -> &E {
&self.embedding
}
pub async fn embed_pending(
&self,
portal_filter: Option<&str>,
reporter: &impl ProgressReporter,
cancel_token: CancellationToken,
) -> Result<EmbeddingStats, AppError> {
let total = self.store.count_pending_embeddings(portal_filter).await? as usize;
if total == 0 {
tracing::info!("No datasets pending embedding");
return Ok(EmbeddingStats::default());
}
tracing::info!(
total,
portal = portal_filter.unwrap_or("all"),
provider = self.embedding.name(),
"Starting embedding pass"
);
let mut stats = EmbeddingStats {
total,
..Default::default()
};
let effective_batch_size =
std::cmp::min(self.config.batch_size, self.embedding.max_batch_size()).max(1);
let circuit_breaker =
CircuitBreaker::new(self.embedding.name(), self.config.circuit_breaker.clone());
let mut processed = 0usize;
let page_size = effective_batch_size * 10; loop {
if cancel_token.is_cancelled() {
tracing::info!("Embedding pass cancelled");
break;
}
let page = self
.store
.list_pending_embeddings(portal_filter, Some(page_size))
.await?;
if page.is_empty() {
break;
}
let embedded_before = stats.embedded;
'batches: for batch in page.chunks(effective_batch_size) {
if cancel_token.is_cancelled() {
tracing::info!("Embedding pass cancelled");
break;
}
const MAX_CIRCUIT_RETRIES: u32 = 5;
let mut attempt = 0;
loop {
match self
.process_batch(batch, &circuit_breaker, &mut stats)
.await
{
BatchOutcome::Processed => break,
BatchOutcome::CircuitOpen { retry_after } => {
attempt += 1;
if attempt > MAX_CIRCUIT_RETRIES {
tracing::warn!(
batch_size = batch.len(),
attempts = attempt - 1,
"Circuit still open after retries — leaving batch pending"
);
stats.skipped += batch.len();
break;
}
tracing::info!(
attempt,
wait_secs = retry_after.as_secs(),
"Circuit open — waiting for recovery before retry"
);
tokio::select! {
_ = tokio::time::sleep(retry_after) => {}
_ = cancel_token.cancelled() => {
tracing::info!("Embedding pass cancelled during circuit wait");
break 'batches;
}
}
}
}
}
processed += batch.len();
reporter.report(HarvestEvent::DatasetProcessed {
current: processed,
total,
created: 0,
updated: stats.embedded,
unchanged: 0,
failed: stats.failed,
skipped: stats.skipped,
});
}
if stats.embedded == embedded_before {
tracing::warn!(
"No progress this page — stopping to avoid infinite loop \
({} failed, {} skipped)",
stats.failed,
stats.skipped
);
break;
}
}
tracing::info!(
embedded = stats.embedded,
failed = stats.failed,
skipped = stats.skipped,
total = stats.total,
"Embedding pass complete"
);
Ok(stats)
}
async fn process_batch(
&self,
datasets: &[crate::Dataset],
circuit_breaker: &CircuitBreaker,
stats: &mut EmbeddingStats,
) -> BatchOutcome {
let embeddable: Vec<(&crate::Dataset, String)> = datasets
.iter()
.filter_map(|d| {
let text = format!(
"{} {}",
d.title,
d.description.as_deref().unwrap_or_default()
);
if text.trim().is_empty() {
None
} else {
Some((d, text))
}
})
.collect();
let skipped_empty = datasets.len() - embeddable.len();
if skipped_empty > 0 {
tracing::debug!(skipped_empty, "Skipped datasets with empty text");
stats.failed += skipped_empty;
}
if embeddable.is_empty() {
return BatchOutcome::Processed;
}
let texts: Vec<String> = embeddable.iter().map(|(_, t)| t.clone()).collect();
let batch_size = texts.len();
match circuit_breaker
.call(|| self.embedding.generate_batch(&texts))
.await
{
Ok(embeddings) => {
if embeddings.len() != batch_size {
tracing::warn!(
expected = batch_size,
got = embeddings.len(),
"Batch embedding count mismatch, failing batch"
);
stats.failed += batch_size;
return BatchOutcome::Processed;
}
let upsert_datasets: Vec<NewDataset> = embeddable
.iter()
.zip(embeddings)
.map(|((d, _), emb)| {
let content_hash = match &d.content_hash {
Some(h) => h.clone(),
None => {
tracing::info!(
original_id = %d.original_id,
"Dataset missing content_hash, automatically generating one"
);
NewDataset::compute_content_hash(&d.title, d.description.as_deref())
}
};
NewDataset {
original_id: d.original_id.clone(),
source_portal: d.source_portal.clone(),
url: d.url.clone(),
title: d.title.clone(),
description: d.description.clone(),
embedding: Some(emb),
metadata: d.metadata.clone(),
content_hash,
}
})
.collect();
let skipped_no_hash = batch_size - upsert_datasets.len();
stats.failed += skipped_no_hash;
let upsert_count = upsert_datasets.len();
match self.store.batch_upsert(&upsert_datasets).await {
Ok(_) => {
stats.embedded += upsert_count;
}
Err(e) => {
tracing::warn!(
count = upsert_count,
error = %e,
"Failed to batch upsert datasets with embeddings"
);
stats.failed += upsert_count;
}
}
BatchOutcome::Processed
}
Err(CircuitBreakerError::Open { retry_after, .. }) => {
tracing::debug!(
batch_size,
retry_after_secs = retry_after.as_secs(),
"Circuit breaker open - batch deferred for retry"
);
BatchOutcome::CircuitOpen { retry_after }
}
Err(CircuitBreakerError::Inner(e)) => {
tracing::warn!(
batch_size,
error = %e,
"Batch embedding generation failed"
);
stats.failed += batch_size;
BatchOutcome::Processed
}
}
}
}