use itertools::izip;
use rand::distributions::uniform::SampleUniform;
use std::{
f64,
ops::{AddAssign, Div, Range, Sub},
sync::{
atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering},
Arc, Mutex,
},
thread,
time::Duration,
};
use crate::util::{poll, update_execution_position, Polling};
pub fn grid_search<
A: 'static + Send + Sync,
T: 'static
+ Copy
+ Send
+ Sync
+ Default
+ SampleUniform
+ PartialOrd
+ AddAssign
+ Sub<Output = T>
+ Div<Output = T>
+ num::FromPrimitive,
const N: usize,
>(
ranges: [Range<T>; N],
f: fn(&[T; N], Option<Arc<A>>) -> f64,
evaluation_data: Option<Arc<A>>,
polling: Option<Polling>,
points: [u64; N],
) -> [T; N] {
let mut steps = [Default::default(); N];
for (r, k, s) in izip!(ranges.iter(), points.iter(), steps.iter_mut()) {
*s = (r.end - r.start) / T::from_u64(*k).unwrap();
}
let point_values: Vec<Vec<T>> = izip!(ranges.iter(), points.iter(), steps.iter())
.map(|(r, k, s)| {
(0..*k)
.scan(r.start, |state, _| {
let prev_state = *state;
*state += *s;
Some(prev_state)
})
.collect()
})
.collect();
let mut start = [Default::default(); N];
for (s, p) in start.iter_mut().zip(point_values.iter()) {
*s = p[0];
}
let (_, params) = thread_search(f, evaluation_data, polling, &point_values, start);
return params;
fn thread_search<
A: 'static + Send + Sync,
T: 'static
+ Copy
+ Send
+ Sync
+ Default
+ SampleUniform
+ PartialOrd
+ AddAssign
+ Sub<Output = T>
+ Div<Output = T>
+ num::FromPrimitive,
const N: usize,
>(
f: fn(&[T; N], Option<Arc<A>>) -> f64,
evaluation_data: Option<Arc<A>>,
polling: Option<Polling>,
point_values: &Vec<Vec<T>>,
mut point: [T; N],
) -> (f64, [T; N]) {
if 0 == point_values.len() {
return (f(&point, evaluation_data), point);
}
let thread_exit = Arc::new(AtomicBool::new(false));
let (handles, links): (Vec<_>, Vec<_>) = point_values[0]
.iter()
.map(|p_value| {
point[0] = *p_value;
let point_values_clone = point_values.clone();
let counter = Arc::new(AtomicU64::new(0));
let thread_best = Arc::new(Mutex::new(f64::MAX));
let thread_execution_position = Arc::new(AtomicU8::new(0));
let thread_execution_time = Arc::new([
Mutex::new((Duration::new(0, 0), 0)),
Mutex::new((Duration::new(0, 0), 0)),
]);
let counter_clone = counter.clone();
let thread_best_clone = thread_best.clone();
let thread_exit_clone = thread_exit.clone();
let evaluation_data_clone = evaluation_data.clone();
let thread_execution_position_clone = thread_execution_position.clone();
let thread_execution_time_clone = thread_execution_time.clone();
(
thread::spawn(move || {
search(
&point_values_clone,
f,
evaluation_data_clone,
counter_clone,
thread_best_clone,
thread_exit_clone,
thread_execution_position_clone,
thread_execution_time_clone,
point,
1,
)
}),
(
counter,
(
thread_best,
(thread_execution_position, thread_execution_time),
),
),
)
})
.unzip();
let (counters, links): (Vec<Arc<AtomicU64>>, Vec<_>) = links.into_iter().unzip();
let (thread_bests, links): (Vec<Arc<Mutex<f64>>>, Vec<_>) = links.into_iter().unzip();
let (thread_execution_positions, thread_execution_times) = links.into_iter().unzip();
if let Some(poll_data) = polling {
let iterations = point_values.iter().map(|pvs| pvs.len() as u64).product();
poll(
poll_data,
counters,
0,
iterations,
thread_bests,
thread_exit,
thread_execution_positions,
thread_execution_times,
);
}
let joins: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
let (value, params) =
joins
.into_iter()
.fold((f64::MAX, [Default::default(); N]), |(bv, bp), (v, p)| {
if v < bv {
(v, p)
} else {
(bv, bp)
}
});
return (value, params);
}
fn search<
A: 'static + Send + Sync,
T: 'static
+ Copy
+ Send
+ Sync
+ Default
+ SampleUniform
+ PartialOrd
+ AddAssign
+ Sub<Output = T>
+ Div<Output = T>
+ num::FromPrimitive,
const N: usize,
>(
point_values: &Vec<Vec<T>>,
f: fn(&[T; N], Option<Arc<A>>) -> f64,
evaluation_data: Option<Arc<A>>,
counter: Arc<AtomicU64>,
best: Arc<Mutex<f64>>,
thread_exit: Arc<AtomicBool>,
thread_execution_position: Arc<AtomicU8>,
thread_execution_times: Arc<[Mutex<(Duration, u64)>; 2]>,
mut point: [T; N],
index: usize,
) -> (f64, [T; N]) {
if index == point_values.len() {
counter.fetch_add(1, Ordering::SeqCst);
return (f(&point, evaluation_data), point);
}
let mut best_value = f64::MAX;
let mut best_params = [Default::default(); N];
for p_value in point_values[index].iter() {
point[index] = *p_value;
let (value, params) = search(
point_values,
f,
evaluation_data.clone(),
counter.clone(),
best.clone(),
thread_exit.clone(),
thread_execution_position.clone(),
thread_execution_times.clone(),
point,
index + 1,
);
if value < best_value {
best_value = value;
best_params = params;
*best.lock().unwrap() = best_value;
}
if thread_exit.load(Ordering::SeqCst) {
break;
}
}
return (best_value, best_params);
}
}