burn_train/metric/store/
client.rs

1use super::EventStore;
2use super::{Aggregate, Direction, Event, Split};
3use std::{sync::mpsc, thread::JoinHandle};
4
5/// Type that allows to communicate with an [event store](EventStore).
6pub struct EventStoreClient {
7    sender: mpsc::Sender<Message>,
8    handler: Option<JoinHandle<()>>,
9}
10
11impl EventStoreClient {
12    /// Create a new [event store](EventStore) client.
13    pub(crate) fn new<C>(store: C) -> Self
14    where
15        C: EventStore + 'static,
16    {
17        let (sender, receiver) = mpsc::channel();
18        let thread = WorkerThread::new(store, receiver);
19
20        let handler = std::thread::spawn(move || thread.run());
21        let handler = Some(handler);
22
23        Self { sender, handler }
24    }
25}
26
27impl EventStoreClient {
28    /// Add a training event to the [event store](EventStore).
29    pub(crate) fn add_event_train(&self, event: Event) {
30        self.sender
31            .send(Message::OnEventTrain(event))
32            .expect("Can send event to event store thread.");
33    }
34
35    /// Add a validation event to the [event store](EventStore).
36    pub(crate) fn add_event_valid(&self, event: Event) {
37        self.sender
38            .send(Message::OnEventValid(event))
39            .expect("Can send event to event store thread.");
40    }
41
42    /// Find the epoch following the given criteria from the collected data.
43    pub fn find_epoch(
44        &self,
45        name: &str,
46        aggregate: Aggregate,
47        direction: Direction,
48        split: Split,
49    ) -> Option<usize> {
50        let (sender, receiver) = mpsc::sync_channel(1);
51        self.sender
52            .send(Message::FindEpoch(
53                name.to_string(),
54                aggregate,
55                direction,
56                split,
57                sender,
58            ))
59            .expect("Can send event to event store thread.");
60
61        match receiver.recv() {
62            Ok(value) => value,
63            Err(err) => panic!("Event store thread crashed: {:?}", err),
64        }
65    }
66
67    /// Find the metric value for the current epoch following the given criteria.
68    pub fn find_metric(
69        &self,
70        name: &str,
71        epoch: usize,
72        aggregate: Aggregate,
73        split: Split,
74    ) -> Option<f64> {
75        let (sender, receiver) = mpsc::sync_channel(1);
76        self.sender
77            .send(Message::FindMetric(
78                name.to_string(),
79                epoch,
80                aggregate,
81                split,
82                sender,
83            ))
84            .expect("Can send event to event store thread.");
85
86        match receiver.recv() {
87            Ok(value) => value,
88            Err(err) => panic!("Event store thread crashed: {:?}", err),
89        }
90    }
91}
92
93#[derive(new)]
94struct WorkerThread<S> {
95    store: S,
96    receiver: mpsc::Receiver<Message>,
97}
98
99impl<C> WorkerThread<C>
100where
101    C: EventStore,
102{
103    fn run(mut self) {
104        for item in self.receiver.iter() {
105            match item {
106                Message::End => {
107                    return;
108                }
109                Message::FindEpoch(name, aggregate, direction, split, callback) => {
110                    let response = self.store.find_epoch(&name, aggregate, direction, split);
111                    callback
112                        .send(response)
113                        .expect("Can send response using callback channel.");
114                }
115                Message::FindMetric(name, epoch, aggregate, split, callback) => {
116                    let response = self.store.find_metric(&name, epoch, aggregate, split);
117                    callback
118                        .send(response)
119                        .expect("Can send response using callback channel.");
120                }
121                Message::OnEventTrain(event) => self.store.add_event(event, Split::Train),
122                Message::OnEventValid(event) => self.store.add_event(event, Split::Valid),
123            }
124        }
125    }
126}
127
128enum Message {
129    OnEventTrain(Event),
130    OnEventValid(Event),
131    End,
132    FindEpoch(
133        String,
134        Aggregate,
135        Direction,
136        Split,
137        mpsc::SyncSender<Option<usize>>,
138    ),
139    FindMetric(
140        String,
141        usize,
142        Aggregate,
143        Split,
144        mpsc::SyncSender<Option<f64>>,
145    ),
146}
147
148impl Drop for EventStoreClient {
149    fn drop(&mut self) {
150        self.sender
151            .send(Message::End)
152            .expect("Can send the end message to the event store thread.");
153        let handler = self.handler.take();
154
155        if let Some(handler) = handler {
156            handler.join().expect("The event store thread should stop.");
157        }
158    }
159}