use std::alloc::{GlobalAlloc, Layout, System};
use std::cell::Cell;
use pounce_nl::nl_tape::{Tape, TapeOp};
struct CountingAlloc;
thread_local! {
static COUNTING: Cell<bool> = const { Cell::new(false) };
static ALLOCS: Cell<usize> = const { Cell::new(0) };
}
fn tally() {
let _ = COUNTING.try_with(|counting| {
if counting.get() {
let _ = ALLOCS.try_with(|n| n.set(n.get() + 1));
}
});
}
fn set_counting(on: bool) {
COUNTING.with(|c| c.set(on));
}
fn reset_allocs() {
ALLOCS.with(|n| n.set(0));
}
fn allocs() -> usize {
ALLOCS.with(|n| n.get())
}
unsafe impl GlobalAlloc for CountingAlloc {
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
tally();
unsafe { System.alloc(layout) }
}
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
unsafe { System.dealloc(ptr, layout) }
}
unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 {
tally();
unsafe { System.realloc(ptr, layout, new_size) }
}
}
#[global_allocator]
static GLOBAL: CountingAlloc = CountingAlloc;
fn sample_tape() -> Tape {
Tape {
ops: vec![
TapeOp::Var(0), TapeOp::Var(1), TapeOp::Mul(0, 1), TapeOp::Exp(0), TapeOp::Mul(3, 1), TapeOp::Add(2, 4), TapeOp::Mul(0, 0), TapeOp::Add(5, 6), ],
}
}
#[test]
fn gradient_seed_into_does_not_allocate_per_call() {
let tape = sample_tape();
let n = tape.ops.len();
let x = [1.3_f64, -0.7_f64];
let mut grad_into = vec![0.0_f64; 2];
let mut grad_seed = vec![0.0_f64; 2];
let mut vals = vec![0.0_f64; n];
let mut adj = vec![0.0_f64; n];
grad_into.fill(0.0);
tape.gradient_seed_into(&x, 1.0, &mut grad_into, &mut vals, &mut adj);
let reference = grad_into.clone();
reset_allocs();
set_counting(true);
for _ in 0..1000 {
grad_into.fill(0.0); tape.gradient_seed_into(&x, 1.0, &mut grad_into, &mut vals, &mut adj);
}
set_counting(false);
let into_allocs = allocs();
reset_allocs();
set_counting(true);
for _ in 0..1000 {
grad_seed.fill(0.0);
tape.gradient_seed(&x, 1.0, &mut grad_seed);
}
set_counting(false);
let seed_allocs = allocs();
assert_eq!(
grad_into, grad_seed,
"gradient_seed_into must match gradient_seed numerically"
);
assert_eq!(grad_into, reference, "result must be stable across calls");
assert!(
seed_allocs >= 1000,
"baseline gradient_seed should allocate per call; saw {seed_allocs}"
);
assert_eq!(
into_allocs, 0,
"gradient_seed_into must not allocate per call; saw {into_allocs}"
);
}