logo
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
use std::sync::Arc;

use serde::Deserialize;

use crate::dataflow::{
    context::OneInOneOutContext,
    message::Message,
    operator::{OneInOneOut, OperatorConfig},
    stream::{OperatorStream, Stream, WriteStreamT},
    Data,
};

/// Filters an incoming stream of type D, retaining messages in the stream that
/// the provided condition function evaluates to true when applied.
///
/// # Example
/// The below example shows how to use a FilterOperator to keep only messages > 10 in an incoming
/// stream of usize messages, and send them.
///
/// ```
/// # use erdos::dataflow::{stream::IngestStream, operator::{OperatorConfig}, operators::{FilterOperator}};
/// # let source_stream = IngestStream::new();
/// // Add the mapping function as an argument to the operator via the OperatorConfig.
/// let filter_config = OperatorConfig::new().name("FilterOperator");
/// let filter_stream = erdos::connect_one_in_one_out(
///     || -> FilterOperator<usize> { FilterOperator::new(|a: &usize| -> bool { a > &10 }) },
///     || {},
///     filter_config,
///     &source_stream,
/// );
/// ```
pub struct FilterOperator<D>
where
    D: Data + for<'a> Deserialize<'a>,
{
    filter_function: Arc<dyn Fn(&D) -> bool + Send + Sync>,
}

impl<D> FilterOperator<D>
where
    D: Data + for<'a> Deserialize<'a>,
{
    pub fn new<F>(filter_function: F) -> Self
    where
        F: 'static + Fn(&D) -> bool + Send + Sync,
    {
        Self {
            filter_function: Arc::new(filter_function),
        }
    }
}

impl<D> OneInOneOut<(), D, D> for FilterOperator<D>
where
    D: Data + for<'a> Deserialize<'a>,
{
    fn on_data(&mut self, ctx: &mut OneInOneOutContext<(), D>, data: &D) {
        let timestamp = ctx.timestamp().clone();
        if (self.filter_function)(data) {
            ctx.write_stream()
                .send(Message::new_message(timestamp, data.clone()))
                .unwrap();
            tracing::debug!(
                "{} @ {:?}: received {:?} and sent it",
                ctx.operator_config().get_name(),
                ctx.timestamp(),
                data,
            );
        }
    }

    fn on_watermark(&mut self, _ctx: &mut OneInOneOutContext<(), D>) {}
}

// Extension trait for FilterOperator
pub trait Filter<D>
where
    D: Data + for<'a> Deserialize<'a>,
{
    fn filter<F>(&self, filter_fn: F) -> OperatorStream<D>
    where
        F: 'static + Fn(&D) -> bool + Send + Sync + Clone;
}

impl<S, D> Filter<D> for S
where
    S: Stream<D>,
    D: Data + for<'a> Deserialize<'a>,
{
    fn filter<F>(&self, filter_fn: F) -> OperatorStream<D>
    where
        F: 'static + Fn(&D) -> bool + Send + Sync + Clone,
    {
        let op_name = format!("FilterOp_{}", self.id());

        crate::connect_one_in_one_out(
            move || -> FilterOperator<D> { FilterOperator::new(filter_fn.clone()) },
            || {},
            OperatorConfig::new().name(&op_name),
            self,
        )
    }
}