Skip to main content

datasynth_core/traits/
post_processor.rs

1//! Post-processor trait for data quality variations and other post-generation transformations.
2//!
3//! Post-processors modify records after generation to inject data quality issues,
4//! format variations, typos, and other realistic flakiness. They produce labels
5//! that can be used for ML training.
6
7use crate::error::SynthResult;
8use std::collections::HashMap;
9
10/// Context passed to post-processors during processing.
11#[derive(Debug, Clone, Default)]
12pub struct ProcessContext {
13    /// Current record index in the batch
14    pub record_index: usize,
15    /// Total records in the batch
16    pub batch_size: usize,
17    /// Current output format (csv, json, parquet)
18    pub output_format: Option<String>,
19    /// Additional context data
20    pub metadata: HashMap<String, String>,
21}
22
23impl ProcessContext {
24    /// Create a new processing context.
25    pub fn new(record_index: usize, batch_size: usize) -> Self {
26        Self {
27            record_index,
28            batch_size,
29            output_format: None,
30            metadata: HashMap::new(),
31        }
32    }
33
34    /// Set the output format.
35    pub fn with_format(mut self, format: impl Into<String>) -> Self {
36        self.output_format = Some(format.into());
37        self
38    }
39
40    /// Add metadata.
41    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
42        self.metadata.insert(key.into(), value.into());
43        self
44    }
45
46    /// Check if processing first record.
47    pub fn is_first(&self) -> bool {
48        self.record_index == 0
49    }
50
51    /// Check if processing last record.
52    pub fn is_last(&self) -> bool {
53        self.record_index == self.batch_size.saturating_sub(1)
54    }
55}
56
57/// Statistics from a post-processor run.
58#[derive(Debug, Clone, Default)]
59pub struct ProcessorStats {
60    /// Number of records processed
61    pub records_processed: u64,
62    /// Number of records modified
63    pub records_modified: u64,
64    /// Number of labels generated
65    pub labels_generated: u64,
66    /// Number of errors encountered
67    pub errors_encountered: u64,
68    /// Processing time in microseconds
69    pub processing_time_us: u64,
70}
71
72impl ProcessorStats {
73    /// Calculate modification rate.
74    pub fn modification_rate(&self) -> f64 {
75        if self.records_processed == 0 {
76            0.0
77        } else {
78            self.records_modified as f64 / self.records_processed as f64
79        }
80    }
81
82    /// Merge stats from another processor.
83    pub fn merge(&mut self, other: &ProcessorStats) {
84        self.records_processed += other.records_processed;
85        self.records_modified += other.records_modified;
86        self.labels_generated += other.labels_generated;
87        self.errors_encountered += other.errors_encountered;
88        self.processing_time_us += other.processing_time_us;
89    }
90}
91
92/// Core trait for post-processors that modify records and generate labels.
93///
94/// Post-processors are applied after generation to inject realistic data quality
95/// issues. Each processor can modify records in place and generate labels
96/// describing the modifications for ML training.
97pub trait PostProcessor: Send + Sync {
98    /// The type of records this processor modifies.
99    type Record;
100    /// The type of labels this processor produces.
101    type Label;
102
103    /// Process a single record, potentially modifying it and generating labels.
104    ///
105    /// Returns a vector of labels describing any modifications made.
106    fn process(
107        &mut self,
108        record: &mut Self::Record,
109        context: &ProcessContext,
110    ) -> SynthResult<Vec<Self::Label>>;
111
112    /// Process a batch of records.
113    ///
114    /// Default implementation calls process for each record.
115    fn process_batch(
116        &mut self,
117        records: &mut [Self::Record],
118        base_context: &ProcessContext,
119    ) -> SynthResult<Vec<Self::Label>> {
120        let mut all_labels = Vec::new();
121        let batch_size = records.len();
122
123        for (i, record) in records.iter_mut().enumerate() {
124            let context = ProcessContext {
125                record_index: i,
126                batch_size,
127                output_format: base_context.output_format.clone(),
128                metadata: base_context.metadata.clone(),
129            };
130            let labels = self.process(record, &context)?;
131            all_labels.extend(labels);
132        }
133
134        Ok(all_labels)
135    }
136
137    /// Get the name of this processor.
138    fn name(&self) -> &'static str;
139
140    /// Check if this processor is enabled.
141    fn is_enabled(&self) -> bool;
142
143    /// Get processing statistics.
144    fn stats(&self) -> ProcessorStats;
145
146    /// Reset statistics (for testing or between batches).
147    fn reset_stats(&mut self);
148}
149
150/// A pipeline of post-processors applied in sequence.
151pub struct PostProcessorPipeline<R, L> {
152    processors: Vec<Box<dyn PostProcessor<Record = R, Label = L>>>,
153    stats: ProcessorStats,
154}
155
156impl<R, L> PostProcessorPipeline<R, L> {
157    /// Create a new empty pipeline.
158    pub fn new() -> Self {
159        Self {
160            processors: Vec::new(),
161            stats: ProcessorStats::default(),
162        }
163    }
164
165    /// Add a processor to the pipeline.
166    pub fn add<P>(&mut self, processor: P)
167    where
168        P: PostProcessor<Record = R, Label = L> + 'static,
169    {
170        self.processors.push(Box::new(processor));
171    }
172
173    /// Add a processor and return self for chaining.
174    pub fn with<P>(mut self, processor: P) -> Self
175    where
176        P: PostProcessor<Record = R, Label = L> + 'static,
177    {
178        self.add(processor);
179        self
180    }
181
182    /// Process a single record through all processors.
183    pub fn process(&mut self, record: &mut R, context: &ProcessContext) -> SynthResult<Vec<L>> {
184        let mut all_labels = Vec::new();
185
186        for processor in &mut self.processors {
187            if processor.is_enabled() {
188                let labels = processor.process(record, context)?;
189                all_labels.extend(labels);
190            }
191        }
192
193        self.stats.records_processed += 1;
194        if !all_labels.is_empty() {
195            self.stats.records_modified += 1;
196        }
197        self.stats.labels_generated += all_labels.len() as u64;
198
199        Ok(all_labels)
200    }
201
202    /// Process a batch of records through all processors.
203    pub fn process_batch(
204        &mut self,
205        records: &mut [R],
206        base_context: &ProcessContext,
207    ) -> SynthResult<Vec<L>> {
208        let mut all_labels = Vec::new();
209        let batch_size = records.len();
210
211        for (i, record) in records.iter_mut().enumerate() {
212            let context = ProcessContext {
213                record_index: i,
214                batch_size,
215                output_format: base_context.output_format.clone(),
216                metadata: base_context.metadata.clone(),
217            };
218            let labels = self.process(record, &context)?;
219            all_labels.extend(labels);
220        }
221
222        Ok(all_labels)
223    }
224
225    /// Get aggregate statistics for the pipeline.
226    ///
227    /// Returns the pipeline's own stats tracking records processed through
228    /// the entire pipeline. Use `processor_stats()` to get individual
229    /// processor statistics.
230    pub fn stats(&self) -> ProcessorStats {
231        self.stats.clone()
232    }
233
234    /// Get individual processor statistics.
235    pub fn processor_stats(&self) -> Vec<(&'static str, ProcessorStats)> {
236        self.processors
237            .iter()
238            .map(|p| (p.name(), p.stats()))
239            .collect()
240    }
241
242    /// Check if pipeline has any enabled processors.
243    pub fn has_enabled_processors(&self) -> bool {
244        self.processors.iter().any(|p| p.is_enabled())
245    }
246
247    /// Get number of processors in the pipeline.
248    pub fn len(&self) -> usize {
249        self.processors.len()
250    }
251
252    /// Check if pipeline is empty.
253    pub fn is_empty(&self) -> bool {
254        self.processors.is_empty()
255    }
256
257    /// Reset all statistics.
258    pub fn reset_stats(&mut self) {
259        self.stats = ProcessorStats::default();
260        for processor in &mut self.processors {
261            processor.reset_stats();
262        }
263    }
264}
265
266impl<R, L> Default for PostProcessorPipeline<R, L> {
267    fn default() -> Self {
268        Self::new()
269    }
270}
271
272/// A no-op processor that passes records through unchanged.
273pub struct PassthroughProcessor<R, L> {
274    enabled: bool,
275    stats: ProcessorStats,
276    _phantom: std::marker::PhantomData<(R, L)>,
277}
278
279impl<R, L> PassthroughProcessor<R, L> {
280    /// Create a new passthrough processor.
281    pub fn new() -> Self {
282        Self {
283            enabled: true,
284            stats: ProcessorStats::default(),
285            _phantom: std::marker::PhantomData,
286        }
287    }
288
289    /// Create a disabled passthrough processor.
290    pub fn disabled() -> Self {
291        Self {
292            enabled: false,
293            stats: ProcessorStats::default(),
294            _phantom: std::marker::PhantomData,
295        }
296    }
297}
298
299impl<R, L> Default for PassthroughProcessor<R, L> {
300    fn default() -> Self {
301        Self::new()
302    }
303}
304
305impl<R: Send + Sync, L: Send + Sync> PostProcessor for PassthroughProcessor<R, L> {
306    type Record = R;
307    type Label = L;
308
309    fn process(
310        &mut self,
311        _record: &mut Self::Record,
312        _context: &ProcessContext,
313    ) -> SynthResult<Vec<Self::Label>> {
314        self.stats.records_processed += 1;
315        Ok(Vec::new())
316    }
317
318    fn name(&self) -> &'static str {
319        "passthrough"
320    }
321
322    fn is_enabled(&self) -> bool {
323        self.enabled
324    }
325
326    fn stats(&self) -> ProcessorStats {
327        self.stats.clone()
328    }
329
330    fn reset_stats(&mut self) {
331        self.stats = ProcessorStats::default();
332    }
333}
334
335/// Builder for creating post-processor pipelines.
336pub struct PipelineBuilder<R, L> {
337    pipeline: PostProcessorPipeline<R, L>,
338}
339
340impl<R, L> PipelineBuilder<R, L> {
341    /// Create a new pipeline builder.
342    pub fn new() -> Self {
343        Self {
344            pipeline: PostProcessorPipeline::new(),
345        }
346    }
347
348    /// Add a processor to the pipeline.
349    #[allow(clippy::should_implement_trait)]
350    pub fn add<P>(mut self, processor: P) -> Self
351    where
352        P: PostProcessor<Record = R, Label = L> + 'static,
353    {
354        self.pipeline.add(processor);
355        self
356    }
357
358    /// Conditionally add a processor.
359    pub fn add_if<P>(mut self, condition: bool, processor: P) -> Self
360    where
361        P: PostProcessor<Record = R, Label = L> + 'static,
362    {
363        if condition {
364            self.pipeline.add(processor);
365        }
366        self
367    }
368
369    /// Build the pipeline.
370    pub fn build(self) -> PostProcessorPipeline<R, L> {
371        self.pipeline
372    }
373}
374
375impl<R, L> Default for PipelineBuilder<R, L> {
376    fn default() -> Self {
377        Self::new()
378    }
379}
380
381#[cfg(test)]
382#[allow(clippy::unwrap_used)]
383mod tests {
384    use super::*;
385
386    // Simple test record type
387    #[derive(Debug, Clone)]
388    struct TestRecord {
389        value: String,
390    }
391
392    // Simple test label type
393    #[allow(dead_code)]
394    #[derive(Debug, Clone)]
395    struct TestLabel {
396        field: String,
397        change: String,
398    }
399
400    // Test processor that uppercases strings
401    struct UppercaseProcessor {
402        enabled: bool,
403        stats: ProcessorStats,
404    }
405
406    impl UppercaseProcessor {
407        fn new() -> Self {
408            Self {
409                enabled: true,
410                stats: ProcessorStats::default(),
411            }
412        }
413    }
414
415    impl PostProcessor for UppercaseProcessor {
416        type Record = TestRecord;
417        type Label = TestLabel;
418
419        fn process(
420            &mut self,
421            record: &mut Self::Record,
422            _context: &ProcessContext,
423        ) -> SynthResult<Vec<Self::Label>> {
424            self.stats.records_processed += 1;
425            let original = record.value.clone();
426            record.value = record.value.to_uppercase();
427            if record.value != original {
428                self.stats.records_modified += 1;
429                self.stats.labels_generated += 1;
430                Ok(vec![TestLabel {
431                    field: "value".to_string(),
432                    change: format!("{} -> {}", original, record.value),
433                }])
434            } else {
435                Ok(vec![])
436            }
437        }
438
439        fn name(&self) -> &'static str {
440            "uppercase"
441        }
442
443        fn is_enabled(&self) -> bool {
444            self.enabled
445        }
446
447        fn stats(&self) -> ProcessorStats {
448            self.stats.clone()
449        }
450
451        fn reset_stats(&mut self) {
452            self.stats = ProcessorStats::default();
453        }
454    }
455
456    #[test]
457    fn test_pipeline_basic() {
458        let mut pipeline = PostProcessorPipeline::new();
459        pipeline.add(UppercaseProcessor::new());
460
461        let mut record = TestRecord {
462            value: "hello".to_string(),
463        };
464        let context = ProcessContext::new(0, 1);
465
466        let labels = pipeline.process(&mut record, &context).unwrap();
467
468        assert_eq!(record.value, "HELLO");
469        assert_eq!(labels.len(), 1);
470        assert_eq!(labels[0].field, "value");
471    }
472
473    #[test]
474    fn test_pipeline_batch() {
475        let mut pipeline = PostProcessorPipeline::new();
476        pipeline.add(UppercaseProcessor::new());
477
478        let mut records = vec![
479            TestRecord {
480                value: "a".to_string(),
481            },
482            TestRecord {
483                value: "b".to_string(),
484            },
485            TestRecord {
486                value: "c".to_string(),
487            },
488        ];
489        let context = ProcessContext::new(0, 3);
490
491        let labels = pipeline.process_batch(&mut records, &context).unwrap();
492
493        assert_eq!(records[0].value, "A");
494        assert_eq!(records[1].value, "B");
495        assert_eq!(records[2].value, "C");
496        assert_eq!(labels.len(), 3);
497    }
498
499    #[test]
500    fn test_pipeline_stats() {
501        let mut pipeline = PostProcessorPipeline::new();
502        pipeline.add(UppercaseProcessor::new());
503
504        let context = ProcessContext::new(0, 1);
505
506        for _ in 0..5 {
507            let mut record = TestRecord {
508                value: "test".to_string(),
509            };
510            let _ = pipeline.process(&mut record, &context);
511        }
512
513        let stats = pipeline.stats();
514        assert_eq!(stats.records_processed, 5);
515        assert_eq!(stats.records_modified, 5);
516    }
517
518    #[test]
519    fn test_passthrough_processor() {
520        let mut processor = PassthroughProcessor::<TestRecord, TestLabel>::new();
521        let mut record = TestRecord {
522            value: "unchanged".to_string(),
523        };
524        let context = ProcessContext::new(0, 1);
525
526        let labels = processor.process(&mut record, &context).unwrap();
527
528        assert_eq!(record.value, "unchanged");
529        assert!(labels.is_empty());
530    }
531
532    #[test]
533    fn test_pipeline_builder() {
534        let pipeline: PostProcessorPipeline<TestRecord, TestLabel> = PipelineBuilder::new()
535            .add(UppercaseProcessor::new())
536            .add_if(false, PassthroughProcessor::new())
537            .build();
538
539        assert_eq!(pipeline.len(), 1);
540    }
541}