burn_train/metric/store/
client.rs1use super::EventStore;
2use super::{Aggregate, Direction, Event, Split};
3use std::{sync::mpsc, thread::JoinHandle};
4
5pub struct EventStoreClient {
7 sender: mpsc::Sender<Message>,
8 handler: Option<JoinHandle<()>>,
9}
10
11impl EventStoreClient {
12 pub(crate) fn new<C>(store: C) -> Self
14 where
15 C: EventStore + 'static,
16 {
17 let (sender, receiver) = mpsc::channel();
18 let thread = WorkerThread::new(store, receiver);
19
20 let handler = std::thread::spawn(move || thread.run());
21 let handler = Some(handler);
22
23 Self { sender, handler }
24 }
25}
26
27impl EventStoreClient {
28 pub(crate) fn add_event_train(&self, event: Event) {
30 self.sender
31 .send(Message::OnEventTrain(event))
32 .expect("Can send event to event store thread.");
33 }
34
35 pub(crate) fn add_event_valid(&self, event: Event) {
37 self.sender
38 .send(Message::OnEventValid(event))
39 .expect("Can send event to event store thread.");
40 }
41
42 pub(crate) fn add_event_test(&self, event: Event) {
44 self.sender
45 .send(Message::OnEventTest(event))
46 .expect("Can send event to event store thread.");
47 }
48
49 pub fn find_epoch(
51 &self,
52 name: &str,
53 aggregate: Aggregate,
54 direction: Direction,
55 split: Split,
56 ) -> Option<usize> {
57 let (sender, receiver) = mpsc::sync_channel(1);
58 self.sender
59 .send(Message::FindEpoch(
60 name.to_string(),
61 aggregate,
62 direction,
63 split,
64 sender,
65 ))
66 .expect("Can send event to event store thread.");
67
68 match receiver.recv() {
69 Ok(value) => value,
70 Err(err) => panic!("Event store thread crashed: {err:?}"),
71 }
72 }
73
74 pub fn find_metric(
76 &self,
77 name: &str,
78 epoch: usize,
79 aggregate: Aggregate,
80 split: Split,
81 ) -> Option<f64> {
82 let (sender, receiver) = mpsc::sync_channel(1);
83 self.sender
84 .send(Message::FindMetric(
85 name.to_string(),
86 epoch,
87 aggregate,
88 split,
89 sender,
90 ))
91 .expect("Can send event to event store thread.");
92
93 match receiver.recv() {
94 Ok(value) => value,
95 Err(err) => panic!("Event store thread crashed: {err:?}"),
96 }
97 }
98}
99
100#[derive(new)]
101struct WorkerThread<S> {
102 store: S,
103 receiver: mpsc::Receiver<Message>,
104}
105
106impl<C> WorkerThread<C>
107where
108 C: EventStore,
109{
110 fn run(mut self) {
111 for item in self.receiver.iter() {
112 match item {
113 Message::End => {
114 return;
115 }
116 Message::FindEpoch(name, aggregate, direction, split, callback) => {
117 let response = self.store.find_epoch(&name, aggregate, direction, split);
118 callback
119 .send(response)
120 .expect("Can send response using callback channel.");
121 }
122 Message::FindMetric(name, epoch, aggregate, split, callback) => {
123 let response = self.store.find_metric(&name, epoch, aggregate, split);
124 callback
125 .send(response)
126 .expect("Can send response using callback channel.");
127 }
128 Message::OnEventTrain(event) => self.store.add_event(event, Split::Train),
129 Message::OnEventValid(event) => self.store.add_event(event, Split::Valid),
130 Message::OnEventTest(event) => self.store.add_event(event, Split::Test),
131 }
132 }
133 }
134}
135
136enum Message {
137 OnEventTest(Event),
138 OnEventTrain(Event),
139 OnEventValid(Event),
140 End,
141 FindEpoch(
142 String,
143 Aggregate,
144 Direction,
145 Split,
146 mpsc::SyncSender<Option<usize>>,
147 ),
148 FindMetric(
149 String,
150 usize,
151 Aggregate,
152 Split,
153 mpsc::SyncSender<Option<f64>>,
154 ),
155}
156
157impl Drop for EventStoreClient {
158 fn drop(&mut self) {
159 self.sender
160 .send(Message::End)
161 .expect("Can send the end message to the event store thread.");
162 let handler = self.handler.take();
163
164 if let Some(handler) = handler {
165 handler.join().expect("The event store thread should stop.");
166 }
167 }
168}