use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use futures::StreamExt;
use futures::stream::BoxStream;
use queue_channel::{ChannelConsumer, channel};
use queue_core::{AckHandle, Consumer, Delivery, Producer};
use schema_core::{GenericValue, IndexName};
use sinks_core::Sink;
use sources_core::SnapshotTable;
use sources_core::cdc::{Ack, Change, ChangeCapture, ChangeEvent};
use sources_core::document::{Document, DocumentBuilder, DocumentId};
use tokio::time::{Instant, timeout_at};
use crate::error::{EngineError, Result};
use crate::observer::{BatchStats, Observer};
use crate::policy::{BatchPolicy, FailurePolicies, FailurePolicy};
#[derive(Clone, Copy)]
pub(crate) struct Pipeline<'a> {
pub(crate) documents: &'a dyn DocumentBuilder,
pub(crate) sink: &'a dyn Sink,
pub(crate) observer: &'a Arc<dyn Observer>,
pub(crate) queue_capacity: usize,
pub(crate) batch: BatchPolicy,
pub(crate) failure_policies: &'a FailurePolicies,
}
pub(crate) async fn run_inner(
pipeline: Pipeline<'_>,
source: &dyn ChangeCapture,
skip_backfill: bool,
) -> Result<()> {
let mappings = pipeline.documents.index_mappings().await?;
tracing::info!(indexes = mappings.len(), "ensuring target indexes");
for mapping in &mappings {
pipeline.sink.ensure_index(mapping).await?;
}
pipeline.observer.on_indexes_ensured(mappings.len());
if skip_backfill {
tracing::info!("skipping backfill (skip_backfill set)");
} else {
backfill(pipeline, source).await?;
}
pipeline.observer.on_backfill_completed();
let stream = source.live().await?;
tracing::info!("following live changes");
pipeline.observer.on_live_started();
let result = pump(pipeline, stream, None).await;
match &result {
Ok(()) => tracing::info!("pipeline stopped: live stream ended"),
Err(error) => tracing::error!(%error, "pipeline stopped on error"),
}
result
}
#[tracing::instrument(name = "backfill", skip_all)]
async fn backfill(pipeline: Pipeline<'_>, source: &dyn ChangeCapture) -> Result<()> {
let mut seeding: HashSet<IndexName> = HashSet::new();
let mut tables: Vec<SnapshotTable> = Vec::new();
for scope in pipeline.documents.backfill_scopes() {
if pipeline.sink.is_seeded(&scope.index).await? {
continue;
}
if !tables.contains(&scope.root) {
tables.push(scope.root);
}
seeding.insert(scope.index);
}
if seeding.is_empty() {
tracing::info!("no unseeded indexes; skipping backfill");
return Ok(());
}
tracing::info!(
indexes = seeding.len(),
tables = tables.len(),
"seeding indexes"
);
pipeline
.observer
.on_backfill_started(&seeding.iter().cloned().collect::<Vec<_>>());
let stream = source.snapshot(&tables).await?;
pump(pipeline, stream, Some(&seeding)).await?;
for index in &seeding {
pipeline.sink.mark_seeded(index).await?;
pipeline.observer.on_index_seeded(index);
}
tracing::info!(indexes = seeding.len(), "backfill complete");
Ok(())
}
#[tracing::instrument(name = "pump", skip_all)]
async fn pump(
pipeline: Pipeline<'_>,
stream: BoxStream<'static, sources_core::Result<Change>>,
filter: Option<&HashSet<IndexName>>,
) -> Result<()> {
let (producer, mut consumer) = channel::<Change>(pipeline.queue_capacity);
let mut capture = CaptureGuard(Some(tokio::spawn(capture(
stream,
producer,
Arc::clone(pipeline.observer),
))));
let worker = work(pipeline, &mut consumer, filter).await;
let captured = match capture.0.take() {
Some(handle) => {
handle.abort();
handle.await
}
None => Ok(Ok(())),
};
worker?;
match captured {
Ok(result) => result,
Err(join) if join.is_cancelled() => Ok(()),
Err(join) => Err(EngineError::Task(join.to_string())),
}
}
#[derive(Debug)]
struct CaptureGuard(Option<tokio::task::JoinHandle<Result<()>>>);
impl Drop for CaptureGuard {
fn drop(&mut self) {
if let Some(handle) = &self.0 {
handle.abort();
}
}
}
#[tracing::instrument(name = "capture", skip_all)]
async fn capture(
mut stream: BoxStream<'static, sources_core::Result<Change>>,
producer: queue_channel::ChannelProducer<Change>,
observer: Arc<dyn Observer>,
) -> Result<()> {
let mut captured = 0u64;
while let Some(change) = stream.next().await {
producer.publish(change?).await?;
captured += 1;
observer.on_change_captured();
}
tracing::debug!(captured, "capture stream ended");
Ok(())
}
#[tracing::instrument(name = "worker", skip_all, fields(max_changes = pipeline.batch.max_changes))]
pub(crate) async fn work(
pipeline: Pipeline<'_>,
consumer: &mut ChannelConsumer<Change>,
filter: Option<&HashSet<IndexName>>,
) -> Result<()> {
let batch = pipeline.batch;
let mut pending: Batch = Batch::with_capacity(batch.max_changes);
'batches: loop {
let Some(delivery) = consumer.recv().await? else {
break;
};
buffer(delivery, pipeline.documents, filter, &mut pending).await?;
let deadline = Instant::now() + batch.max_delay;
while pending.len() < batch.max_changes {
match timeout_at(deadline, consumer.recv()).await {
Err(_elapsed) => break,
Ok(Ok(Some(delivery))) => {
buffer(delivery, pipeline.documents, filter, &mut pending).await?;
}
Ok(Ok(None)) => {
commit(pipeline, &mut pending, consumer.is_empty()).await?;
break 'batches;
}
Ok(Err(queue_err)) => return Err(queue_err.into()),
}
}
commit(pipeline, &mut pending, consumer.is_empty()).await?;
}
commit(pipeline, &mut pending, consumer.is_empty()).await
}
#[derive(Debug)]
struct Batch {
source: Vec<Ack>,
handles: Vec<Box<dyn AckHandle>>,
ids: Vec<DocumentId>,
seen: HashSet<DocumentId>,
}
impl Batch {
fn with_capacity(capacity: usize) -> Self {
Self {
source: Vec::with_capacity(capacity),
handles: Vec::with_capacity(capacity),
ids: Vec::with_capacity(capacity),
seen: HashSet::with_capacity(capacity),
}
}
fn len(&self) -> usize {
self.source.len()
}
}
async fn buffer(
delivery: Delivery<Change>,
documents: &dyn DocumentBuilder,
filter: Option<&HashSet<IndexName>>,
pending: &mut Batch,
) -> Result<()> {
let (change, handle) = delivery.into_parts();
match &change.event {
ChangeEvent::Upsert { table, key } | ChangeEvent::Delete { table, key } => {
let affected = documents.resolve(table, key).await?;
tracing::trace!(documents = affected.len(), "change resolved to documents");
for id in affected {
if filter.is_some_and(|filter| !filter.contains(&id.index)) {
continue;
}
if pending.seen.insert(id.clone()) {
pending.ids.push(id);
}
}
}
}
pending.source.push(change.ack);
pending.handles.push(handle);
Ok(())
}
#[tracing::instrument(name = "commit", level = "debug", skip_all, fields(changes = pending.len(), documents = pending.ids.len(), caught_up))]
async fn commit(pipeline: Pipeline<'_>, pending: &mut Batch, caught_up: bool) -> Result<()> {
if pending.len() == 0 {
return Ok(());
}
let changes = pending.len();
let documents_built = pending.ids.len();
let mut by_index: HashMap<IndexName, usize> = HashMap::new();
for id in &pending.ids {
*by_index.entry(id.index.clone()).or_insert(0) += 1;
}
for document in pipeline.documents.build_many(&pending.ids).await? {
match document {
Document::Upsert { id, body } => {
pipeline
.sink
.upsert(&id.index, &document_id(&id), &body)
.await?;
}
Document::Delete { id } => {
pipeline.sink.delete(&id.index, &document_id(&id)).await?;
}
}
}
let flush_start = Instant::now();
let report = pipeline.sink.flush(caught_up).await?;
let flush = flush_start.elapsed();
if !report.is_clean() {
let mut stop_count = 0usize;
let mut stop_example = String::new();
for doc in &report.rejected {
if pipeline.failure_policies.resolve(&doc.index) == FailurePolicy::Stop {
if stop_count == 0 {
stop_example = format!("{}/{}: {}", doc.index, doc.id, doc.reason);
}
stop_count += 1;
}
}
if stop_count > 0 {
return Err(EngineError::DocumentsRejected(stop_count, stop_example));
}
for doc in &report.rejected {
tracing::warn!(
index = %doc.index,
id = %doc.id,
reason = %doc.reason,
"document rejected by sink; quarantining and continuing",
);
pipeline
.observer
.on_document_quarantined(&doc.index, &doc.id, &doc.reason);
}
}
for ack in pending.source.drain(..) {
ack.confirm();
}
for handle in pending.handles.drain(..) {
handle.ack().await?;
}
pending.ids.clear();
pending.seen.clear();
pipeline.observer.on_batch_committed(BatchStats {
changes,
documents: documents_built,
documents_by_index: by_index.into_iter().collect(),
flush,
});
tracing::debug!("batch built, flushed, and acked");
Ok(())
}
fn document_id(id: &DocumentId) -> String {
id.key
.0
.iter()
.map(|(_, value)| value_to_string(value))
.collect::<Vec<_>>()
.join(":")
}
fn value_to_string(value: &GenericValue) -> String {
match value {
GenericValue::Bool(b) => b.to_string(),
GenericValue::SmallInt(i) => i.to_string(),
GenericValue::Int(i) => i.to_string(),
GenericValue::BigInt(i) => i.to_string(),
GenericValue::Float(f) => f.to_string(),
GenericValue::Double(f) => f.to_string(),
GenericValue::Decimal(d) => d.to_string(),
GenericValue::String(s) => s.clone(),
GenericValue::Uuid(u) => u.to_string(),
GenericValue::Date(d) => d.to_string(),
GenericValue::Time(t) => t.to_string(),
GenericValue::Timestamp(ts) => ts.to_string(),
GenericValue::TimestampTz(ts) => ts.to_rfc3339(),
GenericValue::Bytes(bytes) => {
let mut out = String::with_capacity(2 + bytes.len() * 2);
out.push_str("\\x");
for byte in bytes {
out.push_str(&format!("{byte:02x}"));
}
out
}
GenericValue::Null => "null".to_owned(),
GenericValue::Array(_) | GenericValue::Map(_) => String::new(),
}
}