#[cfg(not(feature = "search-no-pass-rules"))]
use crate::eval::evaluate_pass_rights;
use crate::position::Position;
use crate::types::{Bound, Color, DEPTH_UNSEARCHED, Depth, MAX_PLY, Move, Value};
use super::alpha_beta::{
EvalContext, ProbeOutcome, SearchContext, SearchState, TTContext, to_corrected_static_eval,
};
use super::history::CORRECTION_HISTORY_SIZE;
#[cfg(feature = "use-lazy-evaluate")]
use super::search_helpers::ensure_nnue_accumulator;
use super::search_helpers::nnue_evaluate;
use super::stats::inc_stat_by_depth;
#[cfg(feature = "tt-trace")]
use super::tt_sanity::{
InvalidTtLog, TtCutoffTrace, TtProbeTrace, TtWriteTrace, helper_tt_write_enabled_for_depth,
maybe_log_invalid_tt_data, maybe_trace_tt_cutoff, maybe_trace_tt_probe, maybe_trace_tt_write,
};
use super::tt_sanity::{is_valid_tt_eval, is_valid_tt_stored_value};
use super::types::{ContHistKey, NodeType, value_from_tt};
#[inline]
pub(super) fn correction_value(
st: &SearchState,
ctx: &SearchContext<'_>,
pos: &Position,
ply: i32,
) -> i32 {
let us = pos.side_to_move();
let pawn_idx = (pos.pawn_key() as usize) & (CORRECTION_HISTORY_SIZE - 1);
let minor_idx = (pos.minor_piece_key() as usize) & (CORRECTION_HISTORY_SIZE - 1);
let non_pawn_idx_w = (pos.non_pawn_key(Color::White) as usize) & (CORRECTION_HISTORY_SIZE - 1);
let non_pawn_idx_b = (pos.non_pawn_key(Color::Black) as usize) & (CORRECTION_HISTORY_SIZE - 1);
let prev_move = if ply >= 1 {
debug_assert!(((ply - 1) as usize) < st.stack.len());
unsafe { st.stack.get_unchecked((ply - 1) as usize) }.current_move
} else {
Move::NONE
};
let move_ok = prev_move.is_normal();
let sentinel_key = ContHistKey::null_sentinel();
let cont_key_2 = if move_ok {
if ply >= 2 {
debug_assert!(((ply - 2) as usize) < st.stack.len());
match unsafe { st.stack.get_unchecked((ply - 2) as usize) }.cont_hist_key {
Some(key) => key,
None => sentinel_key,
}
} else {
sentinel_key
}
} else {
sentinel_key
};
let cont_key_4 = if move_ok {
if ply >= 4 {
debug_assert!(((ply - 4) as usize) < st.stack.len());
match unsafe { st.stack.get_unchecked((ply - 4) as usize) }.cont_hist_key {
Some(key) => key,
None => sentinel_key,
}
} else {
sentinel_key
}
} else {
sentinel_key
};
let h = unsafe { ctx.history.as_ref_unchecked() };
let pcv = h.correction_history.pawn_value(pawn_idx, us) as i32;
let micv = h.correction_history.minor_value(minor_idx, us) as i32;
let wnpcv = h.correction_history.non_pawn_value(non_pawn_idx_w, Color::White, us) as i32;
let bnpcv = h.correction_history.non_pawn_value(non_pawn_idx_b, Color::Black, us) as i32;
let cntcv = if move_ok {
let pc = pos.piece_on(prev_move.to());
let cv2 = h.correction_history.continuation_value(
cont_key_2.piece,
cont_key_2.to,
pc,
prev_move.to(),
) as i32;
let cv4 = h.correction_history.continuation_value(
cont_key_4.piece,
cont_key_4.to,
pc,
prev_move.to(),
) as i32;
cv2 + cv4
} else {
8
};
let tp = ctx.tune_params;
tp.correction_value_pcv_weight * pcv
+ tp.correction_value_micv_weight * micv
+ tp.correction_value_nonpawn_weight * (wnpcv + bnpcv)
+ tp.correction_value_cnt_weight * cntcv
}
#[inline]
pub(super) fn update_correction_history(
st: &SearchState,
ctx: &SearchContext<'_>,
pos: &Position,
ply: i32,
bonus: i32,
) {
let us = pos.side_to_move();
let pawn_idx = (pos.pawn_key() as usize) & (CORRECTION_HISTORY_SIZE - 1);
let minor_idx = (pos.minor_piece_key() as usize) & (CORRECTION_HISTORY_SIZE - 1);
let non_pawn_idx_w = (pos.non_pawn_key(Color::White) as usize) & (CORRECTION_HISTORY_SIZE - 1);
let non_pawn_idx_b = (pos.non_pawn_key(Color::Black) as usize) & (CORRECTION_HISTORY_SIZE - 1);
let prev_move = if ply >= 1 {
debug_assert!(((ply - 1) as usize) < st.stack.len());
unsafe { st.stack.get_unchecked((ply - 1) as usize) }.current_move
} else {
Move::NONE
};
let move_ok = prev_move.is_normal();
let sentinel_key = ContHistKey::null_sentinel();
let cont_key_2 = if ply >= 2 {
debug_assert!(((ply - 2) as usize) < st.stack.len());
match unsafe { st.stack.get_unchecked((ply - 2) as usize) }.cont_hist_key {
Some(key) => key,
None => sentinel_key,
}
} else {
sentinel_key
};
let cont_key_4 = if ply >= 4 {
debug_assert!(((ply - 4) as usize) < st.stack.len());
match unsafe { st.stack.get_unchecked((ply - 4) as usize) }.cont_hist_key {
Some(key) => key,
None => sentinel_key,
}
} else {
sentinel_key
};
let tp = ctx.tune_params;
let h = unsafe { ctx.history.as_mut_unchecked() };
h.correction_history.update_pawn(pawn_idx, us, bonus);
h.correction_history.update_minor(
minor_idx,
us,
bonus * tp.correction_history_minor_piece_mult / 128,
);
h.correction_history.update_non_pawn(
non_pawn_idx_w,
Color::White,
us,
bonus * tp.correction_history_non_pawn_weight / 128,
);
h.correction_history.update_non_pawn(
non_pawn_idx_b,
Color::Black,
us,
bonus * tp.correction_history_non_pawn_weight / 128,
);
if move_ok {
let pc = pos.piece_on(prev_move.to());
let prev_to = prev_move.to();
h.correction_history.update_continuation(
cont_key_2.piece,
cont_key_2.to,
pc,
prev_to,
bonus * tp.correction_history_cont_ss2_weight / 128,
);
h.correction_history.update_continuation(
cont_key_4.piece,
cont_key_4.to,
pc,
prev_to,
bonus * tp.correction_history_cont_ss4_weight / 128,
);
}
}
#[allow(clippy::too_many_arguments)]
pub(super) fn probe_transposition<const NT: u8>(
st: &mut SearchState,
ctx: &SearchContext<'_>,
pos: &mut Position,
depth: Depth,
beta: Value,
ply: i32,
pv_node: bool,
in_check: bool,
excluded_move: Move,
cut_node: bool,
) -> ProbeOutcome {
let key = pos.key();
let tt_result = ctx.tt.probe(key, pos);
let tt_hit = tt_result.found;
let mut tt_data = tt_result.data;
st.stack[ply as usize].tt_hit = tt_hit;
st.stack[ply as usize].tt_pv = if excluded_move.is_some() {
st.stack[ply as usize].tt_pv
} else {
pv_node || (tt_hit && tt_data.is_pv)
};
let tt_move = tt_data.mv;
let mut tt_value = if tt_hit {
value_from_tt(tt_data.value, ply)
} else {
Value::NONE
};
if tt_hit && !is_valid_tt_stored_value(tt_data.value) {
#[cfg(feature = "tt-trace")]
maybe_log_invalid_tt_data(InvalidTtLog {
reason: "invalid_value",
stage: "ab_probe",
thread_id: ctx.thread_id,
ply,
key,
depth: tt_data.depth,
bound: tt_data.bound,
tt_move,
stored_value: tt_data.value,
converted_value: tt_value,
eval: tt_data.eval,
});
tt_value = Value::NONE;
}
if tt_hit && !is_valid_tt_eval(tt_data.eval) {
#[cfg(feature = "tt-trace")]
maybe_log_invalid_tt_data(InvalidTtLog {
reason: "invalid_eval",
stage: "ab_probe",
thread_id: ctx.thread_id,
ply,
key,
depth: tt_data.depth,
bound: tt_data.bound,
tt_move,
stored_value: tt_data.value,
converted_value: tt_value,
eval: tt_data.eval,
});
tt_data.eval = Value::NONE;
}
#[cfg(feature = "tt-trace")]
maybe_trace_tt_probe(TtProbeTrace {
stage: "ab_probe",
thread_id: ctx.thread_id,
ply,
key,
hit: tt_hit,
depth: tt_data.depth,
bound: tt_data.bound,
tt_move,
stored_value: tt_data.value,
converted_value: tt_value,
eval: tt_data.eval,
root_move: if ply >= 1 {
st.stack[0].current_move
} else {
Move::NONE
},
});
let tt_capture = tt_move.is_some() && pos.capture_stage(tt_move);
inc_stat_by_depth!(st, tt_probe_by_depth, depth);
if tt_hit {
inc_stat_by_depth!(st, tt_hit_by_depth, depth);
}
let tt_value_lte_beta = tt_value != Value::NONE && tt_value.raw() <= beta.raw();
if !pv_node
&& excluded_move.is_none()
&& tt_hit
&& tt_data.depth > depth - tt_value_lte_beta as i32
&& tt_value != Value::NONE
&& tt_data.bound.can_cutoff(tt_value, beta)
&& (cut_node == (tt_value.raw() >= beta.raw()) || depth > 5)
{
#[cfg(feature = "tt-trace")]
maybe_trace_tt_cutoff(TtCutoffTrace {
stage: "ab_probe_cutoff",
thread_id: ctx.thread_id,
ply,
key,
search_depth: depth,
depth: tt_data.depth,
bound: tt_data.bound,
value: tt_value,
beta,
root_move: if ply >= 1 {
st.stack[0].current_move
} else {
Move::NONE
},
});
return ProbeOutcome::Cutoff {
value: tt_value,
tt_move,
tt_capture,
};
}
#[cfg(feature = "search-stats")]
if !pv_node && excluded_move.is_none() && tt_hit && tt_value != Value::NONE {
if tt_data.depth <= depth - tt_value_lte_beta as i32 {
inc_stat_by_depth!(st, tt_fail_depth_by_depth, depth);
} else if !tt_data.bound.can_cutoff(tt_value, beta) {
inc_stat_by_depth!(st, tt_fail_bound_by_depth, depth);
}
}
if NT != NodeType::Root as u8 && !in_check && !tt_hit && excluded_move.is_none() {
let mate_move = pos.mate_1ply();
if mate_move.is_some() {
let value = Value::mate_in(ply + 1);
let mate1_depth_boost = {
use std::sync::LazyLock;
static BOOST: LazyLock<i32> = LazyLock::new(|| {
std::env::var("RS_MATE1_DEPTH_BOOST")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(6)
});
*BOOST
};
let stored_depth = (depth + mate1_depth_boost).min(MAX_PLY - 1);
#[cfg(feature = "tt-trace")]
let allow_write = ctx.allow_tt_write
&& helper_tt_write_enabled_for_depth(ctx.thread_id, Bound::Exact, stored_depth);
#[cfg(not(feature = "tt-trace"))]
let allow_write = ctx.allow_tt_write;
if allow_write {
#[cfg(feature = "tt-trace")]
maybe_trace_tt_write(TtWriteTrace {
stage: "ab_mate1_store",
thread_id: ctx.thread_id,
ply,
key,
depth: stored_depth,
bound: Bound::Exact,
is_pv: st.stack[ply as usize].tt_pv,
tt_move: mate_move,
stored_value: value,
eval: Value::NONE,
root_move: if ply >= 1 {
st.stack[0].current_move
} else {
Move::NONE
},
});
tt_result.write(
key,
value,
st.stack[ply as usize].tt_pv,
Bound::Exact,
stored_depth,
mate_move,
Value::NONE,
ctx.tt.generation(),
);
inc_stat_by_depth!(st, tt_write_by_depth, stored_depth);
}
return ProbeOutcome::Cutoff {
value,
tt_move: Move::NONE,
tt_capture: false,
};
}
}
ProbeOutcome::Continue(TTContext {
key,
result: tt_result,
data: tt_data,
hit: tt_hit,
mv: tt_move,
value: tt_value,
capture: tt_capture,
})
}
#[allow(clippy::too_many_arguments)]
pub(super) fn compute_eval_context(
st: &mut SearchState,
ctx: &SearchContext<'_>,
pos: &mut Position,
ply: i32,
in_check: bool,
pv_node: bool,
tt_ctx: &TTContext,
excluded_move: Move,
) -> EvalContext {
let corr_value = correction_value(st, ctx, pos, ply);
if excluded_move.is_some() {
let static_eval = st.stack[ply as usize].static_eval;
let prev2_eval = if ply >= 2 {
st.stack[(ply - 2) as usize].static_eval
} else {
Value::NONE
};
let prev_eval = if ply >= 1 {
st.stack[(ply - 1) as usize].static_eval
} else {
Value::NONE
};
let improving = static_eval > prev2_eval;
let opponent_worsening = static_eval > -prev_eval;
return EvalContext {
eval: static_eval,
static_eval,
unadjusted_static_eval: static_eval, correction_value: corr_value,
improving,
opponent_worsening,
};
}
let mut unadjusted_static_eval = Value::NONE;
let mut static_eval = if in_check {
if ply >= 2 {
st.stack[(ply - 2) as usize].static_eval
} else {
Value::NONE
}
} else if tt_ctx.hit && tt_ctx.data.eval != Value::NONE && !pv_node {
#[cfg(feature = "use-lazy-evaluate")]
{
ensure_nnue_accumulator(st, pos);
unadjusted_static_eval = tt_ctx.data.eval;
}
#[cfg(not(feature = "use-lazy-evaluate"))]
{
unadjusted_static_eval = nnue_evaluate(st, pos);
}
unadjusted_static_eval
} else {
unadjusted_static_eval = nnue_evaluate(st, pos);
unadjusted_static_eval
};
if !in_check && unadjusted_static_eval != Value::NONE {
static_eval = to_corrected_static_eval(unadjusted_static_eval, corr_value);
let pass_rights_eval = {
#[cfg(feature = "search-no-pass-rules")]
{
Value::ZERO
}
#[cfg(not(feature = "search-no-pass-rules"))]
{
evaluate_pass_rights(pos, pos.game_ply() as u16)
}
};
static_eval += pass_rights_eval;
}
#[cfg(feature = "tt-trace")]
let eval_allow_write = !in_check
&& !tt_ctx.hit
&& ctx.allow_tt_write
&& helper_tt_write_enabled_for_depth(ctx.thread_id, Bound::None, DEPTH_UNSEARCHED);
#[cfg(not(feature = "tt-trace"))]
let eval_allow_write = !in_check && !tt_ctx.hit && ctx.allow_tt_write;
if eval_allow_write {
#[cfg(feature = "tt-trace")]
maybe_trace_tt_write(TtWriteTrace {
stage: "ab_eval_store_none",
thread_id: ctx.thread_id,
ply,
key: tt_ctx.key,
depth: DEPTH_UNSEARCHED,
bound: Bound::None,
is_pv: st.stack[ply as usize].tt_pv,
tt_move: Move::NONE,
stored_value: Value::NONE,
eval: unadjusted_static_eval,
root_move: if ply >= 1 {
st.stack[0].current_move
} else {
Move::NONE
},
});
tt_ctx.result.write(
tt_ctx.key,
Value::NONE,
st.stack[ply as usize].tt_pv,
Bound::None,
DEPTH_UNSEARCHED,
Move::NONE,
unadjusted_static_eval,
ctx.tt.generation(),
);
inc_stat_by_depth!(st, tt_write_by_depth, 0);
}
let mut eval = static_eval;
if !in_check && tt_ctx.hit && tt_ctx.value != Value::NONE && {
if tt_ctx.value > eval {
tt_ctx.data.bound.is_lower_or_exact()
} else {
matches!(tt_ctx.data.bound, Bound::Upper | Bound::Exact)
}
} {
eval = tt_ctx.value;
}
st.stack[ply as usize].static_eval = static_eval;
let prev2_eval = if ply >= 2 {
st.stack[(ply - 2) as usize].static_eval
} else {
Value::NONE
};
let prev_eval = if ply >= 1 {
st.stack[(ply - 1) as usize].static_eval
} else {
Value::NONE
};
let improving = static_eval > prev2_eval;
let opponent_worsening = static_eval > -prev_eval;
EvalContext {
eval,
static_eval,
unadjusted_static_eval,
correction_value: corr_value,
improving,
opponent_worsening,
}
}