data_pipeline_rs/
node.rs

1use std::{
2    sync::{Arc, Mutex},
3    time::{Duration, Instant},
4};
5
6use anyhow::anyhow;
7use serde_json::json;
8
9use crate::{
10    data_handler::SomeDataHandler, node_visitor::NodeVisitor, stats_producer::StatsProducer,
11};
12
13#[derive(Default)]
14struct NodeStatsTracker {
15    data_ingress: u32,
16    data_egress: u32,
17    data_discarded: u32,
18    errors: u32,
19    total_processing_time: Duration,
20}
21
22/// [`NodeRef`] is really just a convenience helper to hide how references are held and take care
23/// of the mutex logic
24pub struct NodeRef<T>(Arc<Mutex<Node<T>>>);
25
26// derive(Clone) doesn't work with the generic
27impl<T> Clone for NodeRef<T> {
28    fn clone(&self) -> Self {
29        Self(self.0.clone())
30    }
31}
32
33impl<T> From<Node<T>> for NodeRef<T> {
34    fn from(val: Node<T>) -> Self {
35        NodeRef::new(val)
36    }
37}
38
39impl<T> NodeRef<T> {
40    pub fn new(node: Node<T>) -> NodeRef<T> {
41        NodeRef(Arc::new(Mutex::new(node)))
42    }
43
44    pub fn name(&self) -> String {
45        self.0.lock().unwrap().name.clone()
46    }
47
48    pub fn set_next(&self, next: NodeRef<T>) {
49        self.0.lock().unwrap().next = Some(next)
50    }
51
52    pub fn set_prev(&self, prev: NodeRef<T>) {
53        self.0.lock().unwrap().prev = Some(prev)
54    }
55
56    pub fn process_data(&self, data: T) {
57        self.0.lock().unwrap().process_data(data);
58    }
59
60    pub fn visit(&self, visitor: &mut dyn NodeVisitor<T>) {
61        self.0.lock().unwrap().visit(visitor)
62    }
63}
64
65/// A helper type to model the possible outcomes of passing data to [`SomeDataHandler`].
66// Note: I had a thought to get rid of 'ForwardToNext' and always use 'ForwardTo', but non-demux
67// nodes don't know/care if there _is_ a next node.  For a demuxer, which uses ForwardTo, there
68// will be a definitive next node to forward to, so at this point I think the two variants make
69// sense?
70enum SomeDataHandlerResult<'a, T> {
71    // The given data should be forwarded to the next node
72    ForwardToNext(T),
73    // The given data should be forwarded to the given node
74    ForwardTo(T, &'a NodeRef<T>),
75    // The data was consumed
76    Consumed,
77    // The data should be discarded
78    Discard,
79}
80
81pub struct Node<T> {
82    name: String,
83    handler: SomeDataHandler<T>,
84    stats: NodeStatsTracker,
85    next: Option<NodeRef<T>>,
86    prev: Option<NodeRef<T>>,
87}
88
89impl<T> Node<T> {
90    pub fn new<U: Into<String>, V: Into<SomeDataHandler<T>>>(name: U, handler: V) -> Node<T> {
91        Self {
92            name: name.into(),
93            handler: handler.into(),
94            stats: NodeStatsTracker::default(),
95            next: None,
96            prev: None,
97        }
98    }
99
100    pub fn name(&self) -> &str {
101        self.name.as_str()
102    }
103
104    pub fn set_next(&mut self, next: NodeRef<T>) {
105        self.next = Some(next)
106    }
107
108    pub fn set_prev(&mut self, prev: NodeRef<T>) {
109        self.prev = Some(prev)
110    }
111
112    pub fn process_data(&mut self, data: T) {
113        self.stats.data_ingress += 1;
114        let start = Instant::now();
115        let data_handler_result = match self.handler {
116            SomeDataHandler::Observer(ref mut o) => {
117                o.observe(&data);
118                Ok(SomeDataHandlerResult::ForwardToNext(data))
119            }
120            SomeDataHandler::Transformer(ref mut t) => match t.transform(data) {
121                Ok(transformed) => Ok(SomeDataHandlerResult::ForwardToNext(transformed)),
122                Err(e) => Err(anyhow!("Data transformer {} failed: {e:?}", self.name)),
123            },
124            SomeDataHandler::Filter(ref mut f) => match f.should_forward(&data) {
125                true => Ok(SomeDataHandlerResult::ForwardToNext(data)),
126                false => Ok(SomeDataHandlerResult::Discard),
127            },
128            SomeDataHandler::Consumer(ref mut c) => {
129                c.consume(data);
130                Ok(SomeDataHandlerResult::Consumed)
131            }
132            SomeDataHandler::Demuxer(ref mut d) => {
133                if let Some(path) = d.find_path(&data) {
134                    Ok(SomeDataHandlerResult::ForwardTo(data, path))
135                } else {
136                    Ok(SomeDataHandlerResult::Discard)
137                }
138            }
139        };
140        let processing_duration = Instant::now() - start;
141        self.stats.total_processing_time += processing_duration;
142        match data_handler_result {
143            Ok(SomeDataHandlerResult::ForwardToNext(p)) => {
144                self.stats.data_egress += 1;
145                if let Some(ref n) = self.next {
146                    n.process_data(p);
147                }
148            }
149            Ok(SomeDataHandlerResult::ForwardTo(p, next)) => {
150                self.stats.data_egress += 1;
151                next.process_data(p);
152            }
153            Ok(SomeDataHandlerResult::Discard) => {
154                self.stats.data_discarded += 1;
155            }
156            Ok(SomeDataHandlerResult::Consumed) => {
157                // no-op
158            }
159            Err(e) => {
160                self.stats.errors += 1;
161                println!("Error processing data: {e:?}")
162            }
163        }
164    }
165
166    pub fn visit(&mut self, visitor: &mut dyn NodeVisitor<T>) {
167        visitor.visit(self);
168        if let SomeDataHandler::Demuxer(ref mut d) = self.handler {
169            d.visit(visitor)
170        };
171        if let Some(ref mut n) = self.next {
172            n.visit(visitor);
173        }
174    }
175}
176
177impl<T> StatsProducer for Node<T> {
178    fn get_stats(&self) -> Option<serde_json::Value> {
179        Some(json!({
180            "data_ingress": self.stats.data_ingress,
181            "data_egress": self.stats.data_egress,
182            "data_discarded": self.stats.data_discarded,
183            "errors": self.stats.errors,
184            "total processing time": format!("{:?}", self.stats.total_processing_time),
185            "process time per item": format!("{:?}", (self.stats.total_processing_time / self.stats.data_ingress)),
186            "handler_stats": self.handler.get_stats(),
187        }))
188    }
189}