use crate::errors::AnalysisError;
use std::marker::PhantomData;
pub trait Stage {
type Input;
type Output;
type Error;
fn execute(&self, input: Self::Input) -> Result<Self::Output, Self::Error>;
fn name(&self) -> &str;
}
pub struct PureStage<F, I, O> {
name: String,
func: F,
_phantom: PhantomData<(I, O)>,
}
impl<F, I, O> PureStage<F, I, O>
where
F: Fn(I) -> O,
{
pub fn new(name: impl Into<String>, func: F) -> Self {
Self {
name: name.into(),
func,
_phantom: PhantomData,
}
}
}
impl<F, I, O> Stage for PureStage<F, I, O>
where
F: Fn(I) -> O,
{
type Input = I;
type Output = O;
type Error = std::convert::Infallible;
fn execute(&self, input: Self::Input) -> Result<Self::Output, Self::Error> {
Ok((self.func)(input))
}
fn name(&self) -> &str {
&self.name
}
}
pub struct FallibleStage<F, I, O, E> {
name: String,
func: F,
_phantom: PhantomData<(I, O, E)>,
}
impl<F, I, O, E> FallibleStage<F, I, O, E>
where
F: Fn(I) -> Result<O, E>,
{
pub fn new(name: impl Into<String>, func: F) -> Self {
Self {
name: name.into(),
func,
_phantom: PhantomData,
}
}
}
impl<F, I, O, E> Stage for FallibleStage<F, I, O, E>
where
F: Fn(I) -> Result<O, E>,
{
type Input = I;
type Output = O;
type Error = E;
fn execute(&self, input: Self::Input) -> Result<Self::Output, Self::Error> {
(self.func)(input)
}
fn name(&self) -> &str {
&self.name
}
}
pub(crate) trait AnyStage: Send + Sync {
fn execute_any(
&self,
input: Box<dyn std::any::Any>,
) -> Result<Box<dyn std::any::Any>, AnalysisError>;
fn name(&self) -> &str;
}
impl<S> AnyStage for S
where
S: Stage + Send + Sync,
S::Input: 'static,
S::Output: 'static,
S::Error: Into<AnalysisError>,
{
fn execute_any(
&self,
input: Box<dyn std::any::Any>,
) -> Result<Box<dyn std::any::Any>, AnalysisError> {
let typed_input = input
.downcast::<S::Input>()
.map_err(|_| AnalysisError::other("Type mismatch in pipeline stage input"))?;
let output = self.execute(*typed_input).map_err(|e| e.into())?;
Ok(Box::new(output))
}
fn name(&self) -> &str {
Stage::name(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pure_stage_execution() {
let stage = PureStage::new("Double", |x: i32| x * 2);
let result = stage.execute(21).unwrap();
assert_eq!(result, 42);
}
#[test]
fn test_fallible_stage_success() {
let stage = FallibleStage::new("Parse", |s: String| {
s.parse::<i32>().map_err(|_| "Parse error")
});
let result = stage.execute("42".to_string()).unwrap();
assert_eq!(result, 42);
}
#[test]
fn test_fallible_stage_failure() {
let stage = FallibleStage::new("Parse", |s: String| {
s.parse::<i32>().map_err(|_| "Parse error")
});
let result = stage.execute("not a number".to_string());
assert!(result.is_err());
}
#[test]
fn test_stage_name() {
let stage = PureStage::new("Test Stage", |x: i32| x);
assert_eq!(Stage::name(&stage), "Test Stage");
}
}