Skip to main content

burn_train/metric/store/
client.rs

1use super::EventStore;
2use super::{Aggregate, Direction, Event, Split};
3use std::sync::Arc;
4use std::{sync::mpsc, thread::JoinHandle};
5
6/// Type that allows to communicate with an [event store](EventStore).
7pub struct EventStoreClient {
8    sender: mpsc::Sender<Message>,
9    handler: Option<JoinHandle<()>>,
10}
11
12impl EventStoreClient {
13    /// Create a new [event store](EventStore) client.
14    pub(crate) fn new<C>(store: C) -> Self
15    where
16        C: EventStore + 'static,
17    {
18        let (sender, receiver) = mpsc::channel();
19        let thread = WorkerThread::new(store, receiver);
20
21        let handler = std::thread::spawn(move || thread.run());
22        let handler = Some(handler);
23
24        Self { sender, handler }
25    }
26}
27
28impl EventStoreClient {
29    /// Add a training event to the [event store](EventStore).
30    pub(crate) fn add_event_train(&self, event: Event) {
31        self.sender
32            .send(Message::OnEventTrain(event))
33            .expect("Can send event to event store thread.");
34    }
35
36    /// Add a validation event to the [event store](EventStore).
37    pub(crate) fn add_event_valid(&self, event: Event) {
38        self.sender
39            .send(Message::OnEventValid(event))
40            .expect("Can send event to event store thread.");
41    }
42
43    /// Add a testing event to the [event store](EventStore).
44    pub(crate) fn add_event_test(&self, event: Event, tag: Arc<String>) {
45        self.sender
46            .send(Message::OnEventTest(event, tag))
47            .expect("Can send event to event store thread.");
48    }
49
50    /// Find the epoch following the given criteria from the collected data.
51    pub fn find_epoch(
52        &self,
53        name: &str,
54        aggregate: Aggregate,
55        direction: Direction,
56        split: Split,
57    ) -> Option<usize> {
58        let (sender, receiver) = mpsc::sync_channel(1);
59        self.sender
60            .send(Message::FindEpoch(
61                name.to_string(),
62                aggregate,
63                direction,
64                split,
65                sender,
66            ))
67            .expect("Can send event to event store thread.");
68
69        match receiver.recv() {
70            Ok(value) => value,
71            Err(err) => panic!("Event store thread crashed: {err:?}"),
72        }
73    }
74
75    /// Find the metric value for the current epoch following the given criteria.
76    pub fn find_metric(
77        &self,
78        name: &str,
79        epoch: usize,
80        aggregate: Aggregate,
81        split: Split,
82    ) -> Option<f64> {
83        let (sender, receiver) = mpsc::sync_channel(1);
84        self.sender
85            .send(Message::FindMetric(
86                name.to_string(),
87                epoch,
88                aggregate,
89                split,
90                sender,
91            ))
92            .expect("Can send event to event store thread.");
93
94        match receiver.recv() {
95            Ok(value) => value,
96            Err(err) => panic!("Event store thread crashed: {err:?}"),
97        }
98    }
99}
100
101#[derive(new)]
102struct WorkerThread<S> {
103    store: S,
104    receiver: mpsc::Receiver<Message>,
105}
106
107impl<C> WorkerThread<C>
108where
109    C: EventStore,
110{
111    fn run(mut self) {
112        for item in self.receiver.iter() {
113            match item {
114                Message::End => {
115                    return;
116                }
117                Message::FindEpoch(name, aggregate, direction, split, callback) => {
118                    let response = self.store.find_epoch(&name, aggregate, direction, split);
119                    callback
120                        .send(response)
121                        .expect("Can send response using callback channel.");
122                }
123                Message::FindMetric(name, epoch, aggregate, split, callback) => {
124                    let response = self.store.find_metric(&name, epoch, aggregate, split);
125                    callback
126                        .send(response)
127                        .expect("Can send response using callback channel.");
128                }
129                Message::OnEventTrain(event) => self.store.add_event(event, Split::Train, None),
130                Message::OnEventValid(event) => self.store.add_event(event, Split::Valid, None),
131                Message::OnEventTest(event, tag) => {
132                    self.store.add_event(event, Split::Test, Some(tag))
133                }
134            }
135        }
136    }
137}
138
139enum Message {
140    OnEventTest(Event, Arc<String>),
141    OnEventTrain(Event),
142    OnEventValid(Event),
143    End,
144    FindEpoch(
145        String,
146        Aggregate,
147        Direction,
148        Split,
149        mpsc::SyncSender<Option<usize>>,
150    ),
151    FindMetric(
152        String,
153        usize,
154        Aggregate,
155        Split,
156        mpsc::SyncSender<Option<f64>>,
157    ),
158}
159
160impl Drop for EventStoreClient {
161    fn drop(&mut self) {
162        self.sender
163            .send(Message::End)
164            .expect("Can send the end message to the event store thread.");
165        let handler = self.handler.take();
166
167        if let Some(handler) = handler {
168            handler.join().expect("The event store thread should stop.");
169        }
170    }
171}