use executor::SelectionExecutor;
use rand::prelude::*;
use rayon::prelude::*;
use crate::{
  execution::strategy::*,
  operator::{
    tag::SelectionOperatorTag,
    ParBatch,
    ParBatchOperator,
    ParEach,
    ParEachOperator,
  },
  score::Scores,
  Score,
};
pub trait Selection<S, const N: usize> {
  fn select(&self, solution: &S, scores: &Scores<N>) -> bool;
}
impl<S, const N: usize, F> Selection<S, N> for F
where
  F: Fn(&S, &Scores<N>) -> bool,
{
  fn select(&self, solution: &S, scores: &Scores<N>) -> bool {
    self(solution, scores)
  }
}
impl<S, const N: usize, L> ParEach<SelectionOperatorTag, S, N, 0> for L where
  L: Selection<S, N>
{
}
impl<S, const N: usize, L> ParBatch<SelectionOperatorTag, S, N> for L where
  L: Selection<S, N>
{
}
pub trait Selector<S, const N: usize> {
  fn select<'a>(&self, solutions: &'a [S], scores: &[Scores<N>]) -> Vec<&'a S>;
}
impl<S, const N: usize, F> Selector<S, N> for F
where
  F: for<'a> Fn(&'a [S], &[Scores<N>]) -> Vec<&'a S>,
{
  fn select<'a>(&self, solutions: &'a [S], scores: &[Scores<N>]) -> Vec<&'a S> {
    self(solutions, scores)
  }
}
pub(crate) mod executor {
  use crate::score::Scores;
  pub trait SelectionExecutor<S, const N: usize, ExecutionStrategy> {
    fn execute_selection<'a>(
      &self,
      solutions: &'a [S],
      scores: &[Scores<N>],
    ) -> Vec<&'a S>;
  }
}
impl<S, const N: usize, L> SelectionExecutor<S, N, CustomExecutionStrategy>
  for L
where
  L: Selector<S, N>,
{
  fn execute_selection<'a>(
    &self,
    solutions: &'a [S],
    scores: &[Scores<N>],
  ) -> Vec<&'a S> {
    self.select(solutions, scores)
  }
}
impl<S, const N: usize, L> SelectionExecutor<S, N, SequentialExecutionStrategy>
  for L
where
  L: Selection<S, N>,
{
  fn execute_selection<'a>(
    &self,
    solutions: &'a [S],
    scores: &[Scores<N>],
  ) -> Vec<&'a S> {
    solutions
      .iter()
      .zip(scores)
      .filter_map(|(sol, sc)| self.select(sol, sc).then_some(sol))
      .collect()
  }
}
impl<S, const N: usize, L>
  SelectionExecutor<S, N, ParallelEachExecutionStrategy>
  for ParEachOperator<SelectionOperatorTag, S, L>
where
  S: Sync,
  L: Selection<S, N> + Sync,
{
  fn execute_selection<'a>(
    &self,
    solutions: &'a [S],
    scores: &[Scores<N>],
  ) -> Vec<&'a S> {
    solutions
      .par_iter()
      .zip(scores)
      .filter_map(|(sol, sc)| self.operator().select(sol, sc).then_some(sol))
      .collect()
  }
}
impl<S, const N: usize, L>
  SelectionExecutor<S, N, ParallelBatchExecutionStrategy>
  for ParBatchOperator<SelectionOperatorTag, S, L>
