use crate::pattern_table::PATTERN_NUM_IDS;
use noru::trainer::SimpleRng;
use std::sync::OnceLock;
pub const PATTERN4_DENSE_DIM: usize = 64;
const EMBEDDING_SEED: u64 = 0x4F4D_4F4B_5530_3531;
pub struct Pattern4Embedding {
table: Vec<f32>,
}
impl Pattern4Embedding {
#[inline]
pub fn row(&self, pid: u16) -> &[f32] {
let start = (pid as usize) * PATTERN4_DENSE_DIM;
&self.table[start..start + PATTERN4_DENSE_DIM]
}
fn build_random() -> Self {
let scale = (2.0f32 / PATTERN4_DENSE_DIM as f32).sqrt() * 0.1;
let mut rng = SimpleRng::new(EMBEDDING_SEED);
let n = PATTERN_NUM_IDS * PATTERN4_DENSE_DIM;
let mut table = vec![0.0f32; n];
for v in table.iter_mut() {
*v = rng.next_normal() * scale;
}
Self { table }
}
}
pub fn embedding() -> &'static Pattern4Embedding {
static TABLE: OnceLock<Pattern4Embedding> = OnceLock::new();
TABLE.get_or_init(Pattern4Embedding::build_random)
}
pub fn pool_dense_input(line_pattern_ids: &[[u16; 4]]) -> [f32; PATTERN4_DENSE_DIM] {
let mut out = [0.0f32; PATTERN4_DENSE_DIM];
let emb = embedding();
for cell_dirs in line_pattern_ids {
for &pid in cell_dirs {
let row = emb.row(pid);
for k in 0..PATTERN4_DENSE_DIM {
out[k] += row[k];
}
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use crate::board::{to_idx, Board};
#[test]
fn embedding_table_has_expected_shape() {
let emb = embedding();
for pid in 0..PATTERN_NUM_IDS as u16 {
assert_eq!(emb.row(pid).len(), PATTERN4_DENSE_DIM);
}
}
#[test]
fn embedding_is_deterministic_across_threads() {
let a = embedding().row(123).to_vec();
let b = embedding().row(123).to_vec();
assert_eq!(a, b);
}
#[test]
fn pool_dense_input_empty_board_matches_full_recompute() {
let board = Board::new();
let v = pool_dense_input(&board.line_pattern_ids[..]);
let again = pool_dense_input(&board.line_pattern_ids[..]);
assert_eq!(v, again);
}
#[test]
fn pool_dense_input_changes_after_a_move() {
let mut board = Board::new();
let baseline = pool_dense_input(&board.line_pattern_ids[..]);
board.make_move(to_idx(7, 7));
let after_move = pool_dense_input(&board.line_pattern_ids[..]);
assert_ne!(baseline, after_move);
}
}