1use std::thread;
4use std::thread::{JoinHandle};
5use std::sync::mpsc;
6use std::sync::mpsc::{Sender, Receiver, TryRecvError, RecvError};
7use crate::dlx::{Matrix, Callback};
8use crate::problem::{Problem, Value};
9
10pub enum SolverEvent<N: Value> {
12 SolutionFound(Vec<N>),
13 ProgressUpdated(f32),
14 Paused,
15 Aborted(Matrix), Finished,
17}
18
19enum SolverThreadSignal {
20 Run,
21 RequestProgress,
22 Pause,
23 Abort,
24}
25
26enum SolverThreadEvent {
27 SolutionFound(Vec<usize>),
28 ProgressUpdated(f32),
29 Paused,
30 _Aborted(Matrix),
31 Finished,
32}
33
34pub struct Solver<N: Value, C: Value> {
36 problem: Problem<N, C>,
37 solver_thread: Option<SolverThread>,
38}
39
40impl<N: Value, C: Value> Solver<N, C> {
41 pub fn new(problem: Problem<N, C>) -> Solver<N, C> {
43 Solver {
44 problem,
45 solver_thread: None,
46 }
47 }
48
49 pub fn generate_matrix(problem: &Problem<N, C>) -> Matrix {
50 let names = problem.subsets().keys();
51
52 let mut mat = Matrix::new(problem.constraints().len());
53 for name in names {
54 let row: Vec<_> = problem.subsets()[name].iter().map(|x| {
55 problem.constraints().get_index_of(x).unwrap() + 1
57 }).collect();
58 mat.add_row(&row);
59 }
60 mat
61 }
62
63 fn send_signal(&self, signal: SolverThreadSignal) -> Result<(), ()> {
64 let thread = self.solver_thread.as_ref().ok_or(())?;
65 thread.send(signal)
66 }
67
68 pub fn run(&mut self) {
70 if let Some(thread) = &self.solver_thread {
72 thread.send(SolverThreadSignal::Run).ok();
73 } else {
74 let mat = Solver::generate_matrix(&self.problem);
75 self.solver_thread = Some(SolverThread::new(mat));
76 }
77 }
78 pub fn request_progress(&self) { self.send_signal(SolverThreadSignal::RequestProgress).ok(); }
79 pub fn pause(&self) { self.send_signal(SolverThreadSignal::Pause).ok(); }
80 pub fn abort(&self) { self.send_signal(SolverThreadSignal::Abort).ok(); }
81
82 fn map_event(&self, event: SolverThreadEvent) -> SolverEvent<N> {
83 match event {
84 SolverThreadEvent::SolutionFound(sol) => SolverEvent::SolutionFound(
85 sol.iter()
86 .map(|x| { self.problem.subsets().get_index(x-1).unwrap().0.clone() })
87 .collect()
88 ),
89 SolverThreadEvent::ProgressUpdated(progress) => SolverEvent::ProgressUpdated(progress),
90 SolverThreadEvent::Paused => SolverEvent::Paused,
91 SolverThreadEvent::_Aborted(mat) => SolverEvent::Aborted(mat),
92 SolverThreadEvent::Finished => SolverEvent::Finished,
93 }
94 }
95}
96
97pub struct SolverIter<N: Value, C: Value> {
99 solver: Solver<N, C>,
100}
101
102impl<N: Value, C: Value> Iterator for SolverIter<N, C> {
103 type Item = SolverEvent<N>;
104
105 fn next(&mut self) -> Option<SolverEvent<N>> {
106 if let Ok(e) = self.solver.solver_thread.as_ref()?.recv() {
107 Some(self.solver.map_event(e))
108 } else {
109 None
110 }
111 }
112}
113
114impl<N: Value, C: Value> IntoIterator for Solver<N, C> {
116 type Item = SolverEvent<N>;
117 type IntoIter = SolverIter<N, C>;
118
119 fn into_iter(self) -> Self::IntoIter {
121 SolverIter { solver: self }
122 }
123}
124
125
126struct SolverThread {
128 tx_signal: Sender<SolverThreadSignal>,
129 rx_event: Receiver<SolverThreadEvent>,
130 _thread: JoinHandle<()>, }
132
133impl SolverThread {
134 fn new(mut mat: Matrix) -> SolverThread {
136 let (tx_signal, rx_signal) = mpsc::channel();
137 let (tx_event, rx_event) = mpsc::channel();
138
139 let mut callback = ThreadCallback::new(rx_signal, tx_event);
140 let thread = thread::spawn(move || { mat.solve(&mut callback); });
141
142 SolverThread {
143 tx_signal,
144 rx_event,
145 _thread: thread,
146 }
147 }
148
149 fn send(&self, signal: SolverThreadSignal) -> Result<(), ()> {
150 self.tx_signal.send(signal).map_err(|_| {()})
153 }
154
155 fn recv(&self) -> Result<SolverThreadEvent, RecvError> {
156 self.rx_event.recv()
158 }
159}
160
161struct ThreadCallback {
162 signal: Receiver<SolverThreadSignal>,
163 event: Sender<SolverThreadEvent>,
164}
165
166impl ThreadCallback {
167 fn new(
168 signal: Receiver<SolverThreadSignal>,
169 event: Sender<SolverThreadEvent>,
170 ) -> ThreadCallback {
171 ThreadCallback { signal, event }
172 }
173
174 fn update_progress(&self) {
175 self.event.send(SolverThreadEvent::ProgressUpdated(0.0)).ok();
177 todo!()
178 }
179
180 fn pause(&self) -> SolverThreadSignal {
182 self.event.send(SolverThreadEvent::Paused).ok();
183 loop {
184 match self.signal.recv() {
185 Ok(SolverThreadSignal::Run) => break SolverThreadSignal::Run,
186 Ok(SolverThreadSignal::RequestProgress) => (),
187 Ok(SolverThreadSignal::Pause) => (),
188 Ok(SolverThreadSignal::Abort) => break SolverThreadSignal::Abort,
189 Err(RecvError) => break SolverThreadSignal::Abort,
190 }
191 }
192 }
193}
194
195impl Callback for ThreadCallback {
196 fn on_solution(&mut self, sol: Vec<usize>, _mat: &mut Matrix) {
197 self.event.send(SolverThreadEvent::SolutionFound(sol)).ok();
198 }
199
200 fn on_iteration(&mut self, mat: &mut Matrix) {
201 let mut pause_signal = None; let abort = loop {
204 let signal = match pause_signal {
205 Some(s) => Ok(s),
206 None => self.signal.try_recv(),
207 };
208 pause_signal = None;
209
210 match signal {
211 Ok(SolverThreadSignal::Run) => (),
212 Ok(SolverThreadSignal::RequestProgress) => self.update_progress(),
213 Ok(SolverThreadSignal::Pause) => pause_signal = Some(self.pause()),
214 Ok(SolverThreadSignal::Abort) => break true,
215 Err(TryRecvError::Disconnected) => break true,
216 Err(TryRecvError::Empty) => break false,
217 }
218 };
219
220 if abort { mat.abort(); }
221 }
222
223 fn on_abort(&mut self, _mat: &mut Matrix) {
224 }
227
228 fn on_finish(&mut self) {
229 self.event.send(SolverThreadEvent::Finished).ok();
230 }
231}
232
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn solver_can_solve_problem() {
240 let mut prob = Problem::default();
241 prob.add_constraints(1..=3);
242 prob.add_subset("A", vec![1, 2, 3]);
243 prob.add_subset("B", vec![1]);
244 prob.add_subset("C", vec![2]);
245 prob.add_subset("D", vec![3]);
246 prob.add_subset("E", vec![1, 2]);
247 prob.add_subset("F", vec![2, 3]);
248
249 let mut solver = Solver::new(prob);
250 let mut solutions = vec![];
251 solver.run();
252
253 for event in solver {
254 if let SolverEvent::SolutionFound(sol) = event {
255 solutions.push(sol);
256 }
257 }
258
259 assert_eq!(solutions.len(), 4);
260 }
261}