where
  S: Sync,
  L: Selection<S, N> + Sync,
{
  fn execute_selection<'a>(
    &self,
    solutions: &'a [S],
    scores: &[Scores<N>],
  ) -> Vec<&'a S> {
    let chunk_size = (solutions.len() / rayon::current_num_threads()).max(1);
    solutions
      .chunks(chunk_size)
      .zip(scores.chunks(chunk_size))
      .par_bridge()
      .flat_map_iter(|chunk| {
        chunk.0.iter().zip(chunk.1).filter_map(|(sol, sc)| {
          self.operator().select(sol, sc).then_some(sol)
        })
      })
      .collect()
  }
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Default)]
pub struct AllSelector();
impl<const N: usize, S> Selector<S, N> for AllSelector {
  fn select<'a>(&self, solutions: &'a [S], _: &[Scores<N>]) -> Vec<&'a S> {
    solutions.iter().collect()
  }
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub struct FirstSelector(pub usize);
impl<const N: usize, S> Selector<S, N> for FirstSelector {
  fn select<'a>(&self, solutions: &'a [S], _: &[Scores<N>]) -> Vec<&'a S> {
    solutions.iter().take(self.0).collect()
  }
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub struct RandomSelector(pub usize);
impl<const N: usize, S> Selector<S, N> for RandomSelector {
  fn select<'a>(&self, solutions: &'a [S], _: &[Scores<N>]) -> Vec<&'a S> {
    solutions
      .iter()
      .choose_multiple(&mut rand::thread_rng(), self.0)
  }
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub struct BestSelector(pub usize);
impl<const N: usize, S> Selector<S, N> for BestSelector {
  fn select<'a>(&self, solutions: &'a [S], scores: &[Scores<N>]) -> Vec<&'a S> {
    if solutions.len() <= self.0 {
      return solutions.iter().collect();
    }
    let mut sol_sc = solutions
      .iter()
      .zip(
        scores
          .iter()
          .map(|sc| sc.map(Score::abs).iter().sum::<Score>()),
      )
      .collect::<Vec<_>>();
    sol_sc.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("NaN encountered"));
    sol_sc.into_iter().take(self.0).map(|s| s.0).collect()
  }
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub struct TournamentSelector(pub usize);
impl<const N: usize, S> Selector<S, N> for TournamentSelector {
  fn select<'a>(&self, solutions: &'a [S], scores: &[Scores<N>]) -> Vec<&'a S> {
    let mut sol_sc = solutions
      .iter()
      .zip(
        scores
          .iter()
          .map(|sc| sc.map(Score::abs).iter().sum::<Score>()),
      )
      .collect::<Vec<_>>();
    sol_sc.shuffle(&mut rand::thread_rng());
    sol_sc
      .chunks(self.0)
      .filter_map(|chunk| {
        chunk
          .iter()
          .min_by(|a, b| a.1.partial_cmp(&b.1).expect("NaN encoutnered"))
          .map(|ch| ch.0)
      })
      .collect()
  }
}
#[cfg(test)]
mod tests {
  use super::*;
  type Solution = f32;
  fn takes_selector<ES, L: SelectionExecutor<Solution, 3, ES>>(l: &L) {
    l.execute_selection(&[], &[]);
  }
  #[test]
  fn test_selection_from_closure() {
    let selection = |_: &Solution, _: &Scores<3>| true;
    takes_selector(&selection);
    takes_selector(&selection.par_each());
    takes_selector(&selection.par_batch());
  }
  #[test]
  fn test_selector_from_fn() {
    fn selector<'a>(
      solutions: &'a [Solution],
      _: &[Scores<3>],
    ) -> Vec<&'a Solution> {
      solutions.iter().collect()
    }
    takes_selector(&selector);
  }
  #[test]
  fn test_custom_selection() {
    #[derive(Clone, Copy)]
    struct CustomSelection {}
    impl<S> Selection<S, 3> for CustomSelection {
      fn select(&self, _: &S, _: &Scores<3>) -> bool {
        true
      }
    }
    let selection = CustomSelection {};
    takes_selector(&selection);
    takes_selector(&selection.par_each());
    takes_selector(&selection.par_batch());
  }
  #[test]
  fn test_custom_selectior() {
    #[derive(Clone, Copy)]
    struct CustomSelector {}
    impl<S> Selector<S, 3> for CustomSelector {
      fn select<'a>(&self, solutions: &'a [S], _: &[Scores<3>]) -> Vec<&'a S> {
        solutions.iter().collect()
      }
    }
    let selector = CustomSelector {};
    takes_selector(&selector);
  }
  #[test]
  fn test_all_selector() {
    let selector = AllSelector();
    takes_selector(&selector);
  }
  #[test]
  fn test_first_selector() {
    let selector = FirstSelector(10);
    takes_selector(&selector);
  }
  #[test]
  fn test_random_selector() {
    let selector = RandomSelector(10);
    takes_selector(&selector);
  }
  #[test]
  fn test_best_selector() {
    let selector = BestSelector(10);
    takes_selector(&selector);
  }
  #[test]
  fn test_tournament_selector() {
    let selector = TournamentSelector(10);
    takes_selector(&selector);
  }
}