use crate::*;
use std::sync::mpsc;
use std::sync::{Arc, Mutex};
use std::thread::JoinHandle;
use std::time::{Duration, Instant};
pub trait BatchEvaluator<Spec: MCTS>: Send + Sync + 'static {
type StateEvaluation: Sync + Send + Clone;
fn evaluate_batch(
&self,
states: &[(Spec::State, MoveList<Spec>)],
) -> Vec<(Vec<MoveEvaluation<Spec>>, Self::StateEvaluation)>;
fn evaluate_existing_state(
&self,
_state: &Spec::State,
existing_evaln: &Self::StateEvaluation,
) -> Self::StateEvaluation {
existing_evaln.clone()
}
fn interpret_evaluation_for_player(
&self,
evaluation: &Self::StateEvaluation,
player: &Player<Spec>,
) -> i64;
}
pub struct BatchConfig {
pub max_batch_size: usize,
pub max_wait: Duration,
}
impl Default for BatchConfig {
fn default() -> Self {
Self {
max_batch_size: 8,
max_wait: Duration::from_millis(1),
}
}
}
struct EvalRequest<Spec: MCTS, SE> {
state: Spec::State,
moves: MoveList<Spec>,
response: mpsc::SyncSender<(Vec<MoveEvaluation<Spec>>, SE)>,
}
#[allow(clippy::type_complexity)]
pub struct BatchedEvaluatorBridge<Spec: MCTS, B: BatchEvaluator<Spec>> {
sender: Option<Mutex<mpsc::Sender<EvalRequest<Spec, B::StateEvaluation>>>>,
batch_eval: Arc<B>,
eval_thread: Option<JoinHandle<()>>,
}
unsafe impl<Spec: MCTS, B: BatchEvaluator<Spec>> Sync for BatchedEvaluatorBridge<Spec, B> {}
impl<Spec, B> BatchedEvaluatorBridge<Spec, B>
where
Spec: MCTS,
B: BatchEvaluator<Spec>,
Spec::State: Clone,
MoveList<Spec>: Clone + Send + 'static,
MoveEvaluation<Spec>: Send + 'static,
B::StateEvaluation: Send + 'static,
{
pub fn new(batch_eval: B, config: BatchConfig) -> Self {
let (sender, receiver) = mpsc::channel::<EvalRequest<Spec, B::StateEvaluation>>();
let batch_eval = Arc::new(batch_eval);
let eval_clone = Arc::clone(&batch_eval);
let handle = std::thread::spawn(move || {
collector_loop(&receiver, &*eval_clone, &config);
});
Self {
sender: Some(Mutex::new(sender)),
batch_eval,
eval_thread: Some(handle),
}
}
}
impl<Spec, B> Evaluator<Spec> for BatchedEvaluatorBridge<Spec, B>
where
Spec: MCTS<Eval = Self>,
B: BatchEvaluator<Spec>,
Spec::State: Clone,
MoveList<Spec>: Clone + Send + 'static,
MoveEvaluation<Spec>: Send + 'static,
B::StateEvaluation: Send + 'static,
{
type StateEvaluation = B::StateEvaluation;
fn evaluate_new_state(
&self,
state: &Spec::State,
moves: &MoveList<Spec>,
_handle: Option<SearchHandle<Spec>>,
) -> (Vec<MoveEvaluation<Spec>>, Self::StateEvaluation) {
let (response_tx, response_rx) = mpsc::sync_channel(1);
let request = EvalRequest {
state: state.clone(),
moves: moves.clone(),
response: response_tx,
};
let sender = self.sender.as_ref().expect("bridge already shut down");
sender
.lock()
.unwrap()
.send(request)
.expect("batch collector thread died");
response_rx
.recv()
.expect("batch collector dropped response")
}
fn evaluate_existing_state(
&self,
state: &Spec::State,
existing_evaln: &Self::StateEvaluation,
_handle: SearchHandle<Spec>,
) -> Self::StateEvaluation {
self.batch_eval
.evaluate_existing_state(state, existing_evaln)
}
fn interpret_evaluation_for_player(
&self,
evaluation: &Self::StateEvaluation,
player: &Player<Spec>,
) -> i64 {
self.batch_eval
.interpret_evaluation_for_player(evaluation, player)
}
}
impl<Spec: MCTS, B: BatchEvaluator<Spec>> Drop for BatchedEvaluatorBridge<Spec, B> {
fn drop(&mut self) {
self.sender.take();
if let Some(handle) = self.eval_thread.take() {
let _ = handle.join();
}
}
}
fn collector_loop<Spec, B>(
receiver: &mpsc::Receiver<EvalRequest<Spec, B::StateEvaluation>>,
eval: &B,
config: &BatchConfig,
) where
Spec: MCTS,
B: BatchEvaluator<Spec>,
Spec::State: Clone,
MoveList<Spec>: Clone + Send,
MoveEvaluation<Spec>: Send,
B::StateEvaluation: Send,
{
let mut batch_requests: Vec<EvalRequest<Spec, B::StateEvaluation>> =
Vec::with_capacity(config.max_batch_size);
loop {
batch_requests.clear();
match receiver.recv() {
Ok(req) => batch_requests.push(req),
Err(_) => return, }
let deadline = Instant::now() + config.max_wait;
while batch_requests.len() < config.max_batch_size {
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
break;
}
match receiver.recv_timeout(remaining) {
Ok(req) => batch_requests.push(req),
Err(mpsc::RecvTimeoutError::Timeout) => break,
Err(mpsc::RecvTimeoutError::Disconnected) => break,
}
}
let batch_input: Vec<(Spec::State, MoveList<Spec>)> = batch_requests
.iter()
.map(|r| (r.state.clone(), r.moves.clone()))
.collect();
let results = eval.evaluate_batch(&batch_input);
assert_eq!(
results.len(),
batch_requests.len(),
"evaluate_batch returned {} results for {} inputs",
results.len(),
batch_requests.len()
);
for (request, result) in batch_requests.drain(..).zip(results.into_iter()) {
let _ = request.response.send(result);
}
}
}