1use crate::error::SynthResult;
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Default)]
12pub struct ProcessContext {
13 pub record_index: usize,
15 pub batch_size: usize,
17 pub output_format: Option<String>,
19 pub metadata: HashMap<String, String>,
21}
22
23impl ProcessContext {
24 pub fn new(record_index: usize, batch_size: usize) -> Self {
26 Self {
27 record_index,
28 batch_size,
29 output_format: None,
30 metadata: HashMap::new(),
31 }
32 }
33
34 pub fn with_format(mut self, format: impl Into<String>) -> Self {
36 self.output_format = Some(format.into());
37 self
38 }
39
40 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
42 self.metadata.insert(key.into(), value.into());
43 self
44 }
45
46 pub fn is_first(&self) -> bool {
48 self.record_index == 0
49 }
50
51 pub fn is_last(&self) -> bool {
53 self.record_index == self.batch_size.saturating_sub(1)
54 }
55}
56
57#[derive(Debug, Clone, Default)]
59pub struct ProcessorStats {
60 pub records_processed: u64,
62 pub records_modified: u64,
64 pub labels_generated: u64,
66 pub errors_encountered: u64,
68 pub processing_time_us: u64,
70}
71
72impl ProcessorStats {
73 pub fn modification_rate(&self) -> f64 {
75 if self.records_processed == 0 {
76 0.0
77 } else {
78 self.records_modified as f64 / self.records_processed as f64
79 }
80 }
81
82 pub fn merge(&mut self, other: &ProcessorStats) {
84 self.records_processed += other.records_processed;
85 self.records_modified += other.records_modified;
86 self.labels_generated += other.labels_generated;
87 self.errors_encountered += other.errors_encountered;
88 self.processing_time_us += other.processing_time_us;
89 }
90}
91
92pub trait PostProcessor: Send + Sync {
98 type Record;
100 type Label;
102
103 fn process(
107 &mut self,
108 record: &mut Self::Record,
109 context: &ProcessContext,
110 ) -> SynthResult<Vec<Self::Label>>;
111
112 fn process_batch(
116 &mut self,
117 records: &mut [Self::Record],
118 base_context: &ProcessContext,
119 ) -> SynthResult<Vec<Self::Label>> {
120 let mut all_labels = Vec::new();
121 let batch_size = records.len();
122
123 for (i, record) in records.iter_mut().enumerate() {
124 let context = ProcessContext {
125 record_index: i,
126 batch_size,
127 output_format: base_context.output_format.clone(),
128 metadata: base_context.metadata.clone(),
129 };
130 let labels = self.process(record, &context)?;
131 all_labels.extend(labels);
132 }
133
134 Ok(all_labels)
135 }
136
137 fn name(&self) -> &'static str;
139
140 fn is_enabled(&self) -> bool;
142
143 fn stats(&self) -> ProcessorStats;
145
146 fn reset_stats(&mut self);
148}
149
150pub struct PostProcessorPipeline<R, L> {
152 processors: Vec<Box<dyn PostProcessor<Record = R, Label = L>>>,
153 stats: ProcessorStats,
154}
155
156impl<R, L> PostProcessorPipeline<R, L> {
157 pub fn new() -> Self {
159 Self {
160 processors: Vec::new(),
161 stats: ProcessorStats::default(),
162 }
163 }
164
165 pub fn add<P>(&mut self, processor: P)
167 where
168 P: PostProcessor<Record = R, Label = L> + 'static,
169 {
170 self.processors.push(Box::new(processor));
171 }
172
173 pub fn with<P>(mut self, processor: P) -> Self
175 where
176 P: PostProcessor<Record = R, Label = L> + 'static,
177 {
178 self.add(processor);
179 self
180 }
181
182 pub fn process(&mut self, record: &mut R, context: &ProcessContext) -> SynthResult<Vec<L>> {
184 let mut all_labels = Vec::new();
185
186 for processor in &mut self.processors {
187 if processor.is_enabled() {
188 let labels = processor.process(record, context)?;
189 all_labels.extend(labels);
190 }
191 }
192
193 self.stats.records_processed += 1;
194 if !all_labels.is_empty() {
195 self.stats.records_modified += 1;
196 }
197 self.stats.labels_generated += all_labels.len() as u64;
198
199 Ok(all_labels)
200 }
201
202 pub fn process_batch(
204 &mut self,
205 records: &mut [R],
206 base_context: &ProcessContext,
207 ) -> SynthResult<Vec<L>> {
208 let mut all_labels = Vec::new();
209 let batch_size = records.len();
210
211 for (i, record) in records.iter_mut().enumerate() {
212 let context = ProcessContext {
213 record_index: i,
214 batch_size,
215 output_format: base_context.output_format.clone(),
216 metadata: base_context.metadata.clone(),
217 };
218 let labels = self.process(record, &context)?;
219 all_labels.extend(labels);
220 }
221
222 Ok(all_labels)
223 }
224
225 pub fn stats(&self) -> ProcessorStats {
231 self.stats.clone()
232 }
233
234 pub fn processor_stats(&self) -> Vec<(&'static str, ProcessorStats)> {
236 self.processors
237 .iter()
238 .map(|p| (p.name(), p.stats()))
239 .collect()
240 }
241
242 pub fn has_enabled_processors(&self) -> bool {
244 self.processors.iter().any(|p| p.is_enabled())
245 }
246
247 pub fn len(&self) -> usize {
249 self.processors.len()
250 }
251
252 pub fn is_empty(&self) -> bool {
254 self.processors.is_empty()
255 }
256
257 pub fn reset_stats(&mut self) {
259 self.stats = ProcessorStats::default();
260 for processor in &mut self.processors {
261 processor.reset_stats();
262 }
263 }
264}
265
266impl<R, L> Default for PostProcessorPipeline<R, L> {
267 fn default() -> Self {
268 Self::new()
269 }
270}
271
272pub struct PassthroughProcessor<R, L> {
274 enabled: bool,
275 stats: ProcessorStats,
276 _phantom: std::marker::PhantomData<(R, L)>,
277}
278
279impl<R, L> PassthroughProcessor<R, L> {
280 pub fn new() -> Self {
282 Self {
283 enabled: true,
284 stats: ProcessorStats::default(),
285 _phantom: std::marker::PhantomData,
286 }
287 }
288
289 pub fn disabled() -> Self {
291 Self {
292 enabled: false,
293 stats: ProcessorStats::default(),
294 _phantom: std::marker::PhantomData,
295 }
296 }
297}
298
299impl<R, L> Default for PassthroughProcessor<R, L> {
300 fn default() -> Self {
301 Self::new()
302 }
303}
304
305impl<R: Send + Sync, L: Send + Sync> PostProcessor for PassthroughProcessor<R, L> {
306 type Record = R;
307 type Label = L;
308
309 fn process(
310 &mut self,
311 _record: &mut Self::Record,
312 _context: &ProcessContext,
313 ) -> SynthResult<Vec<Self::Label>> {
314 self.stats.records_processed += 1;
315 Ok(Vec::new())
316 }
317
318 fn name(&self) -> &'static str {
319 "passthrough"
320 }
321
322 fn is_enabled(&self) -> bool {
323 self.enabled
324 }
325
326 fn stats(&self) -> ProcessorStats {
327 self.stats.clone()
328 }
329
330 fn reset_stats(&mut self) {
331 self.stats = ProcessorStats::default();
332 }
333}
334
335pub struct PipelineBuilder<R, L> {
337 pipeline: PostProcessorPipeline<R, L>,
338}
339
340impl<R, L> PipelineBuilder<R, L> {
341 pub fn new() -> Self {
343 Self {
344 pipeline: PostProcessorPipeline::new(),
345 }
346 }
347
348 #[allow(clippy::should_implement_trait)]
350 pub fn add<P>(mut self, processor: P) -> Self
351 where
352 P: PostProcessor<Record = R, Label = L> + 'static,
353 {
354 self.pipeline.add(processor);
355 self
356 }
357
358 pub fn add_if<P>(mut self, condition: bool, processor: P) -> Self
360 where
361 P: PostProcessor<Record = R, Label = L> + 'static,
362 {
363 if condition {
364 self.pipeline.add(processor);
365 }
366 self
367 }
368
369 pub fn build(self) -> PostProcessorPipeline<R, L> {
371 self.pipeline
372 }
373}
374
375impl<R, L> Default for PipelineBuilder<R, L> {
376 fn default() -> Self {
377 Self::new()
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384
385 #[derive(Debug, Clone)]
387 struct TestRecord {
388 value: String,
389 }
390
391 #[allow(dead_code)]
393 #[derive(Debug, Clone)]
394 struct TestLabel {
395 field: String,
396 change: String,
397 }
398
399 struct UppercaseProcessor {
401 enabled: bool,
402 stats: ProcessorStats,
403 }
404
405 impl UppercaseProcessor {
406 fn new() -> Self {
407 Self {
408 enabled: true,
409 stats: ProcessorStats::default(),
410 }
411 }
412 }
413
414 impl PostProcessor for UppercaseProcessor {
415 type Record = TestRecord;
416 type Label = TestLabel;
417
418 fn process(
419 &mut self,
420 record: &mut Self::Record,
421 _context: &ProcessContext,
422 ) -> SynthResult<Vec<Self::Label>> {
423 self.stats.records_processed += 1;
424 let original = record.value.clone();
425 record.value = record.value.to_uppercase();
426 if record.value != original {
427 self.stats.records_modified += 1;
428 self.stats.labels_generated += 1;
429 Ok(vec![TestLabel {
430 field: "value".to_string(),
431 change: format!("{} -> {}", original, record.value),
432 }])
433 } else {
434 Ok(vec![])
435 }
436 }
437
438 fn name(&self) -> &'static str {
439 "uppercase"
440 }
441
442 fn is_enabled(&self) -> bool {
443 self.enabled
444 }
445
446 fn stats(&self) -> ProcessorStats {
447 self.stats.clone()
448 }
449
450 fn reset_stats(&mut self) {
451 self.stats = ProcessorStats::default();
452 }
453 }
454
455 #[test]
456 fn test_pipeline_basic() {
457 let mut pipeline = PostProcessorPipeline::new();
458 pipeline.add(UppercaseProcessor::new());
459
460 let mut record = TestRecord {
461 value: "hello".to_string(),
462 };
463 let context = ProcessContext::new(0, 1);
464
465 let labels = pipeline.process(&mut record, &context).unwrap();
466
467 assert_eq!(record.value, "HELLO");
468 assert_eq!(labels.len(), 1);
469 assert_eq!(labels[0].field, "value");
470 }
471
472 #[test]
473 fn test_pipeline_batch() {
474 let mut pipeline = PostProcessorPipeline::new();
475 pipeline.add(UppercaseProcessor::new());
476
477 let mut records = vec![
478 TestRecord {
479 value: "a".to_string(),
480 },
481 TestRecord {
482 value: "b".to_string(),
483 },
484 TestRecord {
485 value: "c".to_string(),
486 },
487 ];
488 let context = ProcessContext::new(0, 3);
489
490 let labels = pipeline.process_batch(&mut records, &context).unwrap();
491
492 assert_eq!(records[0].value, "A");
493 assert_eq!(records[1].value, "B");
494 assert_eq!(records[2].value, "C");
495 assert_eq!(labels.len(), 3);
496 }
497
498 #[test]
499 fn test_pipeline_stats() {
500 let mut pipeline = PostProcessorPipeline::new();
501 pipeline.add(UppercaseProcessor::new());
502
503 let context = ProcessContext::new(0, 1);
504
505 for _ in 0..5 {
506 let mut record = TestRecord {
507 value: "test".to_string(),
508 };
509 let _ = pipeline.process(&mut record, &context);
510 }
511
512 let stats = pipeline.stats();
513 assert_eq!(stats.records_processed, 5);
514 assert_eq!(stats.records_modified, 5);
515 }
516
517 #[test]
518 fn test_passthrough_processor() {
519 let mut processor = PassthroughProcessor::<TestRecord, TestLabel>::new();
520 let mut record = TestRecord {
521 value: "unchanged".to_string(),
522 };
523 let context = ProcessContext::new(0, 1);
524
525 let labels = processor.process(&mut record, &context).unwrap();
526
527 assert_eq!(record.value, "unchanged");
528 assert!(labels.is_empty());
529 }
530
531 #[test]
532 fn test_pipeline_builder() {
533 let pipeline: PostProcessorPipeline<TestRecord, TestLabel> = PipelineBuilder::new()
534 .add(UppercaseProcessor::new())
535 .add_if(false, PassthroughProcessor::new())
536 .build();
537
538 assert_eq!(pipeline.len(), 1);
539 }
540}