burn_train/metric/store/
client.rs

1use super::EventStore;
2use super::{Aggregate, Direction, Event, Split};
3use std::{sync::mpsc, thread::JoinHandle};
4
5/// Type that allows to communicate with an [event store](EventStore).
6pub struct EventStoreClient {
7    sender: mpsc::Sender<Message>,
8    handler: Option<JoinHandle<()>>,
9}
10
11impl EventStoreClient {
12    /// Create a new [event store](EventStore) client.
13    pub(crate) fn new<C>(store: C) -> Self
14    where
15        C: EventStore + 'static,
16    {
17        let (sender, receiver) = mpsc::channel();
18        let thread = WorkerThread::new(store, receiver);
19
20        let handler = std::thread::spawn(move || thread.run());
21        let handler = Some(handler);
22
23        Self { sender, handler }
24    }
25}
26
27impl EventStoreClient {
28    /// Add a training event to the [event store](EventStore).
29    pub(crate) fn add_event_train(&self, event: Event) {
30        self.sender
31            .send(Message::OnEventTrain(event))
32            .expect("Can send event to event store thread.");
33    }
34
35    /// Add a validation event to the [event store](EventStore).
36    pub(crate) fn add_event_valid(&self, event: Event) {
37        self.sender
38            .send(Message::OnEventValid(event))
39            .expect("Can send event to event store thread.");
40    }
41
42    /// Add a testing event to the [event store](EventStore).
43    pub(crate) fn add_event_test(&self, event: Event) {
44        self.sender
45            .send(Message::OnEventTest(event))
46            .expect("Can send event to event store thread.");
47    }
48
49    /// Find the epoch following the given criteria from the collected data.
50    pub fn find_epoch(
51        &self,
52        name: &str,
53        aggregate: Aggregate,
54        direction: Direction,
55        split: Split,
56    ) -> Option<usize> {
57        let (sender, receiver) = mpsc::sync_channel(1);
58        self.sender
59            .send(Message::FindEpoch(
60                name.to_string(),
61                aggregate,
62                direction,
63                split,
64                sender,
65            ))
66            .expect("Can send event to event store thread.");
67
68        match receiver.recv() {
69            Ok(value) => value,
70            Err(err) => panic!("Event store thread crashed: {err:?}"),
71        }
72    }
73
74    /// Find the metric value for the current epoch following the given criteria.
75    pub fn find_metric(
76        &self,
77        name: &str,
78        epoch: usize,
79        aggregate: Aggregate,
80        split: Split,
81    ) -> Option<f64> {
82        let (sender, receiver) = mpsc::sync_channel(1);
83        self.sender
84            .send(Message::FindMetric(
85                name.to_string(),
86                epoch,
87                aggregate,
88                split,
89                sender,
90            ))
91            .expect("Can send event to event store thread.");
92
93        match receiver.recv() {
94            Ok(value) => value,
95            Err(err) => panic!("Event store thread crashed: {err:?}"),
96        }
97    }
98}
99
100#[derive(new)]
101struct WorkerThread<S> {
102    store: S,
103    receiver: mpsc::Receiver<Message>,
104}
105
106impl<C> WorkerThread<C>
107where
108    C: EventStore,
109{
110    fn run(mut self) {
111        for item in self.receiver.iter() {
112            match item {
113                Message::End => {
114                    return;
115                }
116                Message::FindEpoch(name, aggregate, direction, split, callback) => {
117                    let response = self.store.find_epoch(&name, aggregate, direction, split);
118                    callback
119                        .send(response)
120                        .expect("Can send response using callback channel.");
121                }
122                Message::FindMetric(name, epoch, aggregate, split, callback) => {
123                    let response = self.store.find_metric(&name, epoch, aggregate, split);
124                    callback
125                        .send(response)
126                        .expect("Can send response using callback channel.");
127                }
128                Message::OnEventTrain(event) => self.store.add_event(event, Split::Train),
129                Message::OnEventValid(event) => self.store.add_event(event, Split::Valid),
130                Message::OnEventTest(event) => self.store.add_event(event, Split::Test),
131            }
132        }
133    }
134}
135
136enum Message {
137    OnEventTest(Event),
138    OnEventTrain(Event),
139    OnEventValid(Event),
140    End,
141    FindEpoch(
142        String,
143        Aggregate,
144        Direction,
145        Split,
146        mpsc::SyncSender<Option<usize>>,
147    ),
148    FindMetric(
149        String,
150        usize,
151        Aggregate,
152        Split,
153        mpsc::SyncSender<Option<f64>>,
154    ),
155}
156
157impl Drop for EventStoreClient {
158    fn drop(&mut self) {
159        self.sender
160            .send(Message::End)
161            .expect("Can send the end message to the event store thread.");
162        let handler = self.handler.take();
163
164        if let Some(handler) = handler {
165            handler.join().expect("The event store thread should stop.");
166        }
167    }
168}