use rayon::prelude::*;
use zer_core::{
comparison::{ComparisonBatch, ComparisonLevel, ComparisonVector},
field_mapping::{FieldMapping, NullPolicy},
record::Record,
record_pool::RecordPool,
schema::{FieldKind, Schema},
traits::Comparator,
};
use crate::{
discretize::LevelThresholds,
similarity::{default_fns_for, SimilarityFn},
};
pub struct FieldComparator {
field_fns: Vec<Vec<Box<dyn SimilarityFn>>>,
thresholds: Vec<LevelThresholds>,
}
impl FieldComparator {
pub fn from_schema(schema: &Schema) -> Self {
let field_fns = schema.fields.iter()
.map(|f| default_fns_for(f.kind))
.collect();
let thresholds = schema.fields.iter()
.map(|f| LevelThresholds::for_kind(f.kind))
.collect();
Self { field_fns, thresholds }
}
pub fn from_mapping(mappings: &[FieldMapping], a_schema: &Schema) -> Self {
let kind_of = |name: &str| {
a_schema.fields.iter()
.find(|f| f.name == name)
.map(|f| f.kind)
.unwrap_or(FieldKind::Categorical)
};
let (field_fns, thresholds): (Vec<_>, Vec<_>) = mappings.iter()
.map(|m| { let k = kind_of(&m.a_field); (default_fns_for(k), LevelThresholds::for_kind(k)) })
.unzip();
Self { field_fns, thresholds }
}
pub fn compare_pair_mapped(
&self,
a: &Record,
b: &Record,
mappings: &[FieldMapping],
) -> ComparisonVector {
let levels: Vec<ComparisonLevel> = mappings.iter().enumerate()
.map(|(i, m)| {
let va = a.fields.get(&m.a_field);
let vb = b.fields.get(&m.b_field);
match (va, vb, &m.null_policy) {
(Some(va), Some(vb), _) => {
let sim = self.field_fns[i].iter()
.map(|f| f.similarity(va, vb))
.fold(0.0_f32, f32::max);
self.thresholds[i].apply(sim)
}
(_, _, NullPolicy::PenaliseAbsence) => ComparisonLevel::None,
(_, _, NullPolicy::Skip) => ComparisonLevel::Null,
}
})
.collect();
ComparisonVector::new(a.id, b.id, levels)
}
pub fn compare_batch_mapped(
&self,
records: &[Record],
indices: &[(usize, usize)],
mappings: &[FieldMapping],
) -> ComparisonBatch {
let n_pairs = indices.len();
let n_fields = mappings.len();
if n_pairs == 0 {
return ComparisonBatch::new(0, n_fields, vec![]);
}
let pair_ids_and_levels: Vec<((u64, u64), Vec<u8>)> = indices
.par_iter()
.map(|&(i, j)| {
let ids = (records[i].id, records[j].id);
let cv = self.compare_pair_mapped(&records[i], &records[j], mappings);
let levels = cv.levels.iter().map(|&l| l as u8).collect();
(ids, levels)
})
.collect();
Self::scatter_to_batch(n_pairs, n_fields, pair_ids_and_levels)
}
fn scatter_to_batch(
n_pairs: usize,
n_fields: usize,
pair_ids_and_levels: Vec<((u64, u64), Vec<u8>)>,
) -> ComparisonBatch {
let pair_ids: Vec<(u64, u64)> =
pair_ids_and_levels.iter().map(|(ids, _)| *ids).collect();
let mut levels = vec![0u8; n_fields * n_pairs];
for f in 0..n_fields {
let field_slice = &mut levels[f * n_pairs..(f + 1) * n_pairs];
for (p, (_, pair_lvls)) in pair_ids_and_levels.iter().enumerate() {
field_slice[p] = pair_lvls[f];
}
}
ComparisonBatch { n_pairs, n_fields, pair_ids, levels }
}
pub fn with_thresholds(mut self, field_idx: usize, thresholds: LevelThresholds) -> Self {
self.thresholds[field_idx] = thresholds;
self
}
pub fn with_fns(mut self, field_idx: usize, fns: Vec<Box<dyn SimilarityFn>>) -> Self {
self.field_fns[field_idx] = fns;
self
}
fn compare_pair(&self, a: &Record, b: &Record, schema: &Schema) -> ComparisonVector {
let levels: Vec<ComparisonLevel> = schema.fields.iter().enumerate()
.map(|(i, field)| {
let va = a.fields.get(&field.name);
let vb = b.fields.get(&field.name);
match (va, vb) {
(Some(va), Some(vb)) => {
let sim = self.field_fns[i].iter()
.map(|f| f.similarity(va, vb))
.fold(0.0_f32, f32::max);
self.thresholds[i].apply(sim)
}
_ => ComparisonLevel::None,
}
})
.collect();
ComparisonVector::new(a.id, b.id, levels)
}
#[inline]
fn compare_pool_field(&self, f: usize, a_str: &str, b_str: &str) -> u8 {
if a_str.is_empty() || b_str.is_empty() {
return ComparisonLevel::None as u8;
}
let sim = self.field_fns[f].iter()
.map(|fn_| fn_.similarity_str(a_str, b_str))
.fold(0.0_f32, f32::max);
self.thresholds[f].apply(sim) as u8
}
pub fn compare_batch_from_pool(
&self,
pool: &RecordPool,
indices: &[(usize, usize)],
schema: &Schema,
) -> ComparisonBatch {
let n_pairs = indices.len();
let n_fields = schema.fields.len();
if n_pairs == 0 {
return ComparisonBatch::new(0, n_fields, vec![]);
}
let pair_ids: Vec<(u64, u64)> = indices.iter()
.map(|&(i, j)| (pool.ids[i], pool.ids[j]))
.collect();
let mut pair_major = vec![0u8; n_pairs * n_fields];
pair_major
.par_chunks_mut(n_fields)
.zip(indices.par_iter())
.for_each(|(chunk, &(i, j))| {
for f in 0..n_fields {
chunk[f] = self.compare_pool_field(f, pool.get(f, i), pool.get(f, j));
}
});
let mut levels = vec![0u8; n_fields * n_pairs];
for (p, chunk) in pair_major.chunks_exact(n_fields).enumerate() {
for (f, &lvl) in chunk.iter().enumerate() {
levels[f * n_pairs + p] = lvl;
}
}
ComparisonBatch { n_pairs, n_fields, pair_ids, levels }
}
}
impl Comparator for FieldComparator {
fn compare(&self, a: &Record, b: &Record, schema: &Schema) -> ComparisonVector {
self.compare_pair(a, b, schema)
}
fn compare_batch_from_pool(
&self,
pool: &RecordPool,
indices: &[(usize, usize)],
schema: &Schema,
) -> ComparisonBatch {
self.compare_batch_from_pool(pool, indices, schema)
}
}
#[cfg(test)]
mod tests {
use zer_core::{
comparison::ComparisonLevel,
record::FieldValue,
record_pool::RecordPool,
schema::{FieldKind, SchemaBuilder},
};
use super::*;
fn person_schema() -> Schema {
SchemaBuilder::new()
.field("voornamen", FieldKind::Name)
.field("achternaam", FieldKind::Name)
.field("geboortedatum", FieldKind::Date)
.field("postcode", FieldKind::Id)
.build()
.unwrap()
}
fn make_record(id: u64, voornamen: &str, achternaam: &str, dob: &str, postcode: &str) -> Record {
Record::new(id)
.insert("voornamen", FieldValue::Text(voornamen.into()))
.insert("achternaam", FieldValue::Text(achternaam.into()))
.insert("geboortedatum", FieldValue::Text(dob.into()))
.insert("postcode", FieldValue::Text(postcode.into()))
}
#[test]
fn compare_returns_correct_field_count() {
let schema = person_schema();
let cmp = FieldComparator::from_schema(&schema);
let a = make_record(1, "Jan", "Jansen", "1990-06-15", "1011AB");
let b = make_record(2, "Jan", "Jansen", "1990-06-15", "1011AB");
let cv = cmp.compare(&a, &b, &schema);
assert_eq!(cv.levels.len(), schema.len());
}
#[test]
fn identical_records_score_exact_on_all_fields() {
let schema = person_schema();
let cmp = FieldComparator::from_schema(&schema);
let a = make_record(1, "Jan", "Jansen", "1990-06-15", "1011AB");
let b = make_record(2, "Jan", "Jansen", "1990-06-15", "1011AB");
let cv = cmp.compare(&a, &b, &schema);
assert!(cv.levels.iter().all(|&l| l == ComparisonLevel::Exact),
"identical records should have all Exact levels: {:?}", cv.levels);
}
#[test]
fn completely_different_records_score_none_or_low() {
let schema = person_schema();
let cmp = FieldComparator::from_schema(&schema);
let a = make_record(1, "Jan", "Jansen", "1990-06-15", "1011AB");
let b = make_record(2, "Maria", "Bakker", "1955-12-01", "3001XY");
let cv = cmp.compare(&a, &b, &schema);
let n_none = cv.levels.iter().filter(|&&l| l == ComparisonLevel::None).count();
assert!(n_none >= 2, "dissimilar records should have several None levels: {:?}", cv.levels);
}
#[test]
fn missing_field_produces_none() {
let schema = person_schema();
let cmp = FieldComparator::from_schema(&schema);
let a = make_record(1, "Jan", "Jansen", "1990-06-15", "1011AB");
let b = Record::new(2)
.insert("voornamen", FieldValue::Text("Jan".into()))
.insert("achternaam", FieldValue::Text("Jansen".into()))
.insert("geboortedatum", FieldValue::Text("1990-06-15".into()));
let cv = cmp.compare(&a, &b, &schema);
assert_eq!(cv.levels[3], ComparisonLevel::None,
"missing postcode should yield None, got {:?}", cv.levels[3]);
}
#[test]
fn compare_batch_field_major_layout() {
let schema = person_schema();
let cmp = FieldComparator::from_schema(&schema);
let n_fields = schema.len();
let records: Vec<Record> = (0..5).flat_map(|i| vec![
make_record(i * 2, "Jan", "Jansen", "1990-06-15", "1011AB"),
make_record(i * 2 + 1, "Jan", "Jansen", "1990-06-15", "1011AB"),
]).collect();
let pool = RecordPool::from_records(&records, &schema);
let indices: Vec<(usize, usize)> = (0..5).map(|i| (i * 2, i * 2 + 1)).collect();
let batch = cmp.compare_batch_from_pool(&pool, &indices, &schema);
assert_eq!(batch.n_pairs, 5);
assert_eq!(batch.n_fields, n_fields);
assert_eq!(batch.levels.len(), n_fields * 5);
for f in 0..n_fields {
for p in 0..5 {
assert_eq!(
batch.level(f, p),
ComparisonLevel::Exact,
"field {f} pair {p} should be Exact"
);
}
}
}
#[test]
fn compare_batch_from_pool_matches_individual_compare() {
let schema = person_schema();
let cmp = FieldComparator::from_schema(&schema);
let records: Vec<Record> = (0..20).flat_map(|i| vec![
make_record(i * 2, "Jan", "Jansen", "1990-06-15", "1011AB"),
make_record(i * 2 + 1, "Jan", "Jansen", "1990-06-15", "1011AB"),
]).collect();
let pool = RecordPool::from_records(&records, &schema);
let indices: Vec<(usize, usize)> = (0..20).map(|i| (i * 2, i * 2 + 1)).collect();
let batch = cmp.compare_batch_from_pool(&pool, &indices, &schema);
for (p, &(i, j)) in indices.iter().enumerate() {
let single = cmp.compare(&records[i], &records[j], &schema);
for (f, &expected) in single.levels.iter().enumerate() {
assert_eq!(
batch.level(f, p), expected,
"batch and individual disagree at field {f} pair {p}"
);
}
}
}
#[test]
fn empty_batch_is_valid() {
let schema = person_schema();
let cmp = FieldComparator::from_schema(&schema);
let pool = RecordPool::new(schema.fields.len());
let batch = cmp.compare_batch_from_pool(&pool, &[], &schema);
assert_eq!(batch.n_pairs, 0);
assert!(batch.levels.is_empty());
}
}