#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Activation {
CReLU,
SCReLU,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct NnueConfig {
pub feature_size: usize,
pub accumulator_size: usize,
pub hidden_sizes: &'static [usize],
pub activation: Activation,
}
impl NnueConfig {
#[inline]
pub fn concat_size(&self) -> usize {
self.accumulator_size * 2
}
#[inline]
pub fn num_hidden_layers(&self) -> usize {
self.hidden_sizes.len()
}
#[inline]
pub fn layer_input_size(&self, layer_idx: usize) -> usize {
if layer_idx == 0 {
self.concat_size()
} else {
self.hidden_sizes[layer_idx - 1]
}
}
#[inline]
pub fn last_hidden_size(&self) -> usize {
*self.hidden_sizes.last().expect("hidden_sizes must not be empty")
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OwnedNnueConfig {
pub feature_size: usize,
pub accumulator_size: usize,
pub hidden_sizes: Vec<usize>,
pub activation: Activation,
}
impl OwnedNnueConfig {
pub fn new(
feature_size: usize,
accumulator_size: usize,
hidden_sizes: Vec<usize>,
activation: Activation,
) -> Self {
Self { feature_size, accumulator_size, hidden_sizes, activation }
}
pub fn leak(self) -> NnueConfig {
let boxed: Box<[usize]> = self.hidden_sizes.into_boxed_slice();
let static_ref: &'static [usize] = Box::leak(boxed);
NnueConfig {
feature_size: self.feature_size,
accumulator_size: self.accumulator_size,
hidden_sizes: static_ref,
activation: self.activation,
}
}
}
pub unsafe fn reclaim_leaked_hidden_sizes(hidden_sizes: &'static [usize]) {
let ptr = hidden_sizes as *const [usize] as *mut [usize];
drop(Box::from_raw(ptr));
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn owned_config_leak_preserves_fields() {
let owned = OwnedNnueConfig::new(768, 256, vec![256, 32, 32], Activation::SCReLU);
let config = owned.leak();
assert_eq!(config.feature_size, 768);
assert_eq!(config.accumulator_size, 256);
assert_eq!(config.hidden_sizes, &[256, 32, 32]);
assert_eq!(config.activation, Activation::SCReLU);
assert_eq!(config.concat_size(), 512);
assert_eq!(config.num_hidden_layers(), 3);
assert_eq!(config.last_hidden_size(), 32);
unsafe { reclaim_leaked_hidden_sizes(config.hidden_sizes) };
}
}