1use crate::error::SynthResult;
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Default)]
12pub struct ProcessContext {
13 pub record_index: usize,
15 pub batch_size: usize,
17 pub output_format: Option<String>,
19 pub metadata: HashMap<String, String>,
21}
22
23impl ProcessContext {
24 pub fn new(record_index: usize, batch_size: usize) -> Self {
26 Self {
27 record_index,
28 batch_size,
29 output_format: None,
30 metadata: HashMap::new(),
31 }
32 }
33
34 pub fn with_format(mut self, format: impl Into<String>) -> Self {
36 self.output_format = Some(format.into());
37 self
38 }
39
40 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
42 self.metadata.insert(key.into(), value.into());
43 self
44 }
45
46 pub fn is_first(&self) -> bool {
48 self.record_index == 0
49 }
50
51 pub fn is_last(&self) -> bool {
53 self.record_index == self.batch_size.saturating_sub(1)
54 }
55}
56
57#[derive(Debug, Clone, Default)]
59pub struct ProcessorStats {
60 pub records_processed: u64,
62 pub records_modified: u64,
64 pub labels_generated: u64,
66 pub errors_encountered: u64,
68 pub processing_time_us: u64,
70}
71
72impl ProcessorStats {
73 pub fn modification_rate(&self) -> f64 {
75 if self.records_processed == 0 {
76 0.0
77 } else {
78 self.records_modified as f64 / self.records_processed as f64
79 }
80 }
81
82 pub fn merge(&mut self, other: &ProcessorStats) {
84 self.records_processed += other.records_processed;
85 self.records_modified += other.records_modified;
86 self.labels_generated += other.labels_generated;
87 self.errors_encountered += other.errors_encountered;
88 self.processing_time_us += other.processing_time_us;
89 }
90}
91
92pub trait PostProcessor: Send + Sync {
98 type Record;
100 type Label;
102
103 fn process(
107 &mut self,
108 record: &mut Self::Record,
109 context: &ProcessContext,
110 ) -> SynthResult<Vec<Self::Label>>;
111
112 fn process_batch(
116 &mut self,
117 records: &mut [Self::Record],
118 base_context: &ProcessContext,
119 ) -> SynthResult<Vec<Self::Label>> {
120 let mut all_labels = Vec::new();
121 let batch_size = records.len();
122
123 for (i, record) in records.iter_mut().enumerate() {
124 let context = ProcessContext {
125 record_index: i,
126 batch_size,
127 output_format: base_context.output_format.clone(),
128 metadata: base_context.metadata.clone(),
129 };
130 let labels = self.process(record, &context)?;
131 all_labels.extend(labels);
132 }
133
134 Ok(all_labels)
135 }
136
137 fn name(&self) -> &'static str;
139
140 fn is_enabled(&self) -> bool;
142
143 fn stats(&self) -> ProcessorStats;
145
146 fn reset_stats(&mut self);
148}
149
150pub struct PostProcessorPipeline<R, L> {
152 processors: Vec<Box<dyn PostProcessor<Record = R, Label = L>>>,
153 stats: ProcessorStats,
154}
155
156impl<R, L> PostProcessorPipeline<R, L> {
157 pub fn new() -> Self {
159 Self {
160 processors: Vec::new(),
161 stats: ProcessorStats::default(),
162 }
163 }
164
165 pub fn add<P>(&mut self, processor: P)
167 where
168 P: PostProcessor<Record = R, Label = L> + 'static,
169 {
170 self.processors.push(Box::new(processor));
171 }
172
173 pub fn with<P>(mut self, processor: P) -> Self
175 where
176 P: PostProcessor<Record = R, Label = L> + 'static,
177 {
178 self.add(processor);
179 self
180 }
181
182 pub fn process(&mut self, record: &mut R, context: &ProcessContext) -> SynthResult<Vec<L>> {
184 let mut all_labels = Vec::new();
185
186 for processor in &mut self.processors {
187 if processor.is_enabled() {
188 let labels = processor.process(record, context)?;
189 all_labels.extend(labels);
190 }
191 }
192
193 self.stats.records_processed += 1;
194 if !all_labels.is_empty() {
195 self.stats.records_modified += 1;
196 }
197 self.stats.labels_generated += all_labels.len() as u64;
198
199 Ok(all_labels)
200 }
201
202 pub fn process_batch(
204 &mut self,
205 records: &mut [R],
206 base_context: &ProcessContext,
207 ) -> SynthResult<Vec<L>> {
208 let mut all_labels = Vec::new();
209 let batch_size = records.len();
210
211 for (i, record) in records.iter_mut().enumerate() {
212 let context = ProcessContext {
213 record_index: i,
214 batch_size,
215 output_format: base_context.output_format.clone(),
216 metadata: base_context.metadata.clone(),
217 };
218 let labels = self.process(record, &context)?;
219 all_labels.extend(labels);
220 }
221
222 Ok(all_labels)
223 }
224
225 pub fn stats(&self) -> ProcessorStats {
231 self.stats.clone()
232 }
233
234 pub fn processor_stats(&self) -> Vec<(&'static str, ProcessorStats)> {
236 self.processors
237 .iter()
238 .map(|p| (p.name(), p.stats()))
239 .collect()
240 }
241
242 pub fn has_enabled_processors(&self) -> bool {
244 self.processors.iter().any(|p| p.is_enabled())
245 }
246
247 pub fn len(&self) -> usize {
249 self.processors.len()
250 }
251
252 pub fn is_empty(&self) -> bool {
254 self.processors.is_empty()
255 }
256
257 pub fn reset_stats(&mut self) {
259 self.stats = ProcessorStats::default();
260 for processor in &mut self.processors {
261 processor.reset_stats();
262 }
263 }
264}
265
266impl<R, L> Default for PostProcessorPipeline<R, L> {
267 fn default() -> Self {
268 Self::new()
269 }
270}
271
272pub struct PassthroughProcessor<R, L> {
274 enabled: bool,
275 stats: ProcessorStats,
276 _phantom: std::marker::PhantomData<(R, L)>,
277}
278
279impl<R, L> PassthroughProcessor<R, L> {
280 pub fn new() -> Self {
282 Self {
283 enabled: true,
284 stats: ProcessorStats::default(),
285 _phantom: std::marker::PhantomData,
286 }
287 }
288
289 pub fn disabled() -> Self {
291 Self {
292 enabled: false,
293 stats: ProcessorStats::default(),
294 _phantom: std::marker::PhantomData,
295 }
296 }
297}
298
299impl<R, L> Default for PassthroughProcessor<R, L> {
300 fn default() -> Self {
301 Self::new()
302 }
303}
304
305impl<R: Send + Sync, L: Send + Sync> PostProcessor for PassthroughProcessor<R, L> {
306 type Record = R;
307 type Label = L;
308
309 fn process(
310 &mut self,
311 _record: &mut Self::Record,
312 _context: &ProcessContext,
313 ) -> SynthResult<Vec<Self::Label>> {
314 self.stats.records_processed += 1;
315 Ok(Vec::new())
316 }
317
318 fn name(&self) -> &'static str {
319 "passthrough"
320 }
321
322 fn is_enabled(&self) -> bool {
323 self.enabled
324 }
325
326 fn stats(&self) -> ProcessorStats {
327 self.stats.clone()
328 }
329
330 fn reset_stats(&mut self) {
331 self.stats = ProcessorStats::default();
332 }
333}
334
335pub struct PipelineBuilder<R, L> {
337 pipeline: PostProcessorPipeline<R, L>,
338}
339
340impl<R, L> PipelineBuilder<R, L> {
341 pub fn new() -> Self {
343 Self {
344 pipeline: PostProcessorPipeline::new(),
345 }
346 }
347
348 #[allow(clippy::should_implement_trait)]
350 pub fn add<P>(mut self, processor: P) -> Self
351 where
352 P: PostProcessor<Record = R, Label = L> + 'static,
353 {
354 self.pipeline.add(processor);
355 self
356 }
357
358 pub fn add_if<P>(mut self, condition: bool, processor: P) -> Self
360 where
361 P: PostProcessor<Record = R, Label = L> + 'static,
362 {
363 if condition {
364 self.pipeline.add(processor);
365 }
366 self
367 }
368
369 pub fn build(self) -> PostProcessorPipeline<R, L> {
371 self.pipeline
372 }
373}
374
375impl<R, L> Default for PipelineBuilder<R, L> {
376 fn default() -> Self {
377 Self::new()
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384
385 #[derive(Debug, Clone)]
387 struct TestRecord {
388 value: String,
389 }
390
391 #[derive(Debug, Clone)]
393 struct TestLabel {
394 field: String,
395 change: String,
396 }
397
398 struct UppercaseProcessor {
400 enabled: bool,
401 stats: ProcessorStats,
402 }
403
404 impl UppercaseProcessor {
405 fn new() -> Self {
406 Self {
407 enabled: true,
408 stats: ProcessorStats::default(),
409 }
410 }
411 }
412
413 impl PostProcessor for UppercaseProcessor {
414 type Record = TestRecord;
415 type Label = TestLabel;
416
417 fn process(
418 &mut self,
419 record: &mut Self::Record,
420 _context: &ProcessContext,
421 ) -> SynthResult<Vec<Self::Label>> {
422 self.stats.records_processed += 1;
423 let original = record.value.clone();
424 record.value = record.value.to_uppercase();
425 if record.value != original {
426 self.stats.records_modified += 1;
427 self.stats.labels_generated += 1;
428 Ok(vec![TestLabel {
429 field: "value".to_string(),
430 change: format!("{} -> {}", original, record.value),
431 }])
432 } else {
433 Ok(vec![])
434 }
435 }
436
437 fn name(&self) -> &'static str {
438 "uppercase"
439 }
440
441 fn is_enabled(&self) -> bool {
442 self.enabled
443 }
444
445 fn stats(&self) -> ProcessorStats {
446 self.stats.clone()
447 }
448
449 fn reset_stats(&mut self) {
450 self.stats = ProcessorStats::default();
451 }
452 }
453
454 #[test]
455 fn test_pipeline_basic() {
456 let mut pipeline = PostProcessorPipeline::new();
457 pipeline.add(UppercaseProcessor::new());
458
459 let mut record = TestRecord {
460 value: "hello".to_string(),
461 };
462 let context = ProcessContext::new(0, 1);
463
464 let labels = pipeline.process(&mut record, &context).unwrap();
465
466 assert_eq!(record.value, "HELLO");
467 assert_eq!(labels.len(), 1);
468 assert_eq!(labels[0].field, "value");
469 }
470
471 #[test]
472 fn test_pipeline_batch() {
473 let mut pipeline = PostProcessorPipeline::new();
474 pipeline.add(UppercaseProcessor::new());
475
476 let mut records = vec![
477 TestRecord {
478 value: "a".to_string(),
479 },
480 TestRecord {
481 value: "b".to_string(),
482 },
483 TestRecord {
484 value: "c".to_string(),
485 },
486 ];
487 let context = ProcessContext::new(0, 3);
488
489 let labels = pipeline.process_batch(&mut records, &context).unwrap();
490
491 assert_eq!(records[0].value, "A");
492 assert_eq!(records[1].value, "B");
493 assert_eq!(records[2].value, "C");
494 assert_eq!(labels.len(), 3);
495 }
496
497 #[test]
498 fn test_pipeline_stats() {
499 let mut pipeline = PostProcessorPipeline::new();
500 pipeline.add(UppercaseProcessor::new());
501
502 let context = ProcessContext::new(0, 1);
503
504 for _ in 0..5 {
505 let mut record = TestRecord {
506 value: "test".to_string(),
507 };
508 let _ = pipeline.process(&mut record, &context);
509 }
510
511 let stats = pipeline.stats();
512 assert_eq!(stats.records_processed, 5);
513 assert_eq!(stats.records_modified, 5);
514 }
515
516 #[test]
517 fn test_passthrough_processor() {
518 let mut processor = PassthroughProcessor::<TestRecord, TestLabel>::new();
519 let mut record = TestRecord {
520 value: "unchanged".to_string(),
521 };
522 let context = ProcessContext::new(0, 1);
523
524 let labels = processor.process(&mut record, &context).unwrap();
525
526 assert_eq!(record.value, "unchanged");
527 assert!(labels.is_empty());
528 }
529
530 #[test]
531 fn test_pipeline_builder() {
532 let pipeline: PostProcessorPipeline<TestRecord, TestLabel> = PipelineBuilder::new()
533 .add(UppercaseProcessor::new())
534 .add_if(false, PassthroughProcessor::new())
535 .build();
536
537 assert_eq!(pipeline.len(), 1);
538 }
539}