use std::cell::RefCell;
use std::ffi::CString;
use std::os::raw::c_char;
use std::panic::{catch_unwind, AssertUnwindSafe};
use std::ptr;
use std::slice;
use crate::config::{reclaim_leaked_hidden_sizes, Activation, OwnedNnueConfig};
use crate::network::{forward as nnue_forward, Accumulator, FeatureDelta, NnueWeights};
use crate::trainer::{AdamState, ForwardResult, Gradients, SimpleRng, TrainableWeights, TrainingSample};
pub const NORU_OK: i32 = 0;
pub const NORU_ERR_NULL_PTR: i32 = -1;
pub const NORU_ERR_INVALID_ARG: i32 = -2;
pub const NORU_ERR_PANIC: i32 = -3;
pub const NORU_ERR_IO: i32 = -4;
pub const NORU_ERR_STATE: i32 = -5;
thread_local! {
static LAST_ERROR: RefCell<Option<CString>> = RefCell::new(None);
}
fn set_last_error(msg: impl Into<String>) {
let s = msg.into();
let c = CString::new(s).unwrap_or_else(|_| CString::new("error").unwrap());
LAST_ERROR.with(|cell| *cell.borrow_mut() = Some(c));
}
fn clear_last_error() {
LAST_ERROR.with(|cell| *cell.borrow_mut() = None);
}
fn guard<F: FnOnce() -> i32>(f: F) -> i32 {
match catch_unwind(AssertUnwindSafe(f)) {
Ok(code) => code,
Err(_) => {
set_last_error("panic in noru FFI call");
NORU_ERR_PANIC
}
}
}
#[no_mangle]
pub extern "C" fn noru_last_error() -> *const c_char {
LAST_ERROR.with(|cell| match &*cell.borrow() {
Some(s) => s.as_ptr(),
None => ptr::null(),
})
}
pub struct NoruTrainer {
weights: TrainableWeights,
adam: AdamState,
grad: Gradients,
last_sample: Option<TrainingSample>,
last_fwd: Option<ForwardResult>,
}
impl Drop for NoruTrainer {
fn drop(&mut self) {
unsafe { reclaim_leaked_hidden_sizes(self.weights.config.hidden_sizes) };
}
}
pub struct NoruWeights {
weights: NnueWeights,
}
impl Drop for NoruWeights {
fn drop(&mut self) {
unsafe { reclaim_leaked_hidden_sizes(self.weights.config.hidden_sizes) };
}
}
pub struct NoruAccumulator {
acc: Accumulator,
}
unsafe fn slice_from_raw_usize(ptr: *const u32, len: usize) -> Vec<usize> {
if len == 0 {
return Vec::new();
}
let s = slice::from_raw_parts(ptr, len);
s.iter().map(|&v| v as usize).collect()
}
fn activation_from_u8(v: u8) -> Result<Activation, &'static str> {
match v {
0 => Ok(Activation::CReLU),
1 => Ok(Activation::SCReLU),
_ => Err("unknown activation type"),
}
}
fn build_feature_delta(added: &[usize], removed: &[usize]) -> Result<FeatureDelta, &'static str> {
if added.len() > 32 || removed.len() > 32 {
return Err("feature delta exceeds 32 entries per side");
}
let mut d = FeatureDelta::new();
for &i in added {
d.add(i);
}
for &i in removed {
d.remove(i);
}
Ok(d)
}
#[no_mangle]
pub unsafe extern "C" fn noru_trainer_new(
feature_size: usize,
accumulator_size: usize,
hidden_sizes_ptr: *const usize,
hidden_sizes_len: usize,
activation: u8,
seed: u64,
out_handle: *mut *mut NoruTrainer,
) -> i32 {
guard(|| {
clear_last_error();
if out_handle.is_null() {
set_last_error("out_handle is null");
return NORU_ERR_NULL_PTR;
}
if hidden_sizes_ptr.is_null() || hidden_sizes_len == 0 {
set_last_error("hidden_sizes is empty or null");
return NORU_ERR_INVALID_ARG;
}
if feature_size == 0 || accumulator_size == 0 {
set_last_error("feature_size and accumulator_size must be non-zero");
return NORU_ERR_INVALID_ARG;
}
let act = match activation_from_u8(activation) {
Ok(a) => a,
Err(e) => {
set_last_error(e);
return NORU_ERR_INVALID_ARG;
}
};
let hidden = slice::from_raw_parts(hidden_sizes_ptr, hidden_sizes_len).to_vec();
let owned = OwnedNnueConfig::new(feature_size, accumulator_size, hidden, act);
let config = owned.leak();
let mut rng = SimpleRng::new(seed);
let weights = TrainableWeights::init_random(config, &mut rng);
let adam = AdamState::new(config);
let grad = Gradients::new(config);
let handle = Box::new(NoruTrainer {
weights,
adam,
grad,
last_sample: None,
last_fwd: None,
});
*out_handle = Box::into_raw(handle);
NORU_OK
})
}
#[no_mangle]
pub unsafe extern "C" fn noru_trainer_free(handle: *mut NoruTrainer) {
if handle.is_null() {
return;
}
drop(Box::from_raw(handle));
}
#[no_mangle]
pub unsafe extern "C" fn noru_trainer_forward(
handle: *mut NoruTrainer,
stm_ptr: *const u32,
stm_len: usize,
nstm_ptr: *const u32,
nstm_len: usize,
out_eval: *mut f32,
) -> i32 {
guard(|| {
clear_last_error();
if handle.is_null() {
set_last_error("handle is null");
return NORU_ERR_NULL_PTR;
}
let trainer = &mut *handle;
let stm = slice_from_raw_usize(stm_ptr, stm_len);
let nstm = slice_from_raw_usize(nstm_ptr, nstm_len);
let fwd = trainer.weights.forward(&stm, &nstm);
if !out_eval.is_null() {
*out_eval = fwd.output;
}
trainer.last_sample = Some(TrainingSample {
stm_features: stm,
nstm_features: nstm,
target: 0.0,
});
trainer.last_fwd = Some(fwd);
NORU_OK
})
}
#[no_mangle]
pub unsafe extern "C" fn noru_trainer_backward_mse(
handle: *mut NoruTrainer,
target: f32,
) -> i32 {
guard(|| {
clear_last_error();
if handle.is_null() {
set_last_error("handle is null");
return NORU_ERR_NULL_PTR;
}
let trainer = &mut *handle;
let sample = match trainer.last_sample.as_mut() {
Some(s) => s,
None => {
set_last_error("backward_mse called before forward");
return NORU_ERR_STATE;
}
};
sample.target = target;
let fwd = match trainer.last_fwd.as_ref() {
Some(f) => f,
None => {
set_last_error("backward_mse called without forward result");
return NORU_ERR_STATE;
}
};
trainer.weights.backward_mse(sample, fwd, &mut trainer.grad);
NORU_OK
})
}
#[no_mangle]
pub unsafe extern "C" fn noru_trainer_zero_grad(handle: *mut NoruTrainer) -> i32 {
guard(|| {
clear_last_error();
if handle.is_null() {
set_last_error("handle is null");
return NORU_ERR_NULL_PTR;
}
let trainer = &mut *handle;
trainer.grad.zero();
NORU_OK
})
}
#[no_mangle]
pub unsafe extern "C" fn noru_trainer_adam_step(
handle: *mut NoruTrainer,
lr: f32,
batch_size: f32,
) -> i32 {
guard(|| {
clear_last_error();
if handle.is_null() {
set_last_error("handle is null");
return NORU_ERR_NULL_PTR;
}
if !(batch_size > 0.0) {
set_last_error("batch_size must be positive");
return NORU_ERR_INVALID_ARG;
}
let trainer = &mut *handle;
trainer
.weights
.adam_update(&trainer.grad, &mut trainer.adam, lr, batch_size);
NORU_OK
})
}
unsafe fn vec_into_out_buf(
buf: Vec<u8>,
out_ptr: *mut *mut u8,
out_len: *mut usize,
) -> i32 {
if out_ptr.is_null() || out_len.is_null() {
set_last_error("output pointer(s) are null");
return NORU_ERR_NULL_PTR;
}
let mut boxed = buf.into_boxed_slice();
let ptr = boxed.as_mut_ptr();
let len = boxed.len();
std::mem::forget(boxed);
*out_ptr = ptr;
*out_len = len;
NORU_OK
}
#[no_mangle]
pub unsafe extern "C" fn noru_trainer_save_fp32(
handle: *mut NoruTrainer,
out_ptr: *mut *mut u8,
out_len: *mut usize,
) -> i32 {
guard(|| {
clear_last_error();
if handle.is_null() {
set_last_error("handle is null");
return NORU_ERR_NULL_PTR;
}
let trainer = &*handle;
let bytes = trainer.weights.save_to_bytes();
vec_into_out_buf(bytes, out_ptr, out_len)
})
}
#[no_mangle]
pub unsafe extern "C" fn noru_trainer_load_fp32(
data: *const u8,
len: usize,
out_handle: *mut *mut NoruTrainer,
) -> i32 {
guard(|| {
clear_last_error();
if data.is_null() || out_handle.is_null() {
set_last_error("null pointer argument");
return NORU_ERR_NULL_PTR;
}
let slice = slice::from_raw_parts(data, len);
let weights = match TrainableWeights::load_from_bytes(slice) {
Ok(w) => w,
Err(e) => {
set_last_error(e);
return NORU_ERR_IO;
}
};
let adam = AdamState::new(weights.config);
let grad = Gradients::new(weights.config);
let handle = Box::new(NoruTrainer {
weights,
adam,
grad,
last_sample: None,
last_fwd: None,
});
*out_handle = Box::into_raw(handle);
NORU_OK
})
}
#[no_mangle]
pub unsafe extern "C" fn noru_trainer_quantize(
handle: *mut NoruTrainer,
out_weights: *mut *mut NoruWeights,
) -> i32 {
guard(|| {
clear_last_error();
if handle.is_null() || out_weights.is_null() {
set_last_error("null pointer argument");
return NORU_ERR_NULL_PTR;
}
let trainer = &*handle;
let quantized = trainer.weights.quantize();
let owned = OwnedNnueConfig::new(
quantized.config.feature_size,
quantized.config.accumulator_size,
quantized.config.hidden_sizes.to_vec(),
quantized.config.activation,
);
let fresh_config = owned.leak();
let mut weights = quantized;
weights.config = fresh_config;
let boxed = Box::new(NoruWeights { weights });
*out_weights = Box::into_raw(boxed);
NORU_OK
})
}
#[no_mangle]
pub unsafe extern "C" fn noru_weights_load(
data: *const u8,
len: usize,
out_weights: *mut *mut NoruWeights,
) -> i32 {
guard(|| {
clear_last_error();
if data.is_null() || out_weights.is_null() {
set_last_error("null pointer argument");
return NORU_ERR_NULL_PTR;
}
let slice = slice::from_raw_parts(data, len);
let weights = match NnueWeights::load_from_bytes(slice, None) {
Ok(w) => w,
Err(e) => {
set_last_error(e);
return NORU_ERR_IO;
}
};
let boxed = Box::new(NoruWeights { weights });
*out_weights = Box::into_raw(boxed);
NORU_OK
})
}
#[no_mangle]
pub unsafe extern "C" fn noru_weights_save(
handle: *mut NoruWeights,
out_ptr: *mut *mut u8,
out_len: *mut usize,
) -> i32 {
guard(|| {
clear_last_error();
if handle.is_null() {
set_last_error("handle is null");
return NORU_ERR_NULL_PTR;
}
let w = &*handle;
let bytes = w.weights.save_to_bytes();
vec_into_out_buf(bytes, out_ptr, out_len)
})
}
#[no_mangle]
pub unsafe extern "C" fn noru_weights_free(handle: *mut NoruWeights) {
if handle.is_null() {
return;
}
drop(Box::from_raw(handle));
}
#[no_mangle]
pub unsafe extern "C" fn noru_free_bytes(ptr: *mut u8, len: usize) {
if ptr.is_null() || len == 0 {
return;
}
let slice = slice::from_raw_parts_mut(ptr, len);
drop(Box::from_raw(slice as *mut [u8]));
}
#[no_mangle]
pub unsafe extern "C" fn noru_accumulator_new(
weights: *mut NoruWeights,
out_acc: *mut *mut NoruAccumulator,
) -> i32 {
guard(|| {
clear_last_error();
if weights.is_null() || out_acc.is_null() {
set_last_error("null pointer argument");
return NORU_ERR_NULL_PTR;
}
let w = &*weights;
let acc = Accumulator::new(&w.weights.feature_bias);
let boxed = Box::new(NoruAccumulator { acc });
*out_acc = Box::into_raw(boxed);
NORU_OK
})
}
#[no_mangle]
pub unsafe extern "C" fn noru_accumulator_free(handle: *mut NoruAccumulator) {
if handle.is_null() {
return;
}
drop(Box::from_raw(handle));
}
#[no_mangle]
pub unsafe extern "C" fn noru_accumulator_refresh(
acc: *mut NoruAccumulator,
weights: *mut NoruWeights,
stm_ptr: *const u32,
stm_len: usize,
nstm_ptr: *const u32,
nstm_len: usize,
) -> i32 {
guard(|| {
clear_last_error();
if acc.is_null() || weights.is_null() {
set_last_error("null pointer argument");
return NORU_ERR_NULL_PTR;
}
let acc_ref = &mut *acc;
let w = &*weights;
let stm = slice_from_raw_usize(stm_ptr, stm_len);
let nstm = slice_from_raw_usize(nstm_ptr, nstm_len);
acc_ref.acc.refresh(&w.weights, &stm, &nstm);
NORU_OK
})
}
#[no_mangle]
pub unsafe extern "C" fn noru_accumulator_update(
acc: *mut NoruAccumulator,
weights: *mut NoruWeights,
stm_added_ptr: *const u32,
stm_added_len: usize,
stm_removed_ptr: *const u32,
stm_removed_len: usize,
nstm_added_ptr: *const u32,
nstm_added_len: usize,
nstm_removed_ptr: *const u32,
nstm_removed_len: usize,
) -> i32 {
guard(|| {
clear_last_error();
if acc.is_null() || weights.is_null() {
set_last_error("null pointer argument");
return NORU_ERR_NULL_PTR;
}
let acc_ref = &mut *acc;
let w = &*weights;
let stm_added = slice_from_raw_usize(stm_added_ptr, stm_added_len);
let stm_removed = slice_from_raw_usize(stm_removed_ptr, stm_removed_len);
let nstm_added = slice_from_raw_usize(nstm_added_ptr, nstm_added_len);
let nstm_removed = slice_from_raw_usize(nstm_removed_ptr, nstm_removed_len);
let stm_delta = match build_feature_delta(&stm_added, &stm_removed) {
Ok(d) => d,
Err(e) => {
set_last_error(e);
return NORU_ERR_INVALID_ARG;
}
};
let nstm_delta = match build_feature_delta(&nstm_added, &nstm_removed) {
Ok(d) => d,
Err(e) => {
set_last_error(e);
return NORU_ERR_INVALID_ARG;
}
};
acc_ref
.acc
.update_incremental(&w.weights, &stm_delta, &nstm_delta);
NORU_OK
})
}
#[no_mangle]
pub unsafe extern "C" fn noru_accumulator_swap(acc: *mut NoruAccumulator) -> i32 {
guard(|| {
clear_last_error();
if acc.is_null() {
set_last_error("null pointer argument");
return NORU_ERR_NULL_PTR;
}
let acc_ref = &mut *acc;
acc_ref.acc.swap();
NORU_OK
})
}
#[no_mangle]
pub unsafe extern "C" fn noru_accumulator_forward(
acc: *mut NoruAccumulator,
weights: *mut NoruWeights,
out_eval: *mut i32,
) -> i32 {
guard(|| {
clear_last_error();
if acc.is_null() || weights.is_null() || out_eval.is_null() {
set_last_error("null pointer argument");
return NORU_ERR_NULL_PTR;
}
let acc_ref = &*acc;
let w = &*weights;
*out_eval = nnue_forward(&acc_ref.acc, &w.weights);
NORU_OK
})
}
#[cfg(test)]
mod tests {
use super::*;
fn small_hidden() -> Vec<usize> {
vec![16, 8]
}
#[test]
fn trainer_roundtrip_via_ffi() {
unsafe {
let hidden = small_hidden();
let mut trainer: *mut NoruTrainer = ptr::null_mut();
let rc = noru_trainer_new(
64,
32,
hidden.as_ptr(),
hidden.len(),
0,
1234,
&mut trainer as *mut _,
);
assert_eq!(rc, NORU_OK);
assert!(!trainer.is_null());
let stm: [u32; 3] = [1, 5, 10];
let nstm: [u32; 3] = [2, 7, 15];
let mut eval: f32 = 0.0;
let rc = noru_trainer_forward(
trainer,
stm.as_ptr(),
stm.len(),
nstm.as_ptr(),
nstm.len(),
&mut eval,
);
assert_eq!(rc, NORU_OK);
assert_eq!(noru_trainer_zero_grad(trainer), NORU_OK);
assert_eq!(noru_trainer_backward_mse(trainer, 0.5), NORU_OK);
assert_eq!(noru_trainer_adam_step(trainer, 1e-3, 1.0), NORU_OK);
let mut save_ptr: *mut u8 = ptr::null_mut();
let mut save_len: usize = 0;
assert_eq!(
noru_trainer_save_fp32(trainer, &mut save_ptr, &mut save_len),
NORU_OK
);
assert!(!save_ptr.is_null() && save_len > 0);
let mut trainer2: *mut NoruTrainer = ptr::null_mut();
assert_eq!(
noru_trainer_load_fp32(save_ptr, save_len, &mut trainer2),
NORU_OK
);
assert!(!trainer2.is_null());
noru_free_bytes(save_ptr, save_len);
let mut weights_handle: *mut NoruWeights = ptr::null_mut();
assert_eq!(
noru_trainer_quantize(trainer, &mut weights_handle),
NORU_OK
);
let mut acc_handle: *mut NoruAccumulator = ptr::null_mut();
assert_eq!(
noru_accumulator_new(weights_handle, &mut acc_handle),
NORU_OK
);
assert_eq!(
noru_accumulator_refresh(
acc_handle,
weights_handle,
stm.as_ptr(),
stm.len(),
nstm.as_ptr(),
nstm.len()
),
NORU_OK
);
let mut int_eval: i32 = 0;
assert_eq!(
noru_accumulator_forward(acc_handle, weights_handle, &mut int_eval),
NORU_OK
);
noru_accumulator_free(acc_handle);
noru_weights_free(weights_handle);
noru_trainer_free(trainer2);
noru_trainer_free(trainer);
}
}
#[test]
fn null_handle_is_an_error_not_a_crash() {
unsafe {
let rc = noru_trainer_zero_grad(ptr::null_mut());
assert_eq!(rc, NORU_ERR_NULL_PTR);
assert!(!noru_last_error().is_null());
}
}
}