Skip to main content

zer_compute/
comparator.rs

1//! `DeviceComparator`, implements the `Comparator` trait.
2//!
3//! `compare_batch_from_pool` always routes to the CPU (Rayon parallel) path.
4//! String comparison (Jaro-Winkler) is branch-heavy and dominated by PCIe
5//! transfer overhead on GPU, making CPU faster for all observed batch sizes.
6
7use std::sync::Arc;
8
9use zer_core::{
10    comparison::{ComparisonBatch, ComparisonVector},
11    record::Record,
12    record_pool::RecordPool,
13    schema::Schema,
14    traits::Comparator,
15};
16
17use crate::{
18    backend::{cpu::CpuFallbackComparator, DeviceBackend},
19    error::GpuError,
20};
21
22pub struct DeviceComparator {
23    backend: Arc<DeviceBackend>,
24    cpu_fallback: CpuFallbackComparator,
25}
26
27impl DeviceComparator {
28    pub fn new(backend: Arc<DeviceBackend>, schema: &Schema) -> Result<Self, GpuError> {
29        let cpu_fallback = CpuFallbackComparator::from_schema(schema);
30        Ok(Self {
31            backend,
32            cpu_fallback,
33        })
34    }
35
36    pub fn backend_name(&self) -> &'static str {
37        self.backend.name()
38    }
39}
40
41impl Comparator for DeviceComparator {
42    fn compare(&self, a: &Record, b: &Record, schema: &Schema) -> ComparisonVector {
43        self.cpu_fallback.compare(a, b, schema)
44    }
45
46    fn compare_batch_from_pool(
47        &self,
48        pool: &RecordPool,
49        indices: &[(usize, usize)],
50        schema: &Schema,
51    ) -> ComparisonBatch {
52        if indices.is_empty() {
53            return ComparisonBatch::new(0, schema.fields.len(), vec![]);
54        }
55
56        // compare_batch always runs on CPU: string comparison (Jaro-Winkler) is branch-heavy
57        // and dominated by PCIe transfer overhead on GPU for all observed batch sizes.
58        self.cpu_fallback
59            .compare_batch_from_pool(pool, indices, schema)
60    }
61}
62
63// ── Unit tests ────────────────────────────────────────────────────────────────
64
65#[cfg(test)]
66mod tests {
67    use super::*;
68    use zer_core::{
69        comparison::ComparisonLevel,
70        record::{FieldValue, Record},
71        record_pool::RecordPool,
72        schema::{FieldKind, SchemaBuilder},
73    };
74
75    fn test_schema() -> Schema {
76        SchemaBuilder::new()
77            .field("naam", FieldKind::Name)
78            .field("datum", FieldKind::Date)
79            .field("kenteken", FieldKind::LicensePlate)
80            .build()
81            .unwrap()
82    }
83
84    fn make_record(id: u64) -> Record {
85        Record::new(id)
86            .insert("naam", FieldValue::Text("Alice de Vries".into()))
87            .insert("datum", FieldValue::Text("1990-03-15".into()))
88            .insert("kenteken", FieldValue::Text("12-ABC-3".into()))
89    }
90
91    fn make_record_b(id: u64) -> Record {
92        Record::new(id)
93            .insert("naam", FieldValue::Text("Alicia de Vrees".into()))
94            .insert("datum", FieldValue::Text("1990-03-15".into()))
95            .insert("kenteken", FieldValue::Text("12-ABC-3".into()))
96    }
97
98    #[test]
99    fn single_pair_uses_cpu_path() {
100        let schema = test_schema();
101        let backend = Arc::new(DeviceBackend::cpu());
102        let cmp = DeviceComparator::new(backend, &schema).unwrap();
103
104        let a = make_record(1);
105        let b = make_record_b(2);
106        let vec = cmp.compare(&a, &b, &schema);
107
108        assert_eq!(vec.record_a, 1);
109        assert_eq!(vec.record_b, 2);
110        assert_eq!(vec.levels.len(), 3);
111    }
112
113    #[test]
114    fn small_batch_uses_cpu_fallback() {
115        let schema = test_schema();
116        let backend = Arc::new(DeviceBackend::cpu());
117        let cmp = DeviceComparator::new(backend, &schema).unwrap();
118
119        let records: Vec<Record> = (0..20)
120            .map(|i| {
121                if i % 2 == 0 {
122                    make_record(i)
123                } else {
124                    make_record_b(i)
125                }
126            })
127            .collect();
128        let pool = RecordPool::from_records(&records, &schema);
129        let indices: Vec<(usize, usize)> = (0..10).map(|i| (i * 2, i * 2 + 1)).collect();
130
131        let batch = cmp.compare_batch_from_pool(&pool, &indices, &schema);
132        assert_eq!(batch.n_pairs, 10);
133        assert_eq!(batch.n_fields, 3);
134        assert_eq!(batch.levels.len(), 3 * 10);
135    }
136
137    #[test]
138    fn empty_batch_returns_empty() {
139        let schema = test_schema();
140        let backend = Arc::new(DeviceBackend::cpu());
141        let cmp = DeviceComparator::new(backend, &schema).unwrap();
142        let pool = RecordPool::new(schema.fields.len());
143        let batch = cmp.compare_batch_from_pool(&pool, &[], &schema);
144        assert_eq!(batch.n_pairs, 0);
145        assert!(batch.levels.is_empty());
146    }
147
148    #[test]
149    fn exact_match_produces_exact_levels() {
150        let schema = test_schema();
151        let backend = Arc::new(DeviceBackend::cpu());
152        let cmp = DeviceComparator::new(backend, &schema).unwrap();
153
154        let r = Record::new(1)
155            .insert("naam", FieldValue::Text("Jan Jansen".into()))
156            .insert("datum", FieldValue::Text("1980-01-01".into()))
157            .insert("kenteken", FieldValue::Text("AB-123-C".into()));
158        let vec = cmp.compare(&r.clone(), &r, &schema);
159
160        for level in &vec.levels {
161            assert_eq!(
162                *level,
163                ComparisonLevel::Exact,
164                "identical records should give Exact"
165            );
166        }
167    }
168
169    #[test]
170    fn completely_different_records_produce_none_levels() {
171        let schema = test_schema();
172        let backend = Arc::new(DeviceBackend::cpu());
173        let cmp = DeviceComparator::new(backend, &schema).unwrap();
174
175        let a = Record::new(1)
176            .insert("naam", FieldValue::Text("Henk".into()))
177            .insert("datum", FieldValue::Text("1950-01-01".into()))
178            .insert("kenteken", FieldValue::Text("XX-000-X".into()));
179        let b = Record::new(2)
180            .insert("naam", FieldValue::Text("Zäzä".into()))
181            .insert("datum", FieldValue::Text("2010-12-31".into()))
182            .insert("kenteken", FieldValue::Text("YY-999-Y".into()));
183
184        let vec = cmp.compare(&a, &b, &schema);
185        for level in &vec.levels {
186            assert!(
187                matches!(level, ComparisonLevel::None | ComparisonLevel::Partial),
188                "very different records should produce None or Partial levels"
189            );
190        }
191    }
192
193    fn synthetic_records(n: usize, schema: &Schema) -> Vec<Record> {
194        use rand::Rng;
195        let mut rng = rand::thread_rng();
196
197        let names = [
198            "Alice",
199            "Alicia",
200            "Bob",
201            "Robert",
202            "Eva",
203            "Eva-Marie",
204            "Jan",
205            "Johan",
206            "Petra",
207            "Pietra",
208            "Lena",
209            "Lena-Marie",
210        ];
211        let dates = [
212            "1990-01-15",
213            "1990-01-16",
214            "1985-06-20",
215            "1975-03-03",
216            "2000-12-31",
217            "2001-01-01",
218            "1960-07-07",
219            "1970-11-22",
220        ];
221        let plates = [
222            "12-ABC-3", "12-ABD-3", "45-XYZ-6", "46-XYZ-6", "AB-123-C", "AB-124-C", "ZZ-999-Z",
223            "ZZ-998-Z",
224        ];
225
226        let fields: Vec<&str> = schema.fields.iter().map(|f| f.name.as_str()).collect();
227
228        (0..n)
229            .map(|i| {
230                let mut r = Record::new(i as u64);
231                for field in &fields {
232                    let val = match *field {
233                        "naam" => names[rng.gen_range(0..names.len())],
234                        "datum" => dates[rng.gen_range(0..dates.len())],
235                        "kenteken" => plates[rng.gen_range(0..plates.len())],
236                        _ => "unknown",
237                    };
238                    r = r.insert(*field, FieldValue::Text(val.into()));
239                }
240                r
241            })
242            .collect()
243    }
244
245    #[test]
246    fn large_batch_auto_detect_returns_correct_count() {
247        let schema = test_schema();
248        let backend = Arc::new(DeviceBackend::auto_detect());
249        let cmp = DeviceComparator::new(backend, &schema).unwrap();
250
251        let records = synthetic_records(4_000, &schema);
252        let pool = RecordPool::from_records(&records, &schema);
253        let indices: Vec<(usize, usize)> = (0..2_000).map(|i| (i * 2, i * 2 + 1)).collect();
254
255        let batch = cmp.compare_batch_from_pool(&pool, &indices, &schema);
256
257        assert_eq!(
258            batch.n_pairs, 2_000,
259            "compare_batch_from_pool must return one row per pair"
260        );
261        assert_eq!(
262            batch.n_fields, 3,
263            "each batch must have one column per field"
264        );
265    }
266}