use super::zerocopy::ZeroCopyBuffer;
use crate::error::{Result, StreamingError};
use async_trait::async_trait;
use std::time::Instant;
#[derive(Debug)]
pub struct StageResult {
pub data: ZeroCopyBuffer,
pub execution_time_ms: u64,
pub bytes_processed: usize,
pub metadata: std::collections::HashMap<String, String>,
}
impl StageResult {
pub fn new(data: ZeroCopyBuffer, execution_time_ms: u64) -> Self {
let bytes_processed = data.len();
Self {
data,
execution_time_ms,
bytes_processed,
metadata: std::collections::HashMap::new(),
}
}
pub fn with_metadata(mut self, key: String, value: String) -> Self {
self.metadata.insert(key, value);
self
}
}
#[async_trait]
pub trait PipelineStage: Send + Sync {
fn name(&self) -> &str;
async fn process(&self, input: ZeroCopyBuffer) -> Result<StageResult>;
async fn initialize(&self) -> Result<()> {
Ok(())
}
async fn finalize(&self) -> Result<()> {
Ok(())
}
}
pub struct TransformStage<F>
where
F: Fn(&[u8]) -> Result<Vec<u8>> + Send + Sync,
{
name: String,
transform_fn: F,
}
impl<F> TransformStage<F>
where
F: Fn(&[u8]) -> Result<Vec<u8>> + Send + Sync,
{
pub fn new(name: String, transform_fn: F) -> Self {
Self {
name,
transform_fn,
}
}
}
#[async_trait]
impl<F> PipelineStage for TransformStage<F>
where
F: Fn(&[u8]) -> Result<Vec<u8>> + Send + Sync,
{
fn name(&self) -> &str {
&self.name
}
async fn process(&self, input: ZeroCopyBuffer) -> Result<StageResult> {
let start = Instant::now();
let output = (self.transform_fn)(input.as_ref())?;
let output_buffer = ZeroCopyBuffer::new(bytes::Bytes::from(output));
let elapsed = start.elapsed().as_millis() as u64;
Ok(StageResult::new(output_buffer, elapsed))
}
}
pub struct FilterStage<F>
where
F: Fn(&[u8]) -> bool + Send + Sync,
{
name: String,
filter_fn: F,
}
impl<F> FilterStage<F>
where
F: Fn(&[u8]) -> bool + Send + Sync,
{
pub fn new(name: String, filter_fn: F) -> Self {
Self { name, filter_fn }
}
}
#[async_trait]
impl<F> PipelineStage for FilterStage<F>
where
F: Fn(&[u8]) -> bool + Send + Sync,
{
fn name(&self) -> &str {
&self.name
}
async fn process(&self, input: ZeroCopyBuffer) -> Result<StageResult> {
let start = Instant::now();
if (self.filter_fn)(input.as_ref()) {
let elapsed = start.elapsed().as_millis() as u64;
Ok(StageResult::new(input, elapsed))
} else {
Err(StreamingError::Other("Filtered out".to_string()))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
#[tokio::test]
async fn test_transform_stage() {
let stage = TransformStage::new(
"double".to_string(),
|data| Ok(data.iter().map(|&x| x * 2).collect()),
);
let input = ZeroCopyBuffer::new(Bytes::from(vec![1, 2, 3]));
let result = stage.process(input).await.ok();
assert!(result.is_some());
if let Some(result) = result {
assert_eq!(result.data.as_ref(), &[2, 4, 6]);
}
}
#[tokio::test]
async fn test_filter_stage() {
let stage = FilterStage::new(
"non_empty".to_string(),
|data| !data.is_empty(),
);
let input = ZeroCopyBuffer::new(Bytes::from(vec![1, 2, 3]));
let result = stage.process(input).await;
assert!(result.is_ok());
let empty = ZeroCopyBuffer::new(Bytes::new());
let result = stage.process(empty).await;
assert!(result.is_err());
}
}