use std::thread;
use std::thread::{JoinHandle};
use std::sync::mpsc;
use std::sync::mpsc::{Sender, Receiver, TryRecvError, RecvError};
use crate::dlx::{Matrix};
use crate::problem::{Problem, Value};
use crate::callback::Callback;
pub enum SolverEvent<N: Value> {
SolutionFound(Vec<N>),
ProgressUpdated(f32),
Paused,
Aborted(Matrix), Finished,
}
enum SolverThreadSignal {
Run,
RequestProgress,
Pause,
Abort,
}
enum SolverThreadEvent {
SolutionFound(Vec<usize>),
ProgressUpdated(f32),
Paused,
_Aborted(Matrix),
Finished,
}
pub struct Solver<N: Value, C: Value> {
problem: Problem<N, C>,
solver_thread: Option<SolverThread>,
}
impl<N: Value, C: Value> Solver<N, C> {
pub fn new(problem: Problem<N, C>) -> Solver<N, C> {
Solver {
problem,
solver_thread: None,
}
}
pub fn generate_matrix(problem: &Problem<N, C>) -> Matrix {
let names = problem.subsets().keys();
let mut mat = Matrix::new(problem.constraints().len());
for name in names {
let row: Vec<_> = problem.subsets()[name].iter().map(|x| {
problem.constraints().get_index_of(x).unwrap() + 1
}).collect();
mat.add_row(&row);
}
mat
}
fn send_signal(&self, signal: SolverThreadSignal) -> Result<(), ()> {
let thread = self.solver_thread.as_ref().ok_or(())?;
thread.send(signal)
}
pub fn run(&mut self) {
if let Some(thread) = &self.solver_thread {
thread.send(SolverThreadSignal::Run).ok();
} else {
let mat = Solver::generate_matrix(&self.problem);
self.solver_thread = Some(SolverThread::new(mat));
}
}
pub fn request_progress(&self) { self.send_signal(SolverThreadSignal::RequestProgress).ok(); }
pub fn pause(&self) { self.send_signal(SolverThreadSignal::Pause).ok(); }
pub fn abort(&self) { self.send_signal(SolverThreadSignal::Abort).ok(); }
fn map_event(&self, event: SolverThreadEvent) -> SolverEvent<N> {
match event {
SolverThreadEvent::SolutionFound(sol) => SolverEvent::SolutionFound(
sol.iter()
.map(|x| { self.problem.subsets().get_index(x-1).unwrap().0.clone() })
.collect()
),
SolverThreadEvent::ProgressUpdated(progress) => SolverEvent::ProgressUpdated(progress),
SolverThreadEvent::Paused => SolverEvent::Paused,
SolverThreadEvent::_Aborted(mat) => SolverEvent::Aborted(mat),
SolverThreadEvent::Finished => SolverEvent::Finished,
}
}
}
impl<N: Value, C: Value> Iterator for Solver<N, C> {
type Item = SolverEvent<N>;
fn next(&mut self) -> Option<SolverEvent<N>> {
match self.solver_thread.as_ref()?.recv() {
Ok(e) => Some(self.map_event(e)),
Err(_) => None,
}
}
}
struct SolverThread {
tx_signal: Sender<SolverThreadSignal>,
rx_event: Receiver<SolverThreadEvent>,
_thread: JoinHandle<()>, }
impl SolverThread {
fn new(mut mat: Matrix) -> SolverThread {
let (tx_signal, rx_signal) = mpsc::channel();
let (tx_event, rx_event) = mpsc::channel();
let mut callback = ThreadCallback::new(rx_signal, tx_event);
let thread = thread::spawn(move || { mat.solve(&mut callback); });
SolverThread {
tx_signal,
rx_event,
_thread: thread,
}
}
fn send(&self, signal: SolverThreadSignal) -> Result<(), ()> {
self.tx_signal.send(signal).map_err(|_| {()})
}
fn recv(&self) -> Result<SolverThreadEvent, RecvError> {
self.rx_event.recv()
}
}
struct ThreadCallback {
signal: Receiver<SolverThreadSignal>,
event: Sender<SolverThreadEvent>,
}
impl ThreadCallback {
fn new(
signal: Receiver<SolverThreadSignal>,
event: Sender<SolverThreadEvent>,
) -> ThreadCallback {
ThreadCallback { signal, event }
}
fn update_progress(&self) {
self.event.send(SolverThreadEvent::ProgressUpdated(0.0)).ok();
todo!()
}
fn pause(&self) -> SolverThreadSignal {
self.event.send(SolverThreadEvent::Paused).ok();
loop {
match self.signal.recv() {
Ok(SolverThreadSignal::Run) => break SolverThreadSignal::Run,
Ok(SolverThreadSignal::RequestProgress) => (),
Ok(SolverThreadSignal::Pause) => (),
Ok(SolverThreadSignal::Abort) => break SolverThreadSignal::Abort,
Err(RecvError) => break SolverThreadSignal::Abort,
}
}
}
}
impl Callback for ThreadCallback {
fn on_solution(&mut self, sol: Vec<usize>, _mat: &mut Matrix) {
self.event.send(SolverThreadEvent::SolutionFound(sol)).ok();
}
fn on_iteration(&mut self, mat: &mut Matrix) {
let mut pause_signal = None;
let abort = loop {
let signal = match pause_signal {
Some(s) => Ok(s),
None => self.signal.try_recv(),
};
pause_signal = None;
match signal {
Ok(SolverThreadSignal::Run) => (),
Ok(SolverThreadSignal::RequestProgress) => self.update_progress(),
Ok(SolverThreadSignal::Pause) => pause_signal = Some(self.pause()),
Ok(SolverThreadSignal::Abort) => break true,
Err(TryRecvError::Disconnected) => break true,
Err(TryRecvError::Empty) => break false,
}
};
if abort { mat.abort(); }
}
fn on_abort(&mut self, _mat: &mut Matrix) {
}
fn on_finish(&mut self) {
self.event.send(SolverThreadEvent::Finished).ok();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn solver_can_solve_problem() {
let mut prob = Problem::default();
prob.add_constraints(1..=3);
prob.add_subset("A", vec![1, 2, 3]);
prob.add_subset("B", vec![1]);
prob.add_subset("C", vec![2]);
prob.add_subset("D", vec![3]);
prob.add_subset("E", vec![1, 2]);
prob.add_subset("F", vec![2, 3]);
let mut solver = Solver::new(prob);
solver.run();
let sol: Vec<_> = solver.filter_map(|e| match e {
SolverEvent::SolutionFound(s) => Some(s),
_ => None,
}).collect();
assert_eq!(sol.len(), 4);
}
}