burn_train/metric/store/
client.rs1use super::EventStore;
2use super::{Aggregate, Direction, Event, Split};
3use std::sync::Arc;
4use std::{sync::mpsc, thread::JoinHandle};
5
6pub struct EventStoreClient {
8 sender: mpsc::Sender<Message>,
9 handler: Option<JoinHandle<()>>,
10}
11
12impl EventStoreClient {
13 pub(crate) fn new<C>(store: C) -> Self
15 where
16 C: EventStore + 'static,
17 {
18 let (sender, receiver) = mpsc::channel();
19 let thread = WorkerThread::new(store, receiver);
20
21 let handler = std::thread::spawn(move || thread.run());
22 let handler = Some(handler);
23
24 Self { sender, handler }
25 }
26}
27
28impl EventStoreClient {
29 pub(crate) fn add_event_train(&self, event: Event) {
31 self.sender
32 .send(Message::OnEventTrain(event))
33 .expect("Can send event to event store thread.");
34 }
35
36 pub(crate) fn add_event_valid(&self, event: Event) {
38 self.sender
39 .send(Message::OnEventValid(event))
40 .expect("Can send event to event store thread.");
41 }
42
43 pub(crate) fn add_event_test(&self, event: Event, tag: Arc<String>) {
45 self.sender
46 .send(Message::OnEventTest(event, tag))
47 .expect("Can send event to event store thread.");
48 }
49
50 pub fn find_epoch(
52 &self,
53 name: &str,
54 aggregate: Aggregate,
55 direction: Direction,
56 split: Split,
57 ) -> Option<usize> {
58 let (sender, receiver) = mpsc::sync_channel(1);
59 self.sender
60 .send(Message::FindEpoch(
61 name.to_string(),
62 aggregate,
63 direction,
64 split,
65 sender,
66 ))
67 .expect("Can send event to event store thread.");
68
69 match receiver.recv() {
70 Ok(value) => value,
71 Err(err) => panic!("Event store thread crashed: {err:?}"),
72 }
73 }
74
75 pub fn find_metric(
77 &self,
78 name: &str,
79 epoch: usize,
80 aggregate: Aggregate,
81 split: Split,
82 ) -> Option<f64> {
83 let (sender, receiver) = mpsc::sync_channel(1);
84 self.sender
85 .send(Message::FindMetric(
86 name.to_string(),
87 epoch,
88 aggregate,
89 split,
90 sender,
91 ))
92 .expect("Can send event to event store thread.");
93
94 match receiver.recv() {
95 Ok(value) => value,
96 Err(err) => panic!("Event store thread crashed: {err:?}"),
97 }
98 }
99}
100
101#[derive(new)]
102struct WorkerThread<S> {
103 store: S,
104 receiver: mpsc::Receiver<Message>,
105}
106
107impl<C> WorkerThread<C>
108where
109 C: EventStore,
110{
111 fn run(mut self) {
112 for item in self.receiver.iter() {
113 match item {
114 Message::End => {
115 return;
116 }
117 Message::FindEpoch(name, aggregate, direction, split, callback) => {
118 let response = self.store.find_epoch(&name, aggregate, direction, split);
119 callback
120 .send(response)
121 .expect("Can send response using callback channel.");
122 }
123 Message::FindMetric(name, epoch, aggregate, split, callback) => {
124 let response = self.store.find_metric(&name, epoch, aggregate, split);
125 callback
126 .send(response)
127 .expect("Can send response using callback channel.");
128 }
129 Message::OnEventTrain(event) => self.store.add_event(event, Split::Train, None),
130 Message::OnEventValid(event) => self.store.add_event(event, Split::Valid, None),
131 Message::OnEventTest(event, tag) => {
132 self.store.add_event(event, Split::Test, Some(tag))
133 }
134 }
135 }
136 }
137}
138
139enum Message {
140 OnEventTest(Event, Arc<String>),
141 OnEventTrain(Event),
142 OnEventValid(Event),
143 End,
144 FindEpoch(
145 String,
146 Aggregate,
147 Direction,
148 Split,
149 mpsc::SyncSender<Option<usize>>,
150 ),
151 FindMetric(
152 String,
153 usize,
154 Aggregate,
155 Split,
156 mpsc::SyncSender<Option<f64>>,
157 ),
158}
159
160impl Drop for EventStoreClient {
161 fn drop(&mut self) {
162 self.sender
163 .send(Message::End)
164 .expect("Can send the end message to the event store thread.");
165 let handler = self.handler.take();
166
167 if let Some(handler) = handler {
168 handler.join().expect("The event store thread should stop.");
169 }
170 }
171}