use std::fmt::Debug;
use std::marker::PhantomData;
use rayon::prelude::*;
use solverforge_core::domain::PlanningSolution;
use solverforge_scoring::Director;
use crate::phase::Phase;
use crate::scope::ProgressCallback;
use crate::scope::SolverScope;
use super::child_phases::ChildPhases;
use super::config::PartitionedSearchConfig;
use super::partitioner::SolutionPartitioner;
pub struct PartitionedSearchPhase<S, D, PD, Part, SDF, PF, CP>
where
S: PlanningSolution,
D: Director<S>,
PD: Director<S>,
Part: SolutionPartitioner<S>,
SDF: Fn(S) -> PD + Send + Sync,
PF: Fn() -> CP + Send + Sync,
CP: ChildPhases<S, PD>,
{
partitioner: Part,
score_director_factory: SDF,
phase_factory: PF,
config: PartitionedSearchConfig,
_marker: PhantomData<fn(S, D, PD, CP)>,
}
impl<S, D, PD, Part, SDF, PF, CP> PartitionedSearchPhase<S, D, PD, Part, SDF, PF, CP>
where
S: PlanningSolution,
D: Director<S>,
PD: Director<S>,
Part: SolutionPartitioner<S>,
SDF: Fn(S) -> PD + Send + Sync,
PF: Fn() -> CP + Send + Sync,
CP: ChildPhases<S, PD>,
{
pub fn new(partitioner: Part, score_director_factory: SDF, phase_factory: PF) -> Self {
Self {
partitioner,
score_director_factory,
phase_factory,
config: PartitionedSearchConfig::default(),
_marker: PhantomData,
}
}
pub fn with_config(
partitioner: Part,
score_director_factory: SDF,
phase_factory: PF,
config: PartitionedSearchConfig,
) -> Self {
Self {
partitioner,
score_director_factory,
phase_factory,
config,
_marker: PhantomData,
}
}
}
impl<S, D, PD, Part, SDF, PF, CP> Debug for PartitionedSearchPhase<S, D, PD, Part, SDF, PF, CP>
where
S: PlanningSolution,
D: Director<S>,
PD: Director<S>,
Part: SolutionPartitioner<S> + Debug,
SDF: Fn(S) -> PD + Send + Sync,
PF: Fn() -> CP + Send + Sync,
CP: ChildPhases<S, PD>,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PartitionedSearchPhase")
.field("partitioner", &self.partitioner)
.field("config", &self.config)
.finish()
}
}
impl<S, D, BestCb, PD, Part, SDF, PF, CP> Phase<S, D, BestCb>
for PartitionedSearchPhase<S, D, PD, Part, SDF, PF, CP>
where
S: PlanningSolution + 'static,
D: Director<S>,
BestCb: ProgressCallback<S>,
PD: Director<S> + 'static,
Part: SolutionPartitioner<S>,
SDF: Fn(S) -> PD + Send + Sync,
PF: Fn() -> CP + Send + Sync,
CP: ChildPhases<S, PD> + Send,
{
fn solve(&mut self, solver_scope: &mut SolverScope<S, D, BestCb>) {
let solution = solver_scope.score_director().working_solution().clone();
let partitions = self.partitioner.partition(&solution);
let partition_count = partitions.len();
if partition_count == 0 {
return;
}
let thread_count = self.config.thread_count.resolve(partition_count);
if self.config.log_progress {
tracing::info!(event = "phase_start", phase = "PartitionedSearch",);
}
let solved_partitions: Vec<S> = if thread_count == 1 || partition_count == 1 {
partitions
.into_iter()
.map(|p| self.solve_partition(p))
.collect()
} else {
partitions
.into_par_iter()
.map(|partition| {
let director = (self.score_director_factory)(partition);
let mut solver_scope = SolverScope::new(director);
let mut phases = (self.phase_factory)();
phases.solve_all(&mut solver_scope);
solver_scope.take_best_or_working_solution()
})
.collect()
};
let merged = self.partitioner.merge(&solution, solved_partitions);
let director = solver_scope.score_director_mut();
let working = director.working_solution_mut();
*working = merged;
solver_scope.calculate_score();
solver_scope.update_best_solution();
if self.config.log_progress {
if let Some(score) = solver_scope.best_score() {
tracing::info!(
event = "phase_end",
phase = "PartitionedSearch",
score = %format!("{:?}", score),
);
}
}
}
fn phase_type_name(&self) -> &'static str {
"PartitionedSearch"
}
}
impl<S, D, PD, Part, SDF, PF, CP> PartitionedSearchPhase<S, D, PD, Part, SDF, PF, CP>
where
S: PlanningSolution,
D: Director<S>,
PD: Director<S>,
Part: SolutionPartitioner<S>,
SDF: Fn(S) -> PD + Send + Sync,
PF: Fn() -> CP + Send + Sync,
CP: ChildPhases<S, PD>,
{
fn solve_partition(&self, partition: S) -> S {
let director = (self.score_director_factory)(partition);
let mut solver_scope = SolverScope::new(director);
let mut phases = (self.phase_factory)();
phases.solve_all(&mut solver_scope);
solver_scope.take_best_or_working_solution()
}
}
#[cfg(test)]
mod tests {
use super::super::partitioner::ThreadCount;
use super::*;
#[test]
fn test_config_default() {
let config = PartitionedSearchConfig::default();
assert_eq!(config.thread_count, ThreadCount::Auto);
assert!(!config.log_progress);
}
}