use std::alloc::{Alloc, Layout, Global, handle_alloc_error};
use std::sync::atomic::{AtomicBool,Ordering};
use std::thread::{JoinHandle,self};
use std::sync::{Once, ONCE_INIT};
use std::ptr::NonNull;
use std::{ptr,mem};
use std::cell::UnsafeCell;
use pleco::MoveList;
use pleco::tools::pleco_arc::Arc;
use pleco::board::*;
use pleco::core::piece_move::BitMove;
use sync::LockLatch;
use time::uci_timer::*;
use search::Searcher;
use consts::*;
const KILOBYTE: usize = 1000;
const THREAD_STACK_SIZE: usize = 18000 * KILOBYTE;
const POOL_SIZE: usize = mem::size_of::<ThreadPool>();
type DummyThreadPool = [u8; POOL_SIZE];
pub static mut THREADPOOL: DummyThreadPool = [0; POOL_SIZE];
static THREADPOOL_INIT: Once = ONCE_INIT;
#[cold]
pub fn init_threadpool() {
THREADPOOL_INIT.call_once(|| {
unsafe {
let builder = thread::Builder::new()
.name("Starter".to_string())
.stack_size(THREAD_STACK_SIZE);
let handle = builder.spawn_unchecked(
move || {
let pool: *mut ThreadPool = mem::transmute(&mut THREADPOOL);
ptr::write(pool, ThreadPool::new());
});
handle.unwrap().join().unwrap();
}
});
}
#[inline(always)]
pub fn threadpool() -> &'static mut ThreadPool {
unsafe {
mem::transmute::<&mut DummyThreadPool, &'static mut ThreadPool>(&mut THREADPOOL)
}
}
#[derive(Copy, Clone)]
enum ThreadSelection {
Main,
NonMain,
All
}
impl ThreadSelection {
#[inline(always)]
pub fn is_selection(self, id: usize) -> bool {
match self {
ThreadSelection::Main => id == 0,
ThreadSelection::NonMain => id != 0,
ThreadSelection::All => true
}
}
}
struct SearcherPtr {
ptr: UnsafeCell<*mut Searcher>
}
unsafe impl Sync for SearcherPtr {}
unsafe impl Send for SearcherPtr {}
pub struct ThreadPool {
pub threads: Vec<UnsafeCell<*mut Searcher>>,
handles: Vec<JoinHandle<()>>,
pub main_cond: Arc<LockLatch>,
pub thread_cond: Arc<LockLatch>,
pub stop: AtomicBool
}
impl ThreadPool {
pub fn new() -> Self {
let mut pool: ThreadPool = ThreadPool {
threads: Vec::new(),
handles: Vec::new(),
main_cond: Arc::new(LockLatch::new()),
thread_cond: Arc::new(LockLatch::new()),
stop: AtomicBool::new(true)
};
pool.main_cond.lock();
pool.thread_cond.lock();
pool.attach_thread();
pool
}
fn attach_thread(&mut self) {
unsafe {
let thread_ptr: SearcherPtr = self.create_thread();
let builder = thread::Builder::new()
.name(self.size().to_string())
.stack_size(THREAD_STACK_SIZE);
let handle = builder.spawn_unchecked(
move || {
let thread = &mut **thread_ptr.ptr.get();
thread.cond.lock();
thread.idle_loop();
}).unwrap();
self.handles.push(handle);
};
}
fn create_thread(&mut self) -> SearcherPtr {
let len: usize = self.threads.len();
let layout = Layout::new::<Searcher>();
let cond = if len == 0 {self.main_cond.clone()} else {self.thread_cond.clone()};
unsafe {
let result = Global.alloc_zeroed(layout);
let new_ptr: *mut Searcher = match result {
Ok(ptr) => ptr.cast().as_ptr() as *mut Searcher,
Err(_err) => handle_alloc_error(layout),
};
ptr::write(new_ptr, Searcher::new(len, cond));
self.threads.push(UnsafeCell::new(new_ptr));
SearcherPtr {ptr: UnsafeCell::new(new_ptr)}
}
}
#[inline(always)]
pub fn size(&self) -> usize {
self.threads.len()
}
fn main(&mut self) -> &mut Searcher {
unsafe {
let main_thread: *mut Searcher = *self.threads.get_unchecked(0).get();
&mut *main_thread
}
}
#[inline(always)]
pub fn stdout(&mut self, use_stdout: bool) {
USE_STDOUT.store(use_stdout, Ordering::Relaxed);
}
pub fn set_thread_count(&mut self, mut num: usize) {
if num >= 1 {
num = num.min(MAX_THREADS);
self.wait_for_finish();
self.kill_all();
while self.size() < num {
self.attach_thread();
}
}
}
pub fn kill_all(&mut self) {
self.stop.store(true, Ordering::Relaxed);
self.wait_for_finish();
let mut join_handles = Vec::with_capacity(self.size());
unsafe {
self.threads.iter()
.map(|s| &**s.get())
.for_each(|s: &Searcher| { s.kill.store(true, Ordering::SeqCst) });
self.threads.iter()
.map(|s| &**s.get())
.for_each(|s: &Searcher| { s.cond.set(); });
while let Some(handle) = self.handles.pop() {
join_handles.push(handle.join());
}
while let Some(unc) = self.threads.pop() {
let th: *mut Searcher = *unc.get();
let ptr: NonNull<u8> = mem::transmute(NonNull::new_unchecked(th));
let layout = Layout::new::<Searcher>();
Global.dealloc(ptr, layout);
}
}
while let Some(handle_result) = join_handles.pop() {
handle_result.unwrap_or_else(|e| println!("Thread failed: {:?}",e));
}
}
#[inline(always)]
pub fn set_stop(&mut self, stop: bool) {
self.stop.store(stop, Ordering::Relaxed);
}
pub fn wait_for_finish(&self) {
self.await_search_cond(ThreadSelection::All, false);
}
pub fn wait_for_start(&self) {
self.await_search_cond(ThreadSelection::All, true);
}
pub fn wait_for_non_main(&self) {
self.await_search_cond(ThreadSelection::NonMain, false);
}
pub fn wait_for_main_start(&self) {
self.await_search_cond(ThreadSelection::Main, true);
}
fn await_search_cond(&self, thread_sel: ThreadSelection, await_search: bool) {
self.threads.iter()
.map(|s| unsafe {&**s.get()})
.filter(|t| thread_sel.is_selection(t.id))
.for_each(|t: &Searcher|{ t.searching.await(await_search); });
}
pub fn clear_all(&mut self) {
self.threads.iter_mut()
.map(|thread_ptr| unsafe { &mut **(*thread_ptr).get() })
.for_each(|t| t.clear());
}
pub fn uci_search(&mut self, board: &Board, limits: &Limits) {
if let Some(uci_timer) = limits.use_time_management() {
timer().init(limits.start, &uci_timer, board.turn(), board.moves_played());
} else {
timer().start_timer(limits.start);
}
let root_moves: MoveList = board.generate_moves();
assert!(!root_moves.is_empty());
self.wait_for_finish();
self.stop.store(false, Ordering::Relaxed);
for thread_ptr in self.threads.iter_mut() {
let thread: &mut Searcher = unsafe {&mut **(*thread_ptr).get()};
thread.nodes.store(0, Ordering::Relaxed);
thread.depth_completed = 0;
thread.board = board.shallow_clone();
thread.limit = limits.clone();
thread.root_moves().replace(&root_moves);
}
self.main_cond.set();
self.wait_for_main_start();
self.main_cond.lock();
}
pub fn search(&mut self, board: &Board, limits: &Limits) -> BitMove {
self.uci_search(board, limits);
self.wait_for_finish();
self.best_move()
}
pub fn best_move(&mut self) -> BitMove {
self.main().root_moves().get(0).unwrap().bit_move
}
pub fn nodes(&self) -> u64 {
self.threads.iter()
.map(|s| unsafe {&**s.get()})
.map(|s: &Searcher| s.nodes.load(Ordering::Relaxed))
.sum()
}
}
impl Drop for ThreadPool {
fn drop(&mut self) {
self.kill_all();
}
}