use std::borrow::Cow;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Activation {
CReLU,
SCReLU,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct NnueConfig {
pub feature_size: usize,
pub accumulator_size: usize,
pub hidden_sizes: Cow<'static, [usize]>,
pub activation: Activation,
}
impl NnueConfig {
pub const fn new_static(
feature_size: usize,
accumulator_size: usize,
hidden_sizes: &'static [usize],
activation: Activation,
) -> Self {
Self {
feature_size,
accumulator_size,
hidden_sizes: Cow::Borrowed(hidden_sizes),
activation,
}
}
pub fn new_owned(
feature_size: usize,
accumulator_size: usize,
hidden_sizes: Vec<usize>,
activation: Activation,
) -> Self {
Self {
feature_size,
accumulator_size,
hidden_sizes: Cow::Owned(hidden_sizes),
activation,
}
}
#[inline]
pub fn concat_size(&self) -> usize {
self.accumulator_size * 2
}
#[inline]
pub fn num_hidden_layers(&self) -> usize {
self.hidden_sizes.len()
}
#[inline]
pub fn layer_input_size(&self, layer_idx: usize) -> usize {
if layer_idx == 0 {
self.concat_size()
} else {
self.hidden_sizes[layer_idx - 1]
}
}
#[inline]
pub fn last_hidden_size(&self) -> usize {
*self
.hidden_sizes
.last()
.expect("hidden_sizes must not be empty")
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OwnedNnueConfig {
pub feature_size: usize,
pub accumulator_size: usize,
pub hidden_sizes: Vec<usize>,
pub activation: Activation,
}
impl OwnedNnueConfig {
pub fn new(
feature_size: usize,
accumulator_size: usize,
hidden_sizes: Vec<usize>,
activation: Activation,
) -> Self {
Self {
feature_size,
accumulator_size,
hidden_sizes,
activation,
}
}
pub fn into_config(self) -> NnueConfig {
NnueConfig::new_owned(
self.feature_size,
self.accumulator_size,
self.hidden_sizes,
self.activation,
)
}
}
impl From<OwnedNnueConfig> for NnueConfig {
fn from(value: OwnedNnueConfig) -> Self {
value.into_config()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn owned_config_into_config_preserves_fields() {
let owned = OwnedNnueConfig::new(768, 256, vec![256, 32, 32], Activation::SCReLU);
let config = owned.into_config();
assert_eq!(config.feature_size, 768);
assert_eq!(config.accumulator_size, 256);
assert_eq!(config.hidden_sizes.as_ref(), &[256, 32, 32]);
assert_eq!(config.activation, Activation::SCReLU);
assert_eq!(config.concat_size(), 512);
assert_eq!(config.num_hidden_layers(), 3);
assert_eq!(config.last_hidden_size(), 32);
}
#[test]
fn static_config_keeps_borrowed_topology() {
let config = NnueConfig::new_static(530, 256, &[64], Activation::CReLU);
assert!(matches!(config.hidden_sizes, Cow::Borrowed(_)));
assert_eq!(config.hidden_sizes.as_ref(), &[64]);
}
}