datasynth_core/traits/
post_processor.rs

1//! Post-processor trait for data quality variations and other post-generation transformations.
2//!
3//! Post-processors modify records after generation to inject data quality issues,
4//! format variations, typos, and other realistic flakiness. They produce labels
5//! that can be used for ML training.
6
7use crate::error::SynthResult;
8use std::collections::HashMap;
9
10/// Context passed to post-processors during processing.
11#[derive(Debug, Clone, Default)]
12pub struct ProcessContext {
13    /// Current record index in the batch
14    pub record_index: usize,
15    /// Total records in the batch
16    pub batch_size: usize,
17    /// Current output format (csv, json, parquet)
18    pub output_format: Option<String>,
19    /// Additional context data
20    pub metadata: HashMap<String, String>,
21}
22
23impl ProcessContext {
24    /// Create a new processing context.
25    pub fn new(record_index: usize, batch_size: usize) -> Self {
26        Self {
27            record_index,
28            batch_size,
29            output_format: None,
30            metadata: HashMap::new(),
31        }
32    }
33
34    /// Set the output format.
35    pub fn with_format(mut self, format: impl Into<String>) -> Self {
36        self.output_format = Some(format.into());
37        self
38    }
39
40    /// Add metadata.
41    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
42        self.metadata.insert(key.into(), value.into());
43        self
44    }
45
46    /// Check if processing first record.
47    pub fn is_first(&self) -> bool {
48        self.record_index == 0
49    }
50
51    /// Check if processing last record.
52    pub fn is_last(&self) -> bool {
53        self.record_index == self.batch_size.saturating_sub(1)
54    }
55}
56
57/// Statistics from a post-processor run.
58#[derive(Debug, Clone, Default)]
59pub struct ProcessorStats {
60    /// Number of records processed
61    pub records_processed: u64,
62    /// Number of records modified
63    pub records_modified: u64,
64    /// Number of labels generated
65    pub labels_generated: u64,
66    /// Number of errors encountered
67    pub errors_encountered: u64,
68    /// Processing time in microseconds
69    pub processing_time_us: u64,
70}
71
72impl ProcessorStats {
73    /// Calculate modification rate.
74    pub fn modification_rate(&self) -> f64 {
75        if self.records_processed == 0 {
76            0.0
77        } else {
78            self.records_modified as f64 / self.records_processed as f64
79        }
80    }
81
82    /// Merge stats from another processor.
83    pub fn merge(&mut self, other: &ProcessorStats) {
84        self.records_processed += other.records_processed;
85        self.records_modified += other.records_modified;
86        self.labels_generated += other.labels_generated;
87        self.errors_encountered += other.errors_encountered;
88        self.processing_time_us += other.processing_time_us;
89    }
90}
91
92/// Core trait for post-processors that modify records and generate labels.
93///
94/// Post-processors are applied after generation to inject realistic data quality
95/// issues. Each processor can modify records in place and generate labels
96/// describing the modifications for ML training.
97pub trait PostProcessor: Send + Sync {
98    /// The type of records this processor modifies.
99    type Record;
100    /// The type of labels this processor produces.
101    type Label;
102
103    /// Process a single record, potentially modifying it and generating labels.
104    ///
105    /// Returns a vector of labels describing any modifications made.
106    fn process(
107        &mut self,
108        record: &mut Self::Record,
109        context: &ProcessContext,
110    ) -> SynthResult<Vec<Self::Label>>;
111
112    /// Process a batch of records.
113    ///
114    /// Default implementation calls process for each record.
115    fn process_batch(
116        &mut self,
117        records: &mut [Self::Record],
118        base_context: &ProcessContext,
119    ) -> SynthResult<Vec<Self::Label>> {
120        let mut all_labels = Vec::new();
121        let batch_size = records.len();
122
123        for (i, record) in records.iter_mut().enumerate() {
124            let context = ProcessContext {
125                record_index: i,
126                batch_size,
127                output_format: base_context.output_format.clone(),
128                metadata: base_context.metadata.clone(),
129            };
130            let labels = self.process(record, &context)?;
131            all_labels.extend(labels);
132        }
133
134        Ok(all_labels)
135    }
136
137    /// Get the name of this processor.
138    fn name(&self) -> &'static str;
139
140    /// Check if this processor is enabled.
141    fn is_enabled(&self) -> bool;
142
143    /// Get processing statistics.
144    fn stats(&self) -> ProcessorStats;
145
146    /// Reset statistics (for testing or between batches).
147    fn reset_stats(&mut self);
148}
149
150/// A pipeline of post-processors applied in sequence.
151pub struct PostProcessorPipeline<R, L> {
152    processors: Vec<Box<dyn PostProcessor<Record = R, Label = L>>>,
153    stats: ProcessorStats,
154}
155
156impl<R, L> PostProcessorPipeline<R, L> {
157    /// Create a new empty pipeline.
158    pub fn new() -> Self {
159        Self {
160            processors: Vec::new(),
161            stats: ProcessorStats::default(),
162        }
163    }
164
165    /// Add a processor to the pipeline.
166    pub fn add<P>(&mut self, processor: P)
167    where
168        P: PostProcessor<Record = R, Label = L> + 'static,
169    {
170        self.processors.push(Box::new(processor));
171    }
172
173    /// Add a processor and return self for chaining.
174    pub fn with<P>(mut self, processor: P) -> Self
175    where
176        P: PostProcessor<Record = R, Label = L> + 'static,
177    {
178        self.add(processor);
179        self
180    }
181
182    /// Process a single record through all processors.
183    pub fn process(&mut self, record: &mut R, context: &ProcessContext) -> SynthResult<Vec<L>> {
184        let mut all_labels = Vec::new();
185
186        for processor in &mut self.processors {
187            if processor.is_enabled() {
188                let labels = processor.process(record, context)?;
189                all_labels.extend(labels);
190            }
191        }
192
193        self.stats.records_processed += 1;
194        if !all_labels.is_empty() {
195            self.stats.records_modified += 1;
196        }
197        self.stats.labels_generated += all_labels.len() as u64;
198
199        Ok(all_labels)
200    }
201
202    /// Process a batch of records through all processors.
203    pub fn process_batch(
204        &mut self,
205        records: &mut [R],
206        base_context: &ProcessContext,
207    ) -> SynthResult<Vec<L>> {
208        let mut all_labels = Vec::new();
209        let batch_size = records.len();
210
211        for (i, record) in records.iter_mut().enumerate() {
212            let context = ProcessContext {
213                record_index: i,
214                batch_size,
215                output_format: base_context.output_format.clone(),
216                metadata: base_context.metadata.clone(),
217            };
218            let labels = self.process(record, &context)?;
219            all_labels.extend(labels);
220        }
221
222        Ok(all_labels)
223    }
224
225    /// Get aggregate statistics for the pipeline.
226    ///
227    /// Returns the pipeline's own stats tracking records processed through
228    /// the entire pipeline. Use `processor_stats()` to get individual
229    /// processor statistics.
230    pub fn stats(&self) -> ProcessorStats {
231        self.stats.clone()
232    }
233
234    /// Get individual processor statistics.
235    pub fn processor_stats(&self) -> Vec<(&'static str, ProcessorStats)> {
236        self.processors
237            .iter()
238            .map(|p| (p.name(), p.stats()))
239            .collect()
240    }
241
242    /// Check if pipeline has any enabled processors.
243    pub fn has_enabled_processors(&self) -> bool {
244        self.processors.iter().any(|p| p.is_enabled())
245    }
246
247    /// Get number of processors in the pipeline.
248    pub fn len(&self) -> usize {
249        self.processors.len()
250    }
251
252    /// Check if pipeline is empty.
253    pub fn is_empty(&self) -> bool {
254        self.processors.is_empty()
255    }
256
257    /// Reset all statistics.
258    pub fn reset_stats(&mut self) {
259        self.stats = ProcessorStats::default();
260        for processor in &mut self.processors {
261            processor.reset_stats();
262        }
263    }
264}
265
266impl<R, L> Default for PostProcessorPipeline<R, L> {
267    fn default() -> Self {
268        Self::new()
269    }
270}
271
272/// A no-op processor that passes records through unchanged.
273pub struct PassthroughProcessor<R, L> {
274    enabled: bool,
275    stats: ProcessorStats,
276    _phantom: std::marker::PhantomData<(R, L)>,
277}
278
279impl<R, L> PassthroughProcessor<R, L> {
280    /// Create a new passthrough processor.
281    pub fn new() -> Self {
282        Self {
283            enabled: true,
284            stats: ProcessorStats::default(),
285            _phantom: std::marker::PhantomData,
286        }
287    }
288
289    /// Create a disabled passthrough processor.
290    pub fn disabled() -> Self {
291        Self {
292            enabled: false,
293            stats: ProcessorStats::default(),
294            _phantom: std::marker::PhantomData,
295        }
296    }
297}
298
299impl<R, L> Default for PassthroughProcessor<R, L> {
300    fn default() -> Self {
301        Self::new()
302    }
303}
304
305impl<R: Send + Sync, L: Send + Sync> PostProcessor for PassthroughProcessor<R, L> {
306    type Record = R;
307    type Label = L;
308
309    fn process(
310        &mut self,
311        _record: &mut Self::Record,
312        _context: &ProcessContext,
313    ) -> SynthResult<Vec<Self::Label>> {
314        self.stats.records_processed += 1;
315        Ok(Vec::new())
316    }
317
318    fn name(&self) -> &'static str {
319        "passthrough"
320    }
321
322    fn is_enabled(&self) -> bool {
323        self.enabled
324    }
325
326    fn stats(&self) -> ProcessorStats {
327        self.stats.clone()
328    }
329
330    fn reset_stats(&mut self) {
331        self.stats = ProcessorStats::default();
332    }
333}
334
335/// Builder for creating post-processor pipelines.
336pub struct PipelineBuilder<R, L> {
337    pipeline: PostProcessorPipeline<R, L>,
338}
339
340impl<R, L> PipelineBuilder<R, L> {
341    /// Create a new pipeline builder.
342    pub fn new() -> Self {
343        Self {
344            pipeline: PostProcessorPipeline::new(),
345        }
346    }
347
348    /// Add a processor to the pipeline.
349    #[allow(clippy::should_implement_trait)]
350    pub fn add<P>(mut self, processor: P) -> Self
351    where
352        P: PostProcessor<Record = R, Label = L> + 'static,
353    {
354        self.pipeline.add(processor);
355        self
356    }
357
358    /// Conditionally add a processor.
359    pub fn add_if<P>(mut self, condition: bool, processor: P) -> Self
360    where
361        P: PostProcessor<Record = R, Label = L> + 'static,
362    {
363        if condition {
364            self.pipeline.add(processor);
365        }
366        self
367    }
368
369    /// Build the pipeline.
370    pub fn build(self) -> PostProcessorPipeline<R, L> {
371        self.pipeline
372    }
373}
374
375impl<R, L> Default for PipelineBuilder<R, L> {
376    fn default() -> Self {
377        Self::new()
378    }
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384
385    // Simple test record type
386    #[derive(Debug, Clone)]
387    struct TestRecord {
388        value: String,
389    }
390
391    // Simple test label type
392    #[allow(dead_code)]
393    #[derive(Debug, Clone)]
394    struct TestLabel {
395        field: String,
396        change: String,
397    }
398
399    // Test processor that uppercases strings
400    struct UppercaseProcessor {
401        enabled: bool,
402        stats: ProcessorStats,
403    }
404
405    impl UppercaseProcessor {
406        fn new() -> Self {
407            Self {
408                enabled: true,
409                stats: ProcessorStats::default(),
410            }
411        }
412    }
413
414    impl PostProcessor for UppercaseProcessor {
415        type Record = TestRecord;
416        type Label = TestLabel;
417
418        fn process(
419            &mut self,
420            record: &mut Self::Record,
421            _context: &ProcessContext,
422        ) -> SynthResult<Vec<Self::Label>> {
423            self.stats.records_processed += 1;
424            let original = record.value.clone();
425            record.value = record.value.to_uppercase();
426            if record.value != original {
427                self.stats.records_modified += 1;
428                self.stats.labels_generated += 1;
429                Ok(vec![TestLabel {
430                    field: "value".to_string(),
431                    change: format!("{} -> {}", original, record.value),
432                }])
433            } else {
434                Ok(vec![])
435            }
436        }
437
438        fn name(&self) -> &'static str {
439            "uppercase"
440        }
441
442        fn is_enabled(&self) -> bool {
443            self.enabled
444        }
445
446        fn stats(&self) -> ProcessorStats {
447            self.stats.clone()
448        }
449
450        fn reset_stats(&mut self) {
451            self.stats = ProcessorStats::default();
452        }
453    }
454
455    #[test]
456    fn test_pipeline_basic() {
457        let mut pipeline = PostProcessorPipeline::new();
458        pipeline.add(UppercaseProcessor::new());
459
460        let mut record = TestRecord {
461            value: "hello".to_string(),
462        };
463        let context = ProcessContext::new(0, 1);
464
465        let labels = pipeline.process(&mut record, &context).unwrap();
466
467        assert_eq!(record.value, "HELLO");
468        assert_eq!(labels.len(), 1);
469        assert_eq!(labels[0].field, "value");
470    }
471
472    #[test]
473    fn test_pipeline_batch() {
474        let mut pipeline = PostProcessorPipeline::new();
475        pipeline.add(UppercaseProcessor::new());
476
477        let mut records = vec![
478            TestRecord {
479                value: "a".to_string(),
480            },
481            TestRecord {
482                value: "b".to_string(),
483            },
484            TestRecord {
485                value: "c".to_string(),
486            },
487        ];
488        let context = ProcessContext::new(0, 3);
489
490        let labels = pipeline.process_batch(&mut records, &context).unwrap();
491
492        assert_eq!(records[0].value, "A");
493        assert_eq!(records[1].value, "B");
494        assert_eq!(records[2].value, "C");
495        assert_eq!(labels.len(), 3);
496    }
497
498    #[test]
499    fn test_pipeline_stats() {
500        let mut pipeline = PostProcessorPipeline::new();
501        pipeline.add(UppercaseProcessor::new());
502
503        let context = ProcessContext::new(0, 1);
504
505        for _ in 0..5 {
506            let mut record = TestRecord {
507                value: "test".to_string(),
508            };
509            let _ = pipeline.process(&mut record, &context);
510        }
511
512        let stats = pipeline.stats();
513        assert_eq!(stats.records_processed, 5);
514        assert_eq!(stats.records_modified, 5);
515    }
516
517    #[test]
518    fn test_passthrough_processor() {
519        let mut processor = PassthroughProcessor::<TestRecord, TestLabel>::new();
520        let mut record = TestRecord {
521            value: "unchanged".to_string(),
522        };
523        let context = ProcessContext::new(0, 1);
524
525        let labels = processor.process(&mut record, &context).unwrap();
526
527        assert_eq!(record.value, "unchanged");
528        assert!(labels.is_empty());
529    }
530
531    #[test]
532    fn test_pipeline_builder() {
533        let pipeline: PostProcessorPipeline<TestRecord, TestLabel> = PipelineBuilder::new()
534            .add(UppercaseProcessor::new())
535            .add_if(false, PassthroughProcessor::new())
536            .build();
537
538        assert_eq!(pipeline.len(), 1);
539    }
540}