use thiserror::Error;
#[derive(Debug, Error)]
#[error("stage '{stage}' failed: {source}")]
pub struct StageError {
pub stage: &'static str,
#[source]
pub source: Box<dyn std::error::Error + Send + Sync>,
}
impl StageError {
pub fn new(
stage: &'static str,
source: impl std::error::Error + Send + Sync + 'static,
) -> Self {
Self {
stage,
source: Box::new(source),
}
}
}
#[derive(Debug, Clone)]
pub struct SanitizeContext {
pub content: String,
pub was_truncated: bool,
}
impl SanitizeContext {
#[must_use]
pub fn new(content: String) -> Self {
Self {
content,
was_truncated: false,
}
}
}
pub trait Stage: Send + Sync {
fn name(&self) -> &str;
fn process(&self, ctx: SanitizeContext) -> Result<SanitizeContext, StageError>;
}
pub struct Pipeline {
stages: Vec<Box<dyn Stage>>,
}
impl Pipeline {
#[must_use]
pub fn new() -> Self {
Self { stages: Vec::new() }
}
pub fn add_stage(&mut self, stage: impl Stage + 'static) {
self.stages.push(Box::new(stage));
}
pub fn process(&self, mut ctx: SanitizeContext) -> Result<SanitizeContext, StageError> {
for stage in &self.stages {
ctx = stage.process(ctx)?;
}
Ok(ctx)
}
}
impl Default for Pipeline {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
struct AppendStage(&'static str);
impl Stage for AppendStage {
fn name(&self) -> &str {
"append"
}
fn process(&self, mut ctx: SanitizeContext) -> Result<SanitizeContext, StageError> {
ctx.content.push_str(self.0);
Ok(ctx)
}
}
struct FailStage;
impl Stage for FailStage {
fn name(&self) -> &str {
"fail"
}
fn process(&self, _ctx: SanitizeContext) -> Result<SanitizeContext, StageError> {
Err(StageError::new(
"fail",
std::io::Error::other("intentional"),
))
}
}
#[test]
fn empty_pipeline_passes_through() {
let pipeline = Pipeline::new();
let ctx = SanitizeContext::new("hello".to_owned());
let out = pipeline.process(ctx).unwrap();
assert_eq!(out.content, "hello");
assert!(!out.was_truncated);
}
#[test]
fn stages_run_in_order() {
let mut pipeline = Pipeline::new();
pipeline.add_stage(AppendStage(" world"));
pipeline.add_stage(AppendStage("!"));
let out = pipeline
.process(SanitizeContext::new("hello".to_owned()))
.unwrap();
assert_eq!(out.content, "hello world!");
}
#[test]
fn error_aborts_pipeline() {
let mut pipeline = Pipeline::new();
pipeline.add_stage(FailStage);
pipeline.add_stage(AppendStage(" unreachable"));
let err = pipeline
.process(SanitizeContext::new("x".to_owned()))
.unwrap_err();
assert!(err.to_string().contains("fail"));
}
#[test]
fn truncated_flag_propagates() {
struct TruncateStage;
impl Stage for TruncateStage {
fn name(&self) -> &str {
"truncate"
}
fn process(&self, mut ctx: SanitizeContext) -> Result<SanitizeContext, StageError> {
ctx.content.truncate(3);
ctx.was_truncated = true;
Ok(ctx)
}
}
let mut pipeline = Pipeline::new();
pipeline.add_stage(TruncateStage);
let out = pipeline
.process(SanitizeContext::new("hello".to_owned()))
.unwrap();
assert!(out.was_truncated);
assert_eq!(out.content, "hel");
}
}