use std::sync::Arc;
use zer_blocking::BlockerFactory;
use zer_cluster::ConnectedComponentsClusterer;
use zer_compare::{FieldComparator, FellegiSunterScorer};
use zer_core::{
error::ZerError,
schema::Schema,
traits::{Blocker, Clusterer, Comparator, EntityStore, Judge, RecordStore, Scorer},
VecRecordStore,
};
use zer_schema::SchemaRegistry;
use crate::{cluster_view::ClusterView, config::PipelineConfig, ingester::Ingester, progress::PipelineEvent};
pub struct Pipeline {
pub(crate) schema: Schema,
pub(crate) blocker: Arc<dyn Blocker>,
pub(crate) comparator: Arc<dyn Comparator>,
pub(crate) mapped_comparator: Option<Arc<FieldComparator>>,
pub(crate) scorer: Arc<dyn Scorer>,
pub(crate) clusterer: Arc<dyn Clusterer>,
pub(crate) store: Arc<dyn EntityStore>,
pub(crate) record_store: Arc<dyn RecordStore>,
pub(crate) registry: Arc<SchemaRegistry>,
pub(crate) judge: Option<Arc<dyn Judge>>,
pub(crate) config: PipelineConfig,
pub(crate) progress: Option<tokio::sync::mpsc::UnboundedSender<PipelineEvent>>,
}
impl Pipeline {
pub fn builder() -> PipelineBuilder {
PipelineBuilder::default()
}
pub fn ingester(self: Arc<Self>) -> Ingester {
Ingester::new(self)
}
pub fn store(&self) -> &Arc<dyn EntityStore> {
&self.store
}
pub fn record_store(&self) -> &Arc<dyn RecordStore> {
&self.record_store
}
pub fn cluster_view(&self) -> ClusterView {
ClusterView::new(Arc::clone(&self.store), Arc::clone(&self.record_store))
}
pub fn registry(&self) -> &Arc<SchemaRegistry> {
&self.registry
}
pub fn schema(&self) -> &Schema {
&self.schema
}
}
pub struct PipelineBuilder {
schema: Option<Schema>,
blocker: Option<Arc<dyn Blocker>>,
comparator: Option<Arc<dyn Comparator>>,
scorer: Option<Arc<dyn Scorer>>,
clusterer: Option<Arc<dyn Clusterer>>,
store: Option<Arc<dyn EntityStore>>,
record_store: Option<Arc<dyn RecordStore>>,
judge: Option<Arc<dyn Judge>>,
config: PipelineConfig,
progress: Option<tokio::sync::mpsc::UnboundedSender<PipelineEvent>>,
}
impl Default for PipelineBuilder {
fn default() -> Self {
Self {
schema: None,
blocker: None,
comparator: None,
scorer: None,
clusterer: None,
store: None,
record_store: None,
judge: None,
config: PipelineConfig::default(),
progress: None,
}
}
}
pub fn label_source(records: Vec<zer_core::record::Record>, source: &str) -> Vec<zer_core::record::Record> {
records.into_iter().map(|r| r.with_source(source)).collect()
}
impl PipelineBuilder {
pub fn schema(mut self, schema: Schema) -> Self {
self.schema = Some(schema);
self
}
pub fn blocker(mut self, b: impl Blocker + 'static) -> Self {
self.blocker = Some(Arc::new(b));
self
}
pub fn comparator(mut self, c: impl Comparator + 'static) -> Self {
self.comparator = Some(Arc::new(c));
self
}
pub fn scorer(mut self, s: impl Scorer + 'static) -> Self {
self.scorer = Some(Arc::new(s));
self
}
pub fn clusterer(mut self, c: impl Clusterer + 'static) -> Self {
self.clusterer = Some(Arc::new(c));
self
}
pub fn store(mut self, s: impl EntityStore + 'static) -> Self {
self.store = Some(Arc::new(s));
self
}
pub fn record_store(mut self, s: impl RecordStore + 'static) -> Self {
self.record_store = Some(Arc::new(s));
self
}
pub fn record_store_arc(mut self, s: Arc<dyn RecordStore>) -> Self {
self.record_store = Some(s);
self
}
pub fn judge(mut self, j: impl Judge + 'static) -> Self {
self.judge = Some(Arc::new(j));
self
}
pub fn config(mut self, c: PipelineConfig) -> Self {
self.config = c;
self
}
pub fn progress(mut self, tx: tokio::sync::mpsc::UnboundedSender<PipelineEvent>) -> Self {
self.progress = Some(tx);
self
}
pub fn build(self) -> Result<Arc<Pipeline>, ZerError> {
let schema = self.schema.ok_or(ZerError::EmptySchema)?;
let store = self
.store
.ok_or_else(|| ZerError::Store("no entity store configured".into()))?;
let blocker = self.blocker.unwrap_or_else(|| {
Arc::new(BlockerFactory::from_schema(&schema))
});
let mapped_comparator: Option<Arc<FieldComparator>> =
if self.config.field_mappings.is_empty() {
None
} else {
Some(Arc::new(FieldComparator::from_mapping(&self.config.field_mappings, &schema)))
};
let comparator = self.comparator.unwrap_or_else(|| {
Arc::new(FieldComparator::from_schema(&schema))
});
let scorer = self
.scorer
.unwrap_or_else(|| Arc::new(FellegiSunterScorer));
let clusterer = self
.clusterer
.unwrap_or_else(|| Arc::new(ConnectedComponentsClusterer::default()));
let record_store = self.record_store
.unwrap_or_else(|| Arc::new(VecRecordStore::new()));
let registry = Arc::new(SchemaRegistry::open(&self.config.registry_path)?);
Ok(Arc::new(Pipeline {
schema,
blocker,
comparator,
mapped_comparator,
scorer,
clusterer,
store,
record_store,
registry,
judge: self.judge,
config: self.config,
progress: self.progress,
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
use zer_cluster::ZalEntityStore;
use zer_core::schema::{FieldKind, SchemaBuilder};
fn person_schema() -> Schema {
SchemaBuilder::new()
.field("voornamen", FieldKind::Name)
.field("achternaam", FieldKind::Name)
.field("geboortedatum", FieldKind::Date)
.build()
.unwrap()
}
fn temp_pipeline(dir: &TempDir) -> Arc<Pipeline> {
let registry_path = dir.path().join("test.zsm");
let store = ZalEntityStore::open_in_memory().unwrap();
Pipeline::builder()
.schema(person_schema())
.store(store)
.config(PipelineConfig {
registry_path,
..PipelineConfig::default()
})
.build()
.unwrap()
}
#[test]
fn builder_with_schema_and_store_succeeds() {
let dir = TempDir::new().unwrap();
let pipeline = temp_pipeline(&dir);
assert_eq!(pipeline.schema().fields.len(), 3);
}
#[test]
fn builder_missing_schema_returns_error() {
let dir = TempDir::new().unwrap();
let store = ZalEntityStore::open_in_memory().unwrap();
let result = Pipeline::builder()
.store(store)
.config(PipelineConfig {
registry_path: dir.path().join("test.zsm"),
..PipelineConfig::default()
})
.build();
assert!(result.is_err(), "missing schema must return an error");
}
#[test]
fn builder_missing_store_returns_error() {
let dir = TempDir::new().unwrap();
let result = Pipeline::builder()
.schema(person_schema())
.config(PipelineConfig {
registry_path: dir.path().join("test.zsm"),
..PipelineConfig::default()
})
.build();
assert!(result.is_err(), "missing store must return an error");
}
#[test]
fn store_and_registry_accessors_work() {
let dir = TempDir::new().unwrap();
let pipeline = temp_pipeline(&dir);
let _store = pipeline.store();
let _registry = pipeline.registry();
}
}