use crate::error::SynthResult;
use std::collections::HashMap;
#[derive(Debug, Clone, Default)]
pub struct ProcessContext {
pub record_index: usize,
pub batch_size: usize,
pub output_format: Option<String>,
pub metadata: HashMap<String, String>,
}
impl ProcessContext {
pub fn new(record_index: usize, batch_size: usize) -> Self {
Self {
record_index,
batch_size,
output_format: None,
metadata: HashMap::new(),
}
}
pub fn with_format(mut self, format: impl Into<String>) -> Self {
self.output_format = Some(format.into());
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
pub fn is_first(&self) -> bool {
self.record_index == 0
}
pub fn is_last(&self) -> bool {
self.record_index == self.batch_size.saturating_sub(1)
}
}
#[derive(Debug, Clone, Default)]
pub struct ProcessorStats {
pub records_processed: u64,
pub records_modified: u64,
pub labels_generated: u64,
pub errors_encountered: u64,
pub processing_time_us: u64,
}
impl ProcessorStats {
pub fn modification_rate(&self) -> f64 {
if self.records_processed == 0 {
0.0
} else {
self.records_modified as f64 / self.records_processed as f64
}
}
pub fn merge(&mut self, other: &ProcessorStats) {
self.records_processed += other.records_processed;
self.records_modified += other.records_modified;
self.labels_generated += other.labels_generated;
self.errors_encountered += other.errors_encountered;
self.processing_time_us += other.processing_time_us;
}
}
pub trait PostProcessor: Send + Sync {
type Record;
type Label;
fn process(
&mut self,
record: &mut Self::Record,
context: &ProcessContext,
) -> SynthResult<Vec<Self::Label>>;
fn process_batch(
&mut self,
records: &mut [Self::Record],
base_context: &ProcessContext,
) -> SynthResult<Vec<Self::Label>> {
let mut all_labels = Vec::new();
let batch_size = records.len();
for (i, record) in records.iter_mut().enumerate() {
let context = ProcessContext {
record_index: i,
batch_size,
output_format: base_context.output_format.clone(),
metadata: base_context.metadata.clone(),
};
let labels = self.process(record, &context)?;
all_labels.extend(labels);
}
Ok(all_labels)
}
fn name(&self) -> &'static str;
fn is_enabled(&self) -> bool;
fn stats(&self) -> ProcessorStats;
fn reset_stats(&mut self);
}
pub struct PostProcessorPipeline<R, L> {
processors: Vec<Box<dyn PostProcessor<Record = R, Label = L>>>,
stats: ProcessorStats,
}
impl<R, L> PostProcessorPipeline<R, L> {
pub fn new() -> Self {
Self {
processors: Vec::new(),
stats: ProcessorStats::default(),
}
}
pub fn add<P>(&mut self, processor: P)
where
P: PostProcessor<Record = R, Label = L> + 'static,
{
self.processors.push(Box::new(processor));
}
pub fn with<P>(mut self, processor: P) -> Self
where
P: PostProcessor<Record = R, Label = L> + 'static,
{
self.add(processor);
self
}
pub fn process(&mut self, record: &mut R, context: &ProcessContext) -> SynthResult<Vec<L>> {
let mut all_labels = Vec::new();
for processor in &mut self.processors {
if processor.is_enabled() {
let labels = processor.process(record, context)?;
all_labels.extend(labels);
}
}
self.stats.records_processed += 1;
if !all_labels.is_empty() {
self.stats.records_modified += 1;
}
self.stats.labels_generated += all_labels.len() as u64;
Ok(all_labels)
}
pub fn process_batch(
&mut self,
records: &mut [R],
base_context: &ProcessContext,
) -> SynthResult<Vec<L>> {
let mut all_labels = Vec::new();
let batch_size = records.len();
for (i, record) in records.iter_mut().enumerate() {
let context = ProcessContext {
record_index: i,
batch_size,
output_format: base_context.output_format.clone(),
metadata: base_context.metadata.clone(),
};
let labels = self.process(record, &context)?;
all_labels.extend(labels);
}
Ok(all_labels)
}
pub fn stats(&self) -> ProcessorStats {
self.stats.clone()
}
pub fn processor_stats(&self) -> Vec<(&'static str, ProcessorStats)> {
self.processors
.iter()
.map(|p| (p.name(), p.stats()))
.collect()
}
pub fn has_enabled_processors(&self) -> bool {
self.processors.iter().any(|p| p.is_enabled())
}
pub fn len(&self) -> usize {
self.processors.len()
}
pub fn is_empty(&self) -> bool {
self.processors.is_empty()
}
pub fn reset_stats(&mut self) {
self.stats = ProcessorStats::default();
for processor in &mut self.processors {
processor.reset_stats();
}
}
}
impl<R, L> Default for PostProcessorPipeline<R, L> {
fn default() -> Self {
Self::new()
}
}
pub struct PassthroughProcessor<R, L> {
enabled: bool,
stats: ProcessorStats,
_phantom: std::marker::PhantomData<(R, L)>,
}
impl<R, L> PassthroughProcessor<R, L> {
pub fn new() -> Self {
Self {
enabled: true,
stats: ProcessorStats::default(),
_phantom: std::marker::PhantomData,
}
}
pub fn disabled() -> Self {
Self {
enabled: false,
stats: ProcessorStats::default(),
_phantom: std::marker::PhantomData,
}
}
}
impl<R, L> Default for PassthroughProcessor<R, L> {
fn default() -> Self {
Self::new()
}
}
impl<R: Send + Sync, L: Send + Sync> PostProcessor for PassthroughProcessor<R, L> {
type Record = R;
type Label = L;
fn process(
&mut self,
_record: &mut Self::Record,
_context: &ProcessContext,
) -> SynthResult<Vec<Self::Label>> {
self.stats.records_processed += 1;
Ok(Vec::new())
}
fn name(&self) -> &'static str {
"passthrough"
}
fn is_enabled(&self) -> bool {
self.enabled
}
fn stats(&self) -> ProcessorStats {
self.stats.clone()
}
fn reset_stats(&mut self) {
self.stats = ProcessorStats::default();
}
}
pub struct PipelineBuilder<R, L> {
pipeline: PostProcessorPipeline<R, L>,
}
impl<R, L> PipelineBuilder<R, L> {
pub fn new() -> Self {
Self {
pipeline: PostProcessorPipeline::new(),
}
}
#[allow(clippy::should_implement_trait)]
pub fn add<P>(mut self, processor: P) -> Self
where
P: PostProcessor<Record = R, Label = L> + 'static,
{
self.pipeline.add(processor);
self
}
pub fn add_if<P>(mut self, condition: bool, processor: P) -> Self
where
P: PostProcessor<Record = R, Label = L> + 'static,
{
if condition {
self.pipeline.add(processor);
}
self
}
pub fn build(self) -> PostProcessorPipeline<R, L> {
self.pipeline
}
}
impl<R, L> Default for PipelineBuilder<R, L> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[derive(Debug, Clone)]
struct TestRecord {
value: String,
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
struct TestLabel {
field: String,
change: String,
}
struct UppercaseProcessor {
enabled: bool,
stats: ProcessorStats,
}
impl UppercaseProcessor {
fn new() -> Self {
Self {
enabled: true,
stats: ProcessorStats::default(),
}
}
}
impl PostProcessor for UppercaseProcessor {
type Record = TestRecord;
type Label = TestLabel;
fn process(
&mut self,
record: &mut Self::Record,
_context: &ProcessContext,
) -> SynthResult<Vec<Self::Label>> {
self.stats.records_processed += 1;
let original = record.value.clone();
record.value = record.value.to_uppercase();
if record.value != original {
self.stats.records_modified += 1;
self.stats.labels_generated += 1;
Ok(vec![TestLabel {
field: "value".to_string(),
change: format!("{} -> {}", original, record.value),
}])
} else {
Ok(vec![])
}
}
fn name(&self) -> &'static str {
"uppercase"
}
fn is_enabled(&self) -> bool {
self.enabled
}
fn stats(&self) -> ProcessorStats {
self.stats.clone()
}
fn reset_stats(&mut self) {
self.stats = ProcessorStats::default();
}
}
#[test]
fn test_pipeline_basic() {
let mut pipeline = PostProcessorPipeline::new();
pipeline.add(UppercaseProcessor::new());
let mut record = TestRecord {
value: "hello".to_string(),
};
let context = ProcessContext::new(0, 1);
let labels = pipeline.process(&mut record, &context).unwrap();
assert_eq!(record.value, "HELLO");
assert_eq!(labels.len(), 1);
assert_eq!(labels[0].field, "value");
}
#[test]
fn test_pipeline_batch() {
let mut pipeline = PostProcessorPipeline::new();
pipeline.add(UppercaseProcessor::new());
let mut records = vec![
TestRecord {
value: "a".to_string(),
},
TestRecord {
value: "b".to_string(),
},
TestRecord {
value: "c".to_string(),
},
];
let context = ProcessContext::new(0, 3);
let labels = pipeline.process_batch(&mut records, &context).unwrap();
assert_eq!(records[0].value, "A");
assert_eq!(records[1].value, "B");
assert_eq!(records[2].value, "C");
assert_eq!(labels.len(), 3);
}
#[test]
fn test_pipeline_stats() {
let mut pipeline = PostProcessorPipeline::new();
pipeline.add(UppercaseProcessor::new());
let context = ProcessContext::new(0, 1);
for _ in 0..5 {
let mut record = TestRecord {
value: "test".to_string(),
};
let _ = pipeline.process(&mut record, &context);
}
let stats = pipeline.stats();
assert_eq!(stats.records_processed, 5);
assert_eq!(stats.records_modified, 5);
}
#[test]
fn test_passthrough_processor() {
let mut processor = PassthroughProcessor::<TestRecord, TestLabel>::new();
let mut record = TestRecord {
value: "unchanged".to_string(),
};
let context = ProcessContext::new(0, 1);
let labels = processor.process(&mut record, &context).unwrap();
assert_eq!(record.value, "unchanged");
assert!(labels.is_empty());
}
#[test]
fn test_pipeline_builder() {
let pipeline: PostProcessorPipeline<TestRecord, TestLabel> = PipelineBuilder::new()
.add(UppercaseProcessor::new())
.add_if(false, PassthroughProcessor::new())
.build();
assert_eq!(pipeline.len(), 1);
}
